Merge branch 'master' into sparse-xent-op-hessian

This commit is contained in:
MichaelKonobeev 2020-02-06 20:10:56 -07:00
commit f0df47fddf
2291 changed files with 84192 additions and 26605 deletions
.bazelrcRELEASE.mdSECURITY.md
tensorflow
c
cc
compiler

View File

@ -69,6 +69,7 @@
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=500
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance

View File

@ -1,3 +1,19 @@
# Release 2.0.1
## Bug Fixes and Other Changes
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
# Release 1.15.2
## Bug Fixes and Other Changes
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
# Release 2.1.0
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.

View File

@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

View File

@ -54,9 +54,10 @@ filegroup(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"python_api.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
@ -98,6 +99,17 @@ tf_cuda_library(
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
@ -536,6 +548,7 @@ tf_cc_test(
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["notsan"], # b/149031034
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@ -640,7 +653,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)

View File

@ -519,72 +519,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
tensorflow::Tensor dst;
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
LOG(INFO) << dst.DebugString();
TF_DeleteTensor(t);
TF_DeleteStatus(status);
}
void TFE_OpPrintDebugString(TFE_Op* op) {
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
LOG(INFO) << op->operation.DebugString();
}
struct TFE_ExecuteOpNotification {
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
tensorflow::Notification n;
std::unique_ptr<tensorflow::Thread> thread;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
};
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TFE_TensorHandle** retvals,
int* num_retvals,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
n->n.Notify();
}));
return n;
}
void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status) {
if (notification == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification is a nullptr.");
return;
}
if (notification->thread == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification didn't start a thread correctly. Cleaning up "
"this notification. Please re-execute the operation to get a new "
"notification.");
delete notification;
return;
}
notification->n.WaitForNotification();
status->status = notification->status->status;
delete notification;
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}

View File

@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
// Prints `handle` in a human readable format to standard output for debugging.
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
// Allows invoking a kernel asynchronously, and explicitly returns a
// notification that can be waited upon. This always executes the kernel in a
// new thread.
// 1. `retvals` and `num_retvals` can only be consumed after
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
// if the return is unsuccessful
// 2. These new APIs cannot be used together with the TFE context level async
// support.
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status);
// Waits to complete the op execution, and cleans up the notification.
// Errors reported by op execution are set in `status`.
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

View File

@ -84,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
EXPECT_EQ(id, 0);
}
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
auto* r =
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteOp(matmul_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
// Perform a send/recv test. Recv blocks, so they need to be executed
// asynchronously.
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
TFE_TensorHandle* m = TestMatrixTensorHandle();
// Build a send op.
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(send_op, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
string tensor_name = "Tensor";
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(send_op, "client_terminated", true);
// Build a recv op.
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
TFE_TensorHandle* send_retvals;
int send_num_retvals = 0;
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
&send_num_retvals, status);
TFE_TensorHandle* recv_retvals[1] = {nullptr};
int recv_num_retvals = 1;
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
&recv_num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, product[0]);
EXPECT_EQ(2, product[1]);
EXPECT_EQ(3, product[2]);
EXPECT_EQ(4, product[3]);
TFE_DeleteOp(send_op);
TFE_DeleteOp(recv_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(recv_retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
class ShapeInferenceTest : public ::testing::Test {
protected:
ShapeInferenceTest()

View File

@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
@ -89,7 +90,7 @@ tf_cuda_library(
)
filegroup(
name = "pywrap_eager_hdrs",
name = "pywrap_required_hdrs",
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
@ -129,18 +130,6 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)
@ -301,6 +290,27 @@ tf_cuda_cc_test(
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
srcs = [
"custom_device_test.cc",
],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tape",
hdrs = ["tape.h"],
@ -313,7 +323,10 @@ cc_library(
filegroup(
name = "headers",
srcs = ["c_api.h"],
srcs = [
"c_api.h",
"c_api_experimental.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/device_filters.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -102,7 +103,12 @@ const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
return op_def;
}
bool IsCPU(const tensorflow::Device* d) {
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
return false;
}
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
@ -265,9 +271,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
}
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, int keep_alive_secs,
const tensorflow::ServerDef& server_def,
TFE_Context* ctx, const std::vector<string>& remote_workers,
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const bool lazy_copy_remote_function_inputs,
const tensorflow::eager::CreateContextRequest& base_request) {
@ -296,7 +302,7 @@ tensorflow::Status CreateRemoteContexts(
continue;
}
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
@ -304,6 +310,21 @@ tensorflow::Status CreateRemoteContexts(
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(),
base_request.cluster_device_attributes_size());
for (int i = 0; i < filtered_device_mask.size(); i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
request.set_lazy_copy_remote_function_inputs(
@ -325,13 +346,34 @@ tensorflow::Status CreateRemoteContexts(
}
tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
TFE_Context* ctx, const std::vector<string>& remote_workers,
const std::vector<string>& added_workers,
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
int cluster_device_count = base_request.cluster_device_attributes_size();
std::unordered_set<string> added_or_removed(added_workers.begin(),
added_workers.end());
std::copy(removed_workers.begin(), removed_workers.end(),
std::inserter(added_or_removed, added_or_removed.end()));
// Whether each device is in the updated (added or removed) workers
std::vector<bool> device_added_or_removed(cluster_device_count);
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
const auto& da = base_request.cluster_device_attributes().at(i);
tensorflow::DeviceNameUtils::ParsedName pn;
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
string task_name;
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
if (added_or_removed.find(task_name) != added_or_removed.end()) {
device_added_or_removed[i] = true;
}
}
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
@ -354,17 +396,42 @@ tensorflow::Status UpdateRemoteContexts(
continue;
}
std::vector<bool> filtered_device_mask;
ctx->context->FilterDevicesForRemoteWorkers(
remote_worker, base_request.cluster_device_attributes(),
&filtered_device_mask);
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
// If any of the devices that match the device filters are in the set of
// added or removed workers, we must send a complete UpdateContextRequest.
// Otherwise, only send a simple request to increment context view ID.
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
std::transform(device_added_or_removed.begin(),
device_added_or_removed.end(), filtered_device_mask.begin(),
added_or_removed_filtered_devices.begin(),
std::logical_and<bool>());
const bool full_update_request =
std::accumulate(added_or_removed_filtered_devices.begin(),
added_or_removed_filtered_devices.end(), false,
std::logical_or<bool>());
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
for (const auto& da : base_request.cluster_device_attributes()) {
*request.add_cluster_device_attributes() = da;
}
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
if (full_update_request) {
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
server_def.default_session_config());
for (int i = 0; i < cluster_device_count; i++) {
if (filtered_device_mask[i]) {
const auto& da = base_request.cluster_device_attributes(i);
*request.add_cluster_device_attributes() = da;
}
}
}
eager_client->UpdateContextAsync(
&request, response,
@ -525,15 +592,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, context_id, context_view_id, keep_alive_secs,
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
} else {
@ -543,7 +607,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// we must set their context_view_id to the existing master's
// context_view_id + 1.
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
added_workers, context_id, context_view_id + 1, keep_alive_secs,
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
if (!existing_workers.empty()) {
@ -553,8 +617,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, context_view_id + 1, server_def,
remote_eager_workers.get(), base_request));
ctx, existing_workers, added_workers, removed_workers, context_id,
context_view_id + 1, server_def, remote_eager_workers.get(),
base_request));
}
}
@ -709,6 +774,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
if (server_def.has_cluster_device_filters()) {
const auto& cdf = server_def.cluster_device_filters();
for (const auto& jdf : cdf.jobs()) {
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
for (const auto& tdf : jdf.tasks()) {
const int32_t task_index = tdf.first;
std::vector<string> device_filters(tdf.second.device_filters_size());
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
device_filters[i] = tdf.second.device_filters(i);
}
const string remote_worker = remote_prefix + std::to_string(task_index);
status->status =
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
}
}
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
#endif // !IS_MOBILE_PLATFORM
@ -733,6 +814,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
if (server_def.has_cluster_device_filters()) {
LOG(WARNING) << "Device filters can only be specified when initializing "
"the cluster. Any changes in device filters are ignored "
"when updating the server def.";
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
@ -797,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
status->status = ctx->context->ClearRemoteExecutors();
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context->SetThreadLocalDevicePlacementPolicy(
@ -928,6 +1023,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
@ -948,9 +1046,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<tensorflow::CustomDevice*>(handle_->device())
->name()
.c_str();
} else {
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
@ -984,6 +1088,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
tensorflow::CustomDevice* custom_device =
absl::get<tensorflow::CustomDevice*>(handle_->device());
tensorflow::TensorHandle* copy;
*status = custom_device->CopyTensorFromDevice(
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", &copy);
if (status->ok()) {
return TensorHandleInterface(copy).Resolve(status);
} else {
return nullptr;
}
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle_->IsRemote()) {
@ -1029,6 +1145,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
return t->data();
}
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1036,8 +1157,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
"handle.");
return nullptr;
}
if (handle->device() != nullptr) {
status->status = handle->device()->Sync();
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
if (device != nullptr) {
status->status = device->Sync();
if (!status->status.ok()) {
return nullptr;
}
@ -1056,12 +1178,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
const int64_t* dims, int num_dims, void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device;
tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
status->status =
context->FindCustomDeviceFromName(device_name, &custom_device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
}
}
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
@ -1085,8 +1212,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
tensorflow::TensorHandle* ret_handle;
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
if (custom_device == nullptr) {
variant_device = device;
} else {
variant_device = custom_device;
}
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
t, variant_device, context, &ret_handle);
if (!status->status.ok()) {
return nullptr;
}
@ -1427,8 +1560,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
&handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
}
return nullptr;
}
// Handle tensor handles currently in custom devices
const char* handle_device_name = h->handle->DeviceName(&status->status);
if (!status->status.ok()) {
return nullptr;
}
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
return nullptr;
}
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
@ -1567,3 +1734,94 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
: device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle** result) override {
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
(*result)->Ref();
delete result_handle;
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* tensor,
const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override {
TF_Status status;
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
&tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
(*result)->Ref();
delete result_handle;
return status.status;
}
tensorflow::Status Execute(tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
outputs[i]->handle.get())
->Handle();
retvals[i]->Ref();
}
}
for (auto inp : inputs) {
delete inp;
}
return status.status;
}
private:
TFE_CustomDevice device_;
void* info_;
string name_;
};
} // namespace
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle_->device();
tensorflow::Device* device = absl::get<Device*>(handle_->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =

View File

@ -88,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::client::StartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts);
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
@ -104,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;

View File

@ -27,7 +27,7 @@ extern "C" {
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
// does not set the device name. If it's not `NULL`, then it attempts to parse
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than seperately calling it because if the existing op has the same
// than separately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Clear pending streaming requests and error statuses on remote executors.
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
TF_Status* status);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
@ -463,6 +467,57 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.
//
// Eager operations may be placed on this device, e.g. `with
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device
// which is passed to the functions referenced in the TFE_CustomDevice struct
// `device` (execute, delete_device, etc.). It can for example contain the
// names of wrapped devices.
//
// There are currently no graph semantics implemented for registered custom
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -363,13 +363,18 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
TensorHandleCopyBetweenTwoGPUDevices(true);
}
void TensorHandleSilentCopy(bool async) {
void TensorHandleSilentCopy(bool async,
TFE_ContextDevicePlacementPolicy global_policy,
TFE_ContextDevicePlacementPolicy thread_policy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
if (thread_policy != global_policy) {
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
}
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -404,57 +409,21 @@ void TensorHandleSilentCopy(bool async) {
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
void TensorHandleSilentCopyLocal(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
TFE_DEVICE_PLACEMENT_SILENT);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
TEST(CAPI, TensorHandleSilentCopy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT);
}
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
TensorHandleSilentCopyLocal(true);
TEST(CAPI, TensorHandleSilentCopyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
TFE_DEVICE_PLACEMENT_SILENT);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT);
}
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
TFE_DEVICE_PLACEMENT_SILENT);
}
void SetAndGetOpDevices(bool async) {

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
TFE_Context* ctx;
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* ctx, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, dev->ctx, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
std::move(logged_tensor), s);
}
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* context = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived);
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteTensorHandle(hdevice);
TFE_DeleteContext(context);
}
} // namespace

View File

@ -56,7 +56,7 @@ extern "C" {
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
/// pointed to by the `void*` members is always owned by the plugin. The plugin
/// will provide functions to call to allocate and deallocate this data (see
/// next section) and core TensorFlow ensures to call these at the proper time.
/// next sections) and core TensorFlow ensures to call these at the proper time.
///
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
/// TensorFlow will never touch the `void*` wrapped by these structures, except
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// This function will be called by core TensorFlow to clean up all path
/// arguments for all other methods in the filesystem API.
///
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all children were returned.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
/// different than `TF_OK`, free any memory that might have been allocated for
/// `entries` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if all matches were returned.
/// * Might use any other error value for `status` to signal other errors.
@ -741,8 +753,11 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
/// it will be called by core TensorFlow when the filesystem plugin is
/// loaded;
/// * `TF_FilesystemPluginInfo` struct: used to transfer information between
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
/// plugins and core TensorFlow about the operations provided and metadata;
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
/// collects information about all the file schemes that the plugin provides
/// support for, as well as about the plugin's memory handling routines;
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
/// their `TF_InitPlugin` to record the versioning information the plugins
/// are compiled against.
@ -774,7 +789,7 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that a new type of file needs to be
/// supported, add the new ops and metadata at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
typedef struct TF_FilesystemPluginOps {
char* scheme;
int filesystem_ops_abi;
int filesystem_ops_api;
@ -792,6 +807,29 @@ typedef struct TF_FilesystemPluginInfo {
int read_only_memory_region_ops_api;
size_t read_only_memory_region_ops_size;
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
} TF_FilesystemPluginOps;
/// This structure gathers together all the operations provided by the plugin.
///
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
///
/// Since memory that is allocated by the DSO gets transferred to core
/// TensorFlow, we need to provide a way for the allocation and deallocation to
/// match. This is why this structure also defines `plugin_memory_allocate` and
/// `plugin_memory_free` members.
///
/// All memory allocated by the plugin that will be owned by core TensorFlow
/// must be allocated using the allocator in this structure. Core TensorFlow
/// will use the deallocator to free this memory once it no longer needs it.
///
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
/// must not change! In the unlikely case that new global operations must be
/// provided, add them at the end of the structure.
typedef struct TF_FilesystemPluginInfo {
size_t num_schemes;
TF_FilesystemPluginOps* ops;
void* (*plugin_memory_allocate)(size_t size);
void (*plugin_memory_free)(void* ptr);
} TF_FilesystemPluginInfo;
/// Convenience function for setting the versioning metadata.
@ -801,19 +839,19 @@ typedef struct TF_FilesystemPluginInfo {
/// We want this to be defined in the plugin's memory space and we guarantee
/// that core TensorFlow will never call this.
static inline void TF_SetFilesystemVersionMetadata(
TF_FilesystemPluginInfo* info) {
info->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
info->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
info->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
info->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
info->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
info->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
info->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
info->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
info->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
info->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
info->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
info->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
TF_FilesystemPluginOps* ops) {
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
}
/// Initializes a TensorFlow plugin.
@ -828,16 +866,14 @@ static inline void TF_SetFilesystemVersionMetadata(
/// manage themselves). In both of these cases, core TensorFlow looks for
/// the `TF_InitPlugin` symbol and calls this function.
///
/// All memory allocated by this function must be allocated via the `allocator`
/// argument.
///
/// For every filesystem URI scheme that this plugin supports, the plugin must
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info`.
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
/// `TF_SetFilesystemVersionMetadata` for that entry.
///
/// Returns number of entries in `plugin_info` (i.e., number of URI schemes
/// supported).
TF_CAPI_EXPORT extern int TF_InitPlugin(void* (*allocator)(size_t size),
TF_FilesystemPluginInfo** plugin_info);
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
/// freed in a compatible way.
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
#ifdef __cplusplus
} // end extern "C"

View File

@ -164,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
std::string translated_name = TranslateName(dir);
char** children;
// Note that `children` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** children = nullptr;
const int num_children =
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
plugin_status.get());
if (num_children >= 0) {
for (int i = 0; i < num_children; i++) {
result->push_back(std::string(children[i]));
free(children[i]);
plugin_memory_free_(children[i]);
}
free(children);
plugin_memory_free_(children);
}
return StatusFromTF_Status(plugin_status.get());
@ -185,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
char** matches;
// Note that `matches` is allocated by the plugin and freed by core
// TensorFlow, so we need to use `plugin_memory_free_` here.
char** matches = nullptr;
const int num_matches = ops_->get_matching_paths(
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
if (num_matches >= 0) {
for (int i = 0; i < num_matches; i++) {
result->push_back(std::string(matches[i]));
free(matches[i]);
plugin_memory_free_(matches[i]);
}
free(matches);
plugin_memory_free_(matches);
}
return StatusFromTF_Status(plugin_status.get());
@ -357,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
std::string ret(p);
free(p);
// Since `p` is allocated by plugin, free it using plugin's method.
plugin_memory_free_(p);
return ret;
}

View File

@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops)
read_only_memory_region_ops,
std::function<void*(size_t)> plugin_memory_allocate,
std::function<void(void*)> plugin_memory_free)
: filesystem_(std::move(filesystem)),
ops_(std::move(filesystem_ops)),
random_access_file_ops_(std::move(random_access_file_ops)),
writable_file_ops_(std::move(writable_file_ops)),
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
plugin_memory_free_(std::move(plugin_memory_free)) {}
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
read_only_memory_region_ops_;
std::function<void*(size_t)> plugin_memory_allocate_;
std::function<void(void*)> plugin_memory_free_;
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
};

View File

@ -50,21 +50,21 @@ static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
// If the numbers don't match, plugin cannot be loaded.
//
// Uses the simpler `CheckABI(int, int, StringPiece)`.
static Status ValidateABI(const TF_FilesystemPluginInfo* info) {
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(
CheckABI(info->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
if (info->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->random_access_file_ops_abi,
if (ops->random_access_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
TF_RANDOM_ACCESS_FILE_OPS_ABI,
"random access file"));
if (info->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->writable_file_ops_abi,
if (ops->writable_file_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
if (info->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(info->read_only_memory_region_ops_abi,
if (ops->read_only_memory_region_ops != nullptr)
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
"read only memory region"));
@ -83,19 +83,19 @@ static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
// Checks if the plugin and core API numbers match, for all operations.
//
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
static void ValidateAPI(const TF_FilesystemPluginInfo* info) {
CheckAPI(info->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
if (info->random_access_file_ops != nullptr)
CheckAPI(info->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
if (ops->random_access_file_ops != nullptr)
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
"random access file");
if (info->writable_file_ops != nullptr)
CheckAPI(info->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
if (ops->writable_file_ops != nullptr)
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
"writable file");
if (info->read_only_memory_region_ops != nullptr)
CheckAPI(info->read_only_memory_region_ops_api,
if (ops->read_only_memory_region_ops != nullptr)
CheckAPI(ops->read_only_memory_region_ops_api,
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
}
@ -177,27 +177,27 @@ static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
// individual function table and then checks that the function table for a
// specific file type exists if the plugin offers support for creating that
// type of files.
static Status ValidateOperations(const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidateHelper(info->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(info->read_only_memory_region_ops));
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
if (info->filesystem_ops->new_random_access_file != nullptr &&
info->random_access_file_ops == nullptr)
if (ops->filesystem_ops->new_random_access_file != nullptr &&
ops->random_access_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of random access files but no "
"operations on them have been supplied.");
if ((info->filesystem_ops->new_writable_file != nullptr ||
info->filesystem_ops->new_appendable_file != nullptr) &&
info->writable_file_ops == nullptr)
if ((ops->filesystem_ops->new_writable_file != nullptr ||
ops->filesystem_ops->new_appendable_file != nullptr) &&
ops->writable_file_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of writable files but no "
"operations on them have been supplied.");
if (info->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
info->read_only_memory_region_ops == nullptr)
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
ops->read_only_memory_region_ops == nullptr)
return errors::FailedPrecondition(
"Filesystem allows creation of readonly memory regions but no "
"operations on them have been supplied.");
@ -232,18 +232,23 @@ static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
}
// Registers one filesystem from the plugin.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
//
// Must be called only with `index` a valid index in `info->ops`.
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
int index) {
// Step 1: Copy all the function tables to core TensorFlow memory space
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
info->filesystem_ops, info->filesystem_ops_size);
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
info->random_access_file_ops, info->random_access_file_ops_size);
auto core_writable_file_ops = CopyToCore<TF_WritableFileOps>(
info->writable_file_ops, info->writable_file_ops_size);
info->ops[index].random_access_file_ops,
info->ops[index].random_access_file_ops_size);
auto core_writable_file_ops =
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
info->ops[index].writable_file_ops_size);
auto core_read_only_memory_region_ops =
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
info->read_only_memory_region_ops,
info->read_only_memory_region_ops_size);
info->ops[index].read_only_memory_region_ops,
info->ops[index].read_only_memory_region_ops_size);
// Step 2: Initialize the opaque filesystem structure
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
@ -256,32 +261,46 @@ static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
// Step 3: Actual registration
return Env::Default()->RegisterFileSystem(
info->scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops)));
info->ops[index].scheme,
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
std::move(filesystem), std::move(core_filesystem_ops),
std::move(core_random_access_file_ops),
std::move(core_writable_file_ops),
std::move(core_read_only_memory_region_ops),
info->plugin_memory_allocate, info->plugin_memory_free));
}
// Registers all filesystems, if plugin is providing valid information.
// Registers filesystem at `index`, if plugin is providing valid information.
//
// Extracted to a separate function so that pointers inside `info` are freed
// by the caller regardless of whether validation/registration failed or not.
//
// Must be called only with `index` a valid index in `info->ops`.
static Status ValidateAndRegisterFilesystems(
const TF_FilesystemPluginInfo* info) {
TF_RETURN_IF_ERROR(ValidateScheme(info->scheme));
TF_RETURN_IF_ERROR(ValidateABI(info));
ValidateAPI(info); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(info));
TF_RETURN_IF_ERROR(RegisterFileSystem(info));
const TF_FilesystemPluginInfo* info, int index) {
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
return Status::OK();
}
// Alocates memory in plugin DSO.
//
// Provided by core TensorFlow so that it can free this memory after DSO is
// loaded and filesystem information has been used to register the filesystem.
static void* basic_allocator(size_t size) { return calloc(1, size); }
// Ensures that the plugin provides the required memory management operations.
static Status ValidatePluginMemoryRoutines(
const TF_FilesystemPluginInfo* info) {
if (info->plugin_memory_allocate == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_allocate`");
if (info->plugin_memory_free == nullptr)
return errors::FailedPrecondition(
"Cannot load filesystem plugin which does not provide "
"`plugin_memory_free`");
return Status::OK();
}
namespace filesystem_registration {
@ -297,26 +316,28 @@ Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
// Step 3: Call `TF_InitPlugin`
TF_FilesystemPluginInfo* info = nullptr;
auto TF_InitPlugin = reinterpret_cast<int (*)(
decltype(&basic_allocator), TF_FilesystemPluginInfo**)>(dso_symbol);
int num_schemes = TF_InitPlugin(&basic_allocator, &info);
if (num_schemes < 0 || info == nullptr)
return errors::InvalidArgument("DSO returned invalid filesystem data");
TF_FilesystemPluginInfo info;
memset(&info, 0, sizeof(info));
auto TF_InitPlugin =
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
TF_InitPlugin(&info);
// Step 4: Validate and register all filesystems
// Step 4: Ensure plugin provides the memory management functions.
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
// Step 5: Validate and register all filesystems
// Try to register as many filesystems as possible.
// Free memory once we no longer need it
Status status;
for (int i = 0; i < num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info[i]));
free(info[i].scheme);
free(info[i].filesystem_ops);
free(info[i].random_access_file_ops);
free(info[i].writable_file_ops);
free(info[i].read_only_memory_region_ops);
for (int i = 0; i < info.num_schemes; i++) {
status.Update(ValidateAndRegisterFilesystems(&info, i));
info.plugin_memory_free(info.ops[i].scheme);
info.plugin_memory_free(info.ops[i].filesystem_ops);
info.plugin_memory_free(info.ops[i].random_access_file_ops);
info.plugin_memory_free(info.ops[i].writable_file_ops);
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
}
free(info);
info.plugin_memory_free(info.ops);
return status;
}

View File

@ -1569,7 +1569,7 @@ TEST_P(ModularFileSystemTest, TestRoundTrip) {
if (!status.ok())
GTEST_SKIP() << "NewRandomAccessFile() not supported: " << status;
char scratch[64 /* big enough to accomodate test_data */] = {0};
char scratch[64 /* big enough to accommodate test_data */] = {0};
StringPiece result;
status = read_file->Read(0, test_data.size(), &result, scratch);
EXPECT_PRED2(UnimplementedOrReturnsCode, status, Code::OK);

View File

@ -31,6 +31,9 @@ limitations under the License.
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -43,7 +46,9 @@ typedef struct PosixFile {
static void Cleanup(TF_RandomAccessFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
close(posix_file->fd);
free(const_cast<char*>(posix_file->filename));
// This would be safe to free using `free` directly as it is only opaque.
// However, it is better to be consistent everywhere.
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -98,7 +103,7 @@ typedef struct PosixFile {
static void Cleanup(TF_WritableFile* file) {
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
free(const_cast<char*>(posix_file->filename));
plugin_memory_free(const_cast<char*>(posix_file->filename));
delete posix_file;
}
@ -381,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
if (num_entries < 0) {
TF_SetStatusFromIOError(status, errno, path);
} else {
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(dir_entries[i]->d_name);
free(dir_entries[i]);
plugin_memory_free(dir_entries[i]);
}
free(dir_entries);
plugin_memory_free(dir_entries);
}
return num_entries;
@ -394,65 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
} // namespace tf_posix_filesystem
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
const int num_schemes = 2;
*info = static_cast<TF_FilesystemPluginInfo*>(
allocator(num_schemes * sizeof((*info)[0])));
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
for (int i = 0; i < num_schemes; i++) {
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
TF_SetFilesystemVersionMetadata(current_info);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
current_info->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
allocator(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
current_info->random_access_file_ops->cleanup =
tf_random_access_file::Cleanup;
current_info->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->writable_file_ops->append = tf_writable_file::Append;
ops->writable_file_ops->tell = tf_writable_file::Tell;
ops->writable_file_ops->flush = tf_writable_file::Flush;
ops->writable_file_ops->sync = tf_writable_file::Sync;
ops->writable_file_ops->close = tf_writable_file::Close;
current_info->writable_file_ops =
static_cast<TF_WritableFileOps*>(allocator(TF_WRITABLE_FILE_OPS_SIZE));
current_info->writable_file_ops->cleanup = tf_writable_file::Cleanup;
current_info->writable_file_ops->append = tf_writable_file::Append;
current_info->writable_file_ops->tell = tf_writable_file::Tell;
current_info->writable_file_ops->flush = tf_writable_file::Flush;
current_info->writable_file_ops->sync = tf_writable_file::Sync;
current_info->writable_file_ops->close = tf_writable_file::Close;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
current_info->read_only_memory_region_ops =
static_cast<TF_ReadOnlyMemoryRegionOps*>(
allocator(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
current_info->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
current_info->read_only_memory_region_ops->data =
tf_read_only_memory_region::Data;
current_info->read_only_memory_region_ops->length =
tf_read_only_memory_region::Length;
current_info->filesystem_ops =
static_cast<TF_FilesystemOps*>(allocator(TF_FILESYSTEM_OPS_SIZE));
current_info->filesystem_ops->init = tf_posix_filesystem::Init;
current_info->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
current_info->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
current_info->filesystem_ops->new_writable_file =
tf_posix_filesystem::NewWritableFile;
current_info->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
current_info->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
current_info->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
current_info->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
current_info->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
current_info->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
current_info->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
current_info->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
current_info->filesystem_ops->stat = tf_posix_filesystem::Stat;
current_info->filesystem_ops->get_children =
tf_posix_filesystem::GetChildren;
}
(*info)[0].scheme = strdup("");
(*info)[1].scheme = strdup("file");
return num_schemes;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_posix_filesystem::Init;
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_posix_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_posix_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -21,6 +21,9 @@ limitations under the License.
// Implementation of a filesystem for POSIX environments.
// This filesystem will support `file://` and empty (local) URI schemes.
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
@ -53,18 +56,18 @@ namespace tf_windows_filesystem {
} // namespace tf_windows_filesystem
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
const int num_schemes = 2;
*info = static_cast<TF_FilesystemPluginInfo*>(
allocator(num_schemes * sizeof((*info)[0])));
for (int i = 0; i < num_schemes; i++) {
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
TF_SetFilesystemVersionMetadata(current_info);
}
(*info)[0].scheme = strdup("");
(*info)[1].scheme = strdup("file");
return num_schemes;
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 2;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "");
ProvideFilesystemSupportFor(&info->ops[1], "file");
}

View File

@ -18,19 +18,36 @@ limitations under the License.
#include "tensorflow/c/kernels.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <memory>
#include <string>
#include "absl/container/inlined_vector.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
struct MyCustomKernel {
bool created;

View File

@ -41,6 +41,16 @@ filegroup(
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"training/coordinator.h",
],
visibility = [
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "gradients",
srcs = [

View File

@ -33,6 +33,7 @@ cc_library(
deps = [
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -133,8 +134,9 @@ cc_library(
# tfcompile.bzl correctly handles usage from outside of the package that it is
# defined in.
# A simple test of tf_library from a text protobuf, mostly to enable the
# benchmark_test.
# A simple test of tf_library from a text protobuf, to enable benchmark_test.
# This test uses an incompleted graph with a node that is not defined. The
# compilation works because the undefined node is a feed node.
tf_library(
name = "test_graph_tfadd",
testonly = 1,
@ -146,8 +148,21 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfadd_mlir_bridge",
testonly = 1,
config = "test_graph_tfadd.config.pbtxt",
cpp_class = "AddComp",
graph = "test_graph_tfadd.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is not needed for the fetches.
# the compilation works because the node with the unknown op is not needed
# for the fetches.
tf_library(
name = "test_graph_tfunknownop",
testonly = 1,
@ -159,9 +174,21 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfunknownop_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the op between the unknown op and the
# fetches is a feed.
# the compilation works because the node with the unknown op is only used as
# an input of a feed node.
tf_library(
name = "test_graph_tfunknownop2",
testonly = 1,
@ -173,8 +200,20 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfunknownop2_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop2.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# A test of tf_library that includes a graph with an unknown op, but where
# the compilation works because the unknown op is fed.
# the compilation works because the node with the unknown op is a feed node.
tf_library(
name = "test_graph_tfunknownop3",
testonly = 1,
@ -186,6 +225,18 @@ tf_library(
],
)
tf_library(
name = "test_graph_tfunknownop3_mlir_bridge",
testonly = 1,
config = "test_graph_tfunknownop3.config.pbtxt",
cpp_class = "UnknownOpAddComp",
graph = "test_graph_tfunknownop.pbtxt",
mlir_components = "Bridge",
tags = [
"manual",
],
)
# Utility library for benchmark binaries, used by the *_benchmark rules that are
# added by the tfcompile bazel macro.
cc_library(
@ -260,9 +311,13 @@ test_suite(
tests = [
":benchmark_test",
":codegen_test",
":test_graph_tfadd_mlir_bridge_test",
":test_graph_tfadd_test",
":test_graph_tfunknownop2_mlir_bridge_test",
":test_graph_tfunknownop2_test",
":test_graph_tfunknownop3_mlir_bridge_test",
":test_graph_tfunknownop3_test",
":test_graph_tfunknownop_mlir_bridge_test",
":test_graph_tfunknownop_test",
"//tensorflow/compiler/aot/tests:all_tests",
],

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
@ -142,7 +143,7 @@ static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
}
}
static std::once_flag targets_init;
static absl::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
@ -167,7 +168,7 @@ static void InitializeTargets() {
}
Status Main(const MainFlags& flags) {
std::call_once(targets_init, &InitializeTargets);
absl::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;

View File

@ -349,6 +349,18 @@ tf_library(
],
)
tf_library(
name = "test_graph_tffunction_mlir_bridge",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
cpp_class = "FunctionComp",
graph = "test_graph_tffunction.pb",
mlir_components = "Bridge",
tags = [
"manual",
],
)
tf_library(
name = "test_graph_tfassert_eq_mlir_bridge",
testonly = 1,
@ -484,6 +496,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
":test_graph_tfassert_eq_mlir_bridge",
":test_graph_tfcond_mlir_bridge",
":test_graph_tffunction_mlir_bridge",
":test_graph_tfgather_mlir_bridge",
":test_graph_tfmatmul_mlir_bridge",
":test_graph_tfmatmulandadd_mlir_bridge",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
}
}
// TODO(bixia): the following tests failed with MLIR bridge.
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
EXPECT_EQ(add_fn.result0_data()[0], 3);
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
}
#endif
TEST(TFCompileTest, Splits) {
Eigen::ThreadPool tp(1);

View File

@ -57,6 +57,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -70,6 +71,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_or_rocm([
":jit_compilation_passes",
":xla_kernel_creator", # buildcleaner: keep
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
@ -157,7 +159,9 @@ XLA_DEVICE_DEPS = [
":common",
":xla_launch_util",
":xla_tensor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:optional",
"//tensorflow/compiler/jit/ops:xla_ops",
@ -250,13 +254,26 @@ cc_library(
}),
)
# Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/strings",
],
)
# Header-only version of "flags" library, for linking from the shared object
# without ODR violations.
cc_library(
name = "flags_headers_only",
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
@ -276,6 +293,8 @@ cc_library(
visibility = [":friends"],
)
# Internal targets below this point.
cc_library(
name = "xla_launch_util",
srcs = ["xla_launch_util.cc"],
@ -397,6 +416,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
@ -625,6 +645,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",

View File

@ -266,9 +266,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
&handle);
}
if (!s.ok()) {
std::string uncompilable_reason = "could not instantiate call";
std::string uncompilable_reason =
absl::StrCat("could not instantiate call: '", function.name(), "'");
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
encapsulating_function, uncompilable_nodes);
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "

View File

@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/flags.h"
#include <mutex> // NOLINT
#include "absl/base/call_once.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
absl::once_flag flags_init;
bool SetterForXlaAutoJitFlag(const string& value) {
int32 opt_level;
@ -213,38 +216,45 @@ void AllocateAndParseFlags() {
} // namespace
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return SetterForXlaAutoJitFlag(value);
}
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return build_ops_flags;
}
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return mark_for_compilation_flags;
}
XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return device_flags;
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *ops_flags;
}
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
return *jitter_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "absl/base/call_once.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
device_type.type_string() == DEVICE_CPU) {
static std::once_flag once;
std::call_once(once, [] {
static absl::once_flag once;
absl::call_once(once, [] {
LOG(WARNING)
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "

View File

@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
build_options.set_device_allocator(options.device_allocator);
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*executable = std::move(compile_result.ValueOrDie());
TF_ASSIGN_OR_RETURN(
auto executables,
client_->Compile(*result.computation, argument_layouts, build_options));
TF_RET_CHECK(executables.size() == 1);
*executable = std::move(executables[0]);
return Status::OK();
}

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <unordered_set>
#include <utility>
#include "absl/base/call_once.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
return Status::OK();
}
// Warn about XLA_CPU/XLA_GPU exactly once.
static void ShowXlaDeviceDeprecationWarning(
absl::string_view compilation_device_name) {
static absl::once_flag once;
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) {
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
<< op_kernel->type_string();
op_kernel->ComputeAsync(context, done);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -143,11 +143,11 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
node_def.ShortDebugString(), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:\n");
SummarizeNodeDef(node_def), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message =
absl::StrCat("\t", node_info.name, ": ",
absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
@ -156,7 +156,6 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
// node_def is calling a function that XLA can't compile.
return errors::InvalidArgument(message);
}

View File

@ -66,6 +66,8 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
@ -77,10 +79,10 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
"//tensorflow/compiler/mlir/xla:xla_lower",
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",

View File

@ -26,9 +26,11 @@ package_group(
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
],
)
@ -43,10 +45,29 @@ gentbl(
"-gen-op-defs",
"ir/tfl_ops.cc.inc",
),
(
"-gen-struct-attr-decls",
"ir/tfl_structs.h.inc",
),
(
"-gen-struct-attr-defs",
"ir/tfl_structs.cc.inc",
),
(
"-gen-op-doc",
"g3doc/tfl_ops.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_ops.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
)
gentbl(
name = "tensorflow_lite_op_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"ir/tfl_ops_interface.h.inc",
@ -57,7 +78,7 @@ gentbl(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/tfl_ops.td",
td_file = "ir/tfl_op_interfaces.td",
td_srcs = [
":tensorflow_lite_ops_td_files",
],
@ -199,8 +220,6 @@ cc_library(
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
@ -209,6 +228,10 @@ cc_library(
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
@ -267,6 +290,7 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/dilated_conv.cc",
"transforms/extract_ophint.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
@ -280,8 +304,10 @@ cc_library(
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
@ -291,15 +317,19 @@ cc_library(
":stateful_ops_utils",
":tensorflow_lite",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
@ -357,6 +387,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
@ -370,6 +401,24 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "tensorflow_lite_d2s",
srcs = [
"transforms/dense_to_sparse.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
alwayslink = 1,
)
filegroup(
name = "generated_op_quant_spec_getters",
srcs = [
@ -381,6 +430,8 @@ genrule(
name = "op_quant_spec_getters_inc",
srcs = [
"ir/tfl_ops.td",
"ir/tfl_op_interfaces.td",
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
],
outs = [
@ -628,6 +679,7 @@ cc_library(
],
deps = [
":common",
":tensorflow_lite_d2s",
":tensorflow_lite_legalize_tf",
":tensorflow_lite_optimize",
":tensorflow_lite_quantize",

View File

@ -90,6 +90,7 @@ using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::NoneType;
using mlir::Operation;
using mlir::Region;
using mlir::StringAttr;
using mlir::TensorType;
using mlir::TranslateFromMLIRRegistration;
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
return true;
}
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
::mlir::Operation* inst) {
// We pass empty string for the original node_def name since Flex runtime
// does not care about this being set correctly on node_def. There is no
@ -425,6 +426,11 @@ class Translator {
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Build while operator where cond & body are regions.
Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
@ -472,7 +478,10 @@ class Translator {
Operation* inst, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
// Build a subgraph with a given name out of the region either corresponding
// to a function's body or while op.
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
const std::string& name, Region* region);
// Builds Metadata with the given `name` and buffer `content`.
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
@ -523,7 +532,7 @@ class Translator {
};
std::string Translator::UniqueName(mlir::Value val) {
return name_mapper_.GetUniqueName(val);
return std::string(name_mapper_.GetUniqueName(val));
}
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
@ -539,9 +548,14 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
attr = cst.value();
} else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
attr = cst.value();
} else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
attr = cst.value();
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
attr = cst.value();
} else {
return empty_buffer_;
}
tensorflow::Tensor tensor;
auto status = tensorflow::ConvertToTensor(attr, &tensor);
if (!status.ok()) {
@ -595,6 +609,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
};
std::vector<int32_t> shape;
std::vector<int32_t> shape_signature;
if (type.hasStaticShape()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
@ -612,7 +627,25 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
} else if (type.hasRank()) {
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape.reserve(shape_ref.size());
for (auto& dim : shape_ref) {
shape.push_back(dim == -1 ? 1 : dim);
}
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
}
if (auto* inst = value.getDefiningOp()) {
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
// CreateSparsityParameters(cst.s_param());
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
// CreateSparsityParameters(cst.s_param());
}
}
Type element_type = type.getElementType();
tflite::TensorType tflite_element_type =
GetTFLiteType(type.getElementType()).ValueOrDie();
@ -649,10 +682,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
break;
}
}
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
if (shape_signature.empty()) {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable);
} else {
return tflite::CreateTensor(
builder_, builder_.CreateVector(shape), tflite_element_type,
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
/*is_variable=*/is_variable, /*sparsity=*/0,
/*shape_signature=*/builder_.CreateVector(shape_signature));
}
}
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
@ -687,6 +729,32 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
builtin_options);
}
Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
if (b.getOperations().size() != 2) return llvm::None;
if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
return subgraph_index_map_.at(call_op.callee().str());
return llvm::None;
};
auto body_subgraph_index = get_call_index(op.body().front());
auto cond_subgraph_index = get_call_index(op.cond().front());
if (!body_subgraph_index || !cond_subgraph_index)
return op.emitOpError("only single call cond/body while export supported"),
llvm::None;
auto builtin_options =
tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
*body_subgraph_index)
.Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_WhileOptions,
builtin_options);
}
template <typename CustomOptionType, typename TFLOp>
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
const CustomOptionType& custom_option, const std::string& opcode_name,
@ -908,6 +976,16 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
results);
}
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
if (inst->getNumOperands() != inst->getNumResults()) {
inst->emitOpError(
"number of operands and results don't match, only canonical "
"TFL While supported");
return llvm::None;
}
return BuildWhileOperator(whileOp, operands, results);
}
inst->emitOpError("is not a supported TFLite op");
return llvm::None;
}
@ -944,7 +1022,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
auto node_def = getTensorFlowNodeDef(inst);
auto node_def = GetTensorFlowNodeDef(inst);
if (!node_def) {
return llvm::None;
}
@ -1047,9 +1125,12 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
}
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
const std::string& name, Region* region) {
bool has_input_attr = false;
InitializeNamesFromAttribute(fn, &has_input_attr);
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
InitializeNamesFromAttribute(fn, &has_input_attr);
}
std::vector<BufferOffset<tflite::Tensor>> tensors;
llvm::DenseMap<Value, int> tensor_index_map;
@ -1081,7 +1162,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
};
std::vector<BufferOffset<tflite::Operator>> operators;
auto& bb = fn.getBlocks().front();
auto& bb = region->front();
// Main function's arguments are first passed to `input` op so they don't
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
@ -1089,7 +1170,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
mlir::BlockArgument arg = bb.getArgument(i);
std::string name;
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
if (name.empty()) name = absl::StrCat("arg", i);
if (!build_tensor_and_buffer(arg, name)) return llvm::None;
}
@ -1141,7 +1222,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
return tflite::CreateSubGraph(
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
builder_.CreateVector(outputs), builder_.CreateVector(operators),
/*name=*/builder_.CreateString(fn.getName().str()));
/*name=*/builder_.CreateString(name));
}
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
@ -1184,35 +1265,36 @@ Optional<std::string> Translator::Translate(
}
Optional<std::string> Translator::TranslateInternal() {
// Create a list of functions in the module with main function being the
// first function in the list. This is required as the first subgraph in the
// model is entry point for the model.
std::vector<FuncOp> functions;
functions.reserve(std::distance(module_.begin(), module_.end()));
// A list of named regions in the module with main function being the first in
// the list. The main function is required as the first subgraph in the model
// is entry point for the model.
std::vector<std::pair<std::string, Region*>> named_regions;
named_regions.reserve(std::distance(module_.begin(), module_.end()));
int subgraph_idx = 0;
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
functions.push_back(main_fn);
for (auto fn : module_.getOps<FuncOp>()) {
if (fn == main_fn) continue;
named_regions.emplace_back("main", &main_fn.getBody());
// Walk over the module collection ops with functions and while ops.
module_.walk([&](FuncOp fn) {
if (fn != main_fn) {
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
}
});
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
functions.push_back(fn);
}
// Build subgraph for each of the functions.
// Build subgraph for each of the named regions.
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
subgraphs.reserve(functions.size());
subgraphs.reserve(named_regions.size());
int first_failed_func = -1;
for (int i = 0; i < functions.size(); ++i) {
auto subgraph_or = BuildSubGraph(functions[i]);
for (auto it : llvm::enumerate(named_regions)) {
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
if (!subgraph_or) {
if (first_failed_func == -1)
// Record the index of the first function that cannot be converted.
// Record the index of the first region that cannot be converted.
// Keep looping through all subgraphs in the module to make sure that
// we collect the list of missing ops from the entire module.
first_failed_func = i;
first_failed_func = it.index();
} else {
subgraphs.push_back(*subgraph_or);
}
@ -1233,9 +1315,10 @@ Optional<std::string> Translator::TranslateInternal() {
"-emit-custom-ops flag): " +
failed_custom_ops_list;
return functions[first_failed_func].emitError("failed while converting: '")
<< functions[first_failed_func].getName() << "\'\n"
<< err,
auto& failed_region = named_regions[first_failed_func];
return failed_region.second->getParentOp()->emitError()
<< "failed while converting: '" << failed_region.first
<< "': " << err,
llvm::None;
}

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation interface definition file for TensorFlow Lite.
#ifndef TFL_OP_INTERFACES
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for sparse operands.
def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
let description = [{
Interface for ops that support sparse computation.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of sparse operands.}],
"std::vector<int>", "GetSparseOperands", (ins)
>,
];
}
#endif // TFL_OP_INTERFACES

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
//===----------------------------------------------------------------------===//
@ -797,8 +798,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
// With
// %2 = "tfl.reshape"(%0, %shape1)
rewriter.replaceOpWithNewOp<ReshapeOp>(
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0),
thisOp.getOperand(1));
op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
}
};
@ -1302,6 +1302,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//
@ -1724,6 +1737,67 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
namespace {
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
using OpRewritePattern<WhileOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(WhileOp while_op,
PatternRewriter &rewriter) const override {
auto size = while_op.body().front().getArguments().size();
Operation *op = while_op.getOperation();
auto old_size = op->getNumResults();
// No change needed as the number of operands match the number of results.
if (size == old_size) return matchFailure();
// Collect the new types by combining results of old op with additional
// operand results.
llvm::SmallVector<Type, 4> types;
types.reserve(size);
for (auto type : while_op.getResultTypes()) types.push_back(type);
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
types.push_back(arg.getType());
// Collect operands.
llvm::SmallVector<Value, 8> operands;
operands.reserve(while_op.getNumOperands());
for (auto operand : while_op.getOperands()) operands.push_back(operand);
// Replace with new While with matching operands and results.
Operation *new_op = rewriter.insert(
Operation::create(op->getLoc(), op->getName(), types, operands,
op->getAttrs(), {}, /*numRegions=*/2,
/*resizableOperandList=*/true));
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
rewriter.replaceOp(op,
new_op->getResults().take_front(op->getNumResults()));
return matchSuccess();
}
};
} // namespace
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<WhileResultOperandsMatch>(context);
}
Region &WhileOp::getLoopBody() { return body(); }
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
// TODO(jpienaar): This is to overly conservative and disables anything other
// than constant hoisting initially.
return false;
}
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
if (ops.empty()) return success();
// Move the hoisted value to just before the while.
Operation *while_op = this->getOperation();
for (auto op : ops) op->moveBefore(while_op);
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -27,10 +27,12 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc"
namespace TFL {
class TensorFlowLiteDialect : public Dialect {

View File

@ -19,6 +19,8 @@ limitations under the License.
#define TFL_OPS
include "mlir/IR/OpBase.td"
include "mlir/Transforms/LoopLikeInterface.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
def TFL_Dialect : Dialect {
@ -248,41 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
let description = [{
Interface for ops that are stateful and need to identify stateful operands.
Stateful operands correspond to TF's variables semantics. An op that has 1
or more stateful operands is a stateful op.
}];
let methods = [
InterfaceMethod<
[{Returns the indices of stateful operands.}],
"std::vector<int>", "GetStatefulOperands", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op interface for output channel index.
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
let description = [{
Interface for defining the index of out channel index.
}];
let methods = [
InterfaceMethod<
[{Returns the dimension index of the output channels.}],
"int", "GetChannelDimIndex", (ins)
>,
];
}
//===----------------------------------------------------------------------===//
// TFL op base class.
//===----------------------------------------------------------------------===//
@ -583,14 +550,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
let arguments = (
ins Variadic<TensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>>:$values,
[F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>>:$values,
I32Attr:$axis,
TFL_AFAttr:$fused_activation_function
);
let results = (outs
TensorOf<
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output
[F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$output
);
let hasOptions = 1;
@ -627,6 +594,57 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
];
}
// Attributes used for encoding sparse tensors.
// Please find detailed explanation of these parameters in the TFLite schema.
def TFL_DT_Dense : StrEnumAttrCase<"DENSE", 0>;
def TFL_DT_SparseCSR : StrEnumAttrCase<"SPARSE_CSR", 1>;
def TFL_DimensionTypeAttr : StrEnumAttr<
"DimensionType", "dimension type", [TFL_DT_Dense, TFL_DT_SparseCSR]>;
def DimensionMetadataAttr : StructAttr<"DimensionMetadataAttr", TFL_Dialect, [
StructFieldAttr<"format", TFL_DimensionTypeAttr>,
StructFieldAttr<"dense_size", I32Attr>,
StructFieldAttr<"segments", I32ArrayAttr>,
StructFieldAttr<"indices", I32ArrayAttr>] > {
let description = "Dimension metadata.";
}
def DimensionMetadataArrayAttr : TypedArrayAttrBase<DimensionMetadataAttr,
"Array of DimensionMetadata">{}
def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [
StructFieldAttr<"traversal_order", I32ArrayAttr>,
StructFieldAttr<"block_map", I32ArrayAttr>,
StructFieldAttr<"dim_metadata", DimensionMetadataArrayAttr>]> {
let description = "Sparsity parameter.";
let storageType = [{ TFL::SparsityParameterAttr }];
}
def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [NoSideEffect,
FirstAttrDerivedResultType]> {
let summary = "Sparse constant pseudo op.";
let description = [{
Represents a sparse constant value in TensorFlow Lite dialect. This is not
an actual operation and it will be lowered to buffer instead.
}];
let arguments = (ins ElementsAttr:$value, SparsityParameterAttr:$s_param);
let results = (outs AnyTensor:$output);
let builders = [OpBuilder<
"Builder *, OperationState &state, Attribute value, "
"SparsityParameterAttr s_param",
[{
state.addTypes(value.getType());
state.addAttribute("value", value);
state.addAttribute("s_param", s_param);
}]>
];
}
def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
let summary = "External const op.";
@ -685,7 +703,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>]> {
AffineOpCoefficient<-1, 1>,
TFL_SparseOp]> {
let summary = "Fully connected op";
let arguments = (ins
@ -710,6 +729,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
}];
}
@ -718,7 +739,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
TFL_TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "Gather operator";
@ -727,7 +748,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
}];
let arguments = (ins
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
@ -740,7 +761,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
];
let results = (outs
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
);
let hasOptions = 1;
@ -1102,7 +1123,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
let hasOptions = 0b1;
}
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Inserts a dimension of 1 into a tensor's shape.";
let description = [{
@ -1694,7 +1716,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Min-reduction operator";
let description = [{
@ -1713,7 +1736,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
let customOption = "ReducerOptions";
}
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Max-reduction operator";
let description = [{
@ -1810,6 +1834,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
let hasFolder = 1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
@ -1843,14 +1869,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -2466,7 +2492,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
let hasFolder = 1;
}
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{
@ -2632,12 +2658,12 @@ def TFL_SplitOp : TFL_Op<"split", [
let arguments = (ins
TensorOf<[I32]>:$split_dim,
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
PositiveI32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
@ -2655,14 +2681,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
}];
let arguments = (ins
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
1DTensorOf<[I32]>:$size_splits,
0DTensorOf<[I32]>:$split_dim,
PositiveI32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
@ -2793,12 +2819,11 @@ def TFL_CastOp : TFL_Op<"cast", [
Casts input from input type to output type.
}];
// TODO(b/135538711): Add complex types here.
let arguments = (ins
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
);
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
@ -2898,7 +2923,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let arguments = (
ins AnyTensor:$input,
// The expected [min, max] range of values.
MinMaxAttr:$minmax,
F32Attr:$min,
F32Attr:$max,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I32Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
@ -2907,6 +2934,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let results = (outs AnyTensor:$output);
let hasCanonicalizer = 0b1;
let hasOptions = 1;
}
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
@ -2936,6 +2965,36 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
];
}
def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
let summary = "Sparse quantized constant pseudo op";
let description = [{
Represents a sparse quantized constant value in TensorFlow Lite dialect.
This is not an actual operation and it will be lowered to buffer instead.
The quantization parameters are stored as a type attribute in this constant.
}];
let arguments = (
ins TensorTypeAttr:$qtype,
ElementsAttr:$value,
SparsityParameterAttr:$s_param
);
let results = (outs AnyTensor:$output);
let builders = [OpBuilder<
"Builder *, OperationState &state, TypeAttr qtype, "
"Attribute value, SparsityParameterAttr s_param",
[{
state.addTypes(qtype.getValue());
state.addAttribute("qtype", qtype);
state.addAttribute("value", value);
state.addAttribute("s_param", s_param);
}]>
];
}
def TFL_QuantizeOp: TFL_Op<"quantize", [
FirstAttrDerivedResultType, NoQuantizableResult]> {
let summary = "Quantize operator";
@ -3396,4 +3455,48 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
let results = (outs TensorOf<[F32, I32]>:$output);
}
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., while). The operation takes
variable number of operands and produces no results. The operand number and
types must match the signature of the region that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def TFL_WhileOp : Op<TFL_Dialect, "while", [
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = [{While loop}];
let description = [{
output = input; while (cond(output)) { output = body(output) }
While loop where all values are passes through arguments with implicit
capture.
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A region takes 'input' and returns a boolean scalar tensor.
body: A region that takes a list of tensors and returns another
list of tensors. Both lists have the same types.
}];
let arguments = (ins
Variadic<AnyTensor>:$input,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let results = (outs Variadic<AnyTensor>:$output);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let hasCanonicalizer = 1;
}
#endif // TFL_OPS

View File

@ -41,13 +41,20 @@ limitations under the License.
#include "tensorflow/lite/delegates/flex/delegate.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/optional_debug_tools.h"
using llvm::cl::desc;
using llvm::cl::init;
using llvm::cl::opt;
// NOLINTNEXTLINE
static opt<std::string> inputFileName(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
static opt<std::string> input_filename(llvm::cl::Positional,
desc("<input file>"), init("-"));
// NOLINTNEXTLINE
static opt<bool> dump_state("dump-interpreter-state",
desc("dump interpreter state post execution"),
init(false));
// TODO(jpienaar): Move these functions to some debug utils.
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
@ -82,9 +89,9 @@ int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
<< "': " << error.message() << "\n";
return 1;
}
@ -133,5 +140,7 @@ int main(int argc, char** argv) {
TfLiteTensorString(out).c_str());
}
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
return 0;
}

View File

@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
os << formatv(
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
val.getName(), record->getClasses()[0]->getName());
options.push_back(val.getName());
options.push_back(std::string(val.getName()));
}
}
}

View File

@ -71,18 +71,17 @@ cc_library(
"quantization_utils.cc",
],
hdrs = [
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
# TODO(fengliuai): remove this dependence.
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:lib_proto_parsing",
],
)

View File

@ -12,6 +12,7 @@ package_group(
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/lite/...",
],
)
@ -41,6 +42,24 @@ cc_library(
],
)
cc_library(
name = "tfl_to_std",
srcs = [
"tfl_to_std.cc",
],
hdrs = [
"tfl_to_std.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
)
# Binary to apply quantization on the annotated files.
tf_cc_binary(
name = "tfl_quantizer",

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
auto dcast = b.create<quant::DequantizeCastOp>(
dq.getLoc(), dq.output().getType(), dq.input());
dq.output().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
auto qcast = b.create<quant::QuantizeCastOp>(
q.getLoc(), q.output().getType(), q.input());
q.output().replaceAllUsesWith(qcast);
q.erase();
}
});
}
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
OpBuilder b(func);
func.walk([&](Operation* op) {
b.setInsertionPoint(op);
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
q.erase();
}
});
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
#include "mlir/IR/Function.h" // TF:llvm-project
namespace mlir {
namespace TFL {
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
// dialect ones in the function.
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
// ops in the function.
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_

View File

@ -22,21 +22,6 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/QuantOps/QuantPredicates.td"
//===----------------------------------------------------------------------===//
// Min-max range pair definitions.
//===----------------------------------------------------------------------===//
// A pair of floating point values which defines the min and max of a value
// range for quantization. The attribute is allowed to be empty or
// have 2 elements.
def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
CPred<"$_self.cast<ArrayAttr>().size() == 2">]>,
"min-max range pair"> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayRef<Attribute> }];
}
//===----------------------------------------------------------------------===//
// QuantizedType definitions.
//===----------------------------------------------------------------------===//

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
@ -34,13 +36,14 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/core/platform/logging.h"
#define DEBUG_TYPE "quantization-driver"
namespace mlir {
namespace TFL {
namespace quant {
namespace {
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
@ -281,6 +284,37 @@ class QuantizationDriver {
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
}
void DumpStates(Operation *current_op) {
if (current_op) {
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
}
fn_.walk([&](Operation *op) {
if (llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
return;
if (current_op == op) llvm::errs() << "===>>>";
llvm::errs() << op->getName() << " : (";
for (auto i = 0; i < op->getNumOperands(); ++i) {
if (auto params = GetOperandQuantState(op, i).params)
params.print(llvm::errs());
else
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ") -> (";
for (auto i = 0; i < op->getNumResults(); ++i) {
if (auto params = GetResultQuantState(op, i).params)
params.print(llvm::errs());
else
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
llvm::errs());
llvm::errs() << ",";
}
llvm::errs() << ")\n";
});
}
FuncOp fn_;
OpBuilder builder_;
bool is_signed_;
@ -350,7 +384,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
}
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
ElementsAttr attr;
DenseFPElementsAttr attr;
Value res = op->getResult(0);
if (!matchPattern(res, m_Constant(&attr))) {
return false;
@ -457,11 +491,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto quantize =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
quantize.output());
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);
@ -475,7 +507,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) {
if (llvm::isa<quant::QuantizeCastOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0);
builder_.setInsertionPointAfter(user);
@ -490,8 +522,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
builder_.setInsertionPointToStart(arg.getOwner());
if (value.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
}
}
@ -518,9 +550,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
// This value isn't an expressed type (float), skip.
if (!new_type) return;
TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
}
@ -650,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
// If the argument is quantized, it should only has one user.
if (arg.hasOneUse()) {
auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
value = q.getResult();
}
}
InitializeArgState(arg, value, &value_to_state);
@ -659,7 +690,9 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) ||
llvm::isa<quant::QuantizeCastOp>(op))
return;
work_list_.push_back(op);
@ -668,8 +701,8 @@ void QuantizationDriver::SetupAllStates() {
if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
operand = dq.input();
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
operand = dq.arg();
}
}
InitializeOperandState(op, i, operand, &value_to_state);
@ -682,8 +715,8 @@ void QuantizationDriver::SetupAllStates() {
// create the state and mark it immutable.
if (result.hasOneUse()) {
auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output();
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
result = q.getResult();
}
}
InitializeResultState(op, res, result, &value_to_state);
@ -713,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
Operation *op = work_list_.back();
work_list_.pop_back();
LLVM_DEBUG(DumpStates(op));
// This op has been quantized, so we should not consider it again.
if (llvm::is_contained(quantized_, op)) continue;
quantized_.insert(op);
@ -737,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
}
// Use the final state to set all the operands' parameters.
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
changed |= SetOperandParams(op, i, params);
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetOperandParams(op, i, params);
}
}
// Use the final state to set all the results' parameters.
for (int res = 0, e = op->getNumResults(); res != e; ++res)
changed |= SetResultParams(op, res, params);
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
// Without this check, it will accidently propagate the quantization
// information by the shared non-float-tensors.
if (type.getElementType().isa<FloatType>())
changed |= SetResultParams(op, res, params);
}
}
// TODO(fengliuai): make the bit width configurable.
@ -821,5 +867,5 @@ void ApplyQuantizationParamsPropagation(
.Run();
}
} // namespace TFL
} // namespace quant
} // namespace mlir

View File

@ -70,7 +70,8 @@ class FixedResultUniformScale {
QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation();
auto result_type =
op->getResult(index).getType().template cast<TensorType>();
op->getResult(index).getType().template cast<ShapedType>();
if (!result_type.getElementType().template isa<FloatType>()) return {};
Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) *

View File

@ -30,10 +30,9 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
namespace mlir {
namespace TFL {
namespace quant {
const float kNearZeroTolerance = 1.0e-6;
@ -400,7 +399,7 @@ static bool PreferResultScale(Operation* op) {
for (auto operand : op->getOperands()) {
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
if (operand_type.getElementType().isa<FloatType>()) {
if (float_operands++ > 1) return true;
if (++float_operands > 1) return true;
}
}
}
@ -460,7 +459,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
}
// Step 2: backward pass: For the ops skiped in the forward pass, propagate
// its results scale backwards.
// its results scale backwards as far as possible.
func.walk([&](quant::StatisticsOp stats_op) {
if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
all_stats_ops.push_back(stats_op);
@ -472,8 +471,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
all_stats_ops.pop_back();
if (auto def = stats_op.arg().getDefiningOp()) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
PreferResultScale(def)) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
for (auto input : def->getOperands()) {
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
input.getDefiningOp())) {
@ -496,5 +494,5 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
// Returns false if the steps finish without errors.
return false;
}
} // namespace TFL
} // namespace quant
} // namespace mlir

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
namespace TFL {
namespace quant {
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
rewriter.setInsertionPointAfter(op);
Type result_type = quant_type.castFromExpressedType(op.getType());
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type));
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg());
@ -168,9 +167,12 @@ struct QuantizationPattern : public RewritePattern {
return matchFailure();
}
// If it is terminator or not quantizable, we shouldn't rewrite.
// If it is terminator or not quantizable or any ops form the mlir quant
// ops dialect, we shouldn't rewrite.
if (quantized_op->isKnownTerminator() ||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
return matchFailure();
}
@ -316,7 +318,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override {
Type output_type = op.output().getType();
Type output_type = op.getResult().getType();
auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure();
@ -355,8 +357,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
if (!new_qtype) return this->matchFailure();
Type new_output_type = new_qtype.castFromExpressedType(
QType::castToExpressedType(output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
TypeAttr::get(new_output_type));
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
return this->matchSuccess();
}
};
@ -444,7 +445,7 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
bool RemoveRedundantStatsOps(mlir::FuncOp func,
OpQuantSpecGetter op_quant_spec_getter);
} // namespace TFL
} // namespace quant
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_

View File

@ -0,0 +1,36 @@
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "tf_to_quant",
srcs = [
"tf_to_quant.cc",
],
hdrs = [
"passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
],
alwayslink = 1,
)

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
#include <memory>
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
namespace mlir {
namespace TF {
// Legalize the tf ops to the quant ops, so the quantization passes can work.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,148 @@
// RUN: tf-opt -tf-to-quant %s | FileCheck %s
// CHECK-LABEL: fakeQuantPerChannelForActivation
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
%arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
return %0 : tensor<8x3xf32>
// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
// CHECK: return %[[dq]]
}
// CHECK-LABEL: fakeQuantForActivation
func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
^bb0(%arg0: tensor<8xf32>):
%arg1 = constant dense<0.0> : tensor<f32>
%arg2 = constant dense<255.0> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %0 : tensor<8xf32>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %2 = "quant.dcast"(%1)
// CHECK: return %2
}
// CHECK-LABEL: fakeQuantForActivationNoDuplication
func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
^bb0(%arg0: tensor<8xf32>):
%arg1 = constant dense<0.0> : tensor<f32>
%arg2 = constant dense<255.0> : tensor<f32>
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
%1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: return %1
}
// CHECK-LABEL: fakeQuantFolded
func @fakeQuantFolded() -> (tensor<8xf32>) {
%in = constant dense<0.0> : tensor<8xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %rst : tensor<8xf32>
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
}
// CHECK-LABEL: fakeQuantNotFolded
func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
return %1 : tensor<8xf32>
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
// CHECK: return %0 : tensor<8xf32>
}
// CHECK-LABEL: fakeQuantWithConv2D
func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}
// CHECK-LABEL: perChannelFakeQuantWithConv2D
func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<16xf32>
%max = constant dense<255.0> : tensor<16xf32>
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<f32>
%max = constant dense<255.0> : tensor<f32>
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}
// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
^bb0(%arg: tensor<256x32x32x3xf32>) :
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
%min = constant dense<0.0> : tensor<16xf32>
%max = constant dense<255.0> : tensor<16xf32>
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
return %rst : tensor<256x30x30x16xf32>
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
// CHECK: return %[[CONV]]
}

View File

@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
//===----------------------------------------------------------------------===//
// The pass to legalize the quantization emulation ops from TF.
//
namespace {
// Legalize TF quantization emulation ops to that in Quant ops dialect.
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
explicit LegalizeTFToQuant() = default;
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
/// Performs the lowering to Quant ops dialect.
void runOnFunction() override;
};
// TODO(fengliuai): move this rule to PreparePatterns.td
// TODO(b/140968741): propagate the sign from the command line. Currently all
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
// actually INT8.
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
// folding logic will use a "std.constant" op to replace the
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
// convert the output type to the next op. Here are the transformations:
//
// input min cst max cst input min cst max cst
// \ | | \ | |
// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
// \ | | \ | |
// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
// | |
// tf.quantize
// |
// tf.dequantize
// |
// If the input is a constant, the result pattern will eventually converted to
//
// quant-emulated input
// |
// tf.quantize
// |
// tf.dequantize
// |
template <typename TFFakeQuantOp, bool PerAxis>
struct InsertQuantOpsAfterTFFakeQuantOp
: public OpRewritePattern<TFFakeQuantOp> {
using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
MLIRContext *ctx)
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs();
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
return this->matchFailure();
// Extract the min/max constant values from the operands. We also consider
// a special case that there are tf.Identity ops between the min/max
// constants and the tf.FakeQuantWithMinMaxVarsOp.
Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
id1.replaceAllUsesWith(id1.input());
min = tf_op.min();
rewriter.eraseOp(id1);
}
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
id2.replaceAllUsesWith(id2.input());
max = tf_op.max();
rewriter.eraseOp(id2);
}
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
int quant_dim = -1;
if (PerAxis) {
// This is a special case that the quant_dim is the last dimensions
// according to the tf.FakeQuantWithMinMaxPerChannel.
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
}
// Use the min/max from the operands and the num_bits and narrow_range
// attribute to create the quantization parameter for the new quantize op.
rewriter.setInsertionPointAfter(tf_op);
IntegerAttr num_bits =
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
Type res_type = tf_op.getType();
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/true);
if (!qtype) this->matchFailure();
// Finally, use the quantization parameter to create the quantize and
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
// and its users.
Value value = tf_op.outputs();
auto quantize = rewriter.create<quant::QuantizeCastOp>(
tf_op.getLoc(), qtype.getValue(), value);
auto dequantize = rewriter.create<quant::DequantizeCastOp>(
tf_op.getLoc(), res_type, quantize.getResult());
value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess();
}
};
using PreparePerTensorFakeQuant =
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
using PreparePerChannelFakeQuant =
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
true>;
// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
// legalization.
void LegalizeTFToQuant::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
auto *ctx = func.getContext();
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
applyPatternsGreedily(func, patterns);
}
} // namespace
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
return std::make_unique<LegalizeTFToQuant>();
}
static PassRegistration<LegalizeTFToQuant> pass(
"tf-to-quant", "Legalize TF to quant ops dialect");
} // namespace TF
} // namespace mlir

View File

@ -46,9 +46,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
llvm::sort(defs, LessRecord());
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
OUT(0) << "static std::unique_ptr<quant::OpQuantSpec> "
"GetOpQuantSpec(mlir::Operation *op) {\n";
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
OUT(2) << "auto spec = absl::make_unique<quant::OpQuantSpec>();\n";
llvm::SmallVector<llvm::StringRef, 3> matches;
for (auto *def : defs) {
Operator op(def);
@ -74,7 +74,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
<< "GetUniformQuantizedTypeForBias)));\n";
<< "quant::GetUniformQuantizedTypeForBias)));\n";
matches.clear();
}
// There is a "QuantChannelDim" trait, set the quantization dimension.

View File

@ -0,0 +1,40 @@
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//tensorflow/compiler/mlir/...",
"//tensorflow/compiler/mlir/lite/...",
],
)
cc_library(
name = "hlo_xla_quantization_passes",
srcs = [
"materialize.cc",
"op_quant_spec.inc",
"propagate.cc",
],
hdrs = [
"passes.h",
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/xla/client/lib:quantize",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
)

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass quantize the constant and rewrite the quantization
// ops by xla_hlo primitive ops.
#include <cstdint>
#include <iterator>
#include <numeric>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
//===----------------------------------------------------------------------===//
// The pass to materialize the quantization results by xla primitive ops.
//
namespace mlir {
namespace xla_hlo {
namespace {
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
// "quantized constant->xla_hlo.dequantize". If it only matches the
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
// We chain the pattern as a whole to bypass the type checks of the normal
// xla_hlo ops.
// TODO(fengliuai): make this pass work for bf16 input.
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
public:
explicit RewriteDequantize(int64_t size, MLIRContext *context)
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op,
PatternRewriter &rewriter) const override {
// quant.dcast
// xla_hlo dequantize only takes min/max, so let's recover them from
// the quantization parameters.
Value dcast = op.arg();
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
if (!type || !type.isa<quant::UniformQuantizedType>()) {
return matchFailure();
}
auto qtype = type.cast<quant::UniformQuantizedType>();
double scale = qtype.getScale();
int64_t zero_point = qtype.getZeroPoint();
float min = scale * (qtype.getStorageTypeMin() - zero_point);
float max = scale * (qtype.getStorageTypeMax() - zero_point);
// quant.qcast
auto qcast =
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
if (!qcast) return matchFailure();
// constant
DenseFPElementsAttr attr;
// If it isn't a floating-point constant or the size is too small, let's
// remove the quantization. Also the last dimension size should be a
// multiplier of 4, so the shape isn't broken during packing and unpacking.
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
attr.getNumElements() <= size_ ||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
op.getResult().replaceAllUsesWith(qcast.arg());
return matchSuccess();
}
// TODO(fengliuai): implement transpose if it has high dimension.
// Create the quantized result
auto quantized_result =
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
if (!quantized_result) {
return matchFailure();
}
// Pack the uint8 bits to uint32. The shape is changed from from
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
std::vector<uint8_t> raw_data;
for (auto d : quantized_result.getValues<uint8_t>()) {
raw_data.push_back(d);
}
// The packing might increase the data size by paddings.
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
auto packed_shape = attr.getType().getShape().vec();
int lower_dims = std::accumulate(
packed_shape.begin(),
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
std::multiplies<int>());
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
auto packed_type =
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
auto packed_quantized_result =
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
auto quantized_constant =
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
// Create the xla dequantize op with bf16 output
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
rewriter.getBF16Type());
auto dequantize = rewriter.create<DequantizeOp>(
qcast.getLoc(), dequantized_type, quantized_constant,
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
// Convert bf16 output back to f32
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
dequantize);
return matchSuccess();
}
private:
int64_t size_;
};
// Materialize the quantization results by hlo primitive ops.
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
explicit MaterializeToXlaPass() = default;
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
void runOnFunction() override;
};
void MaterializeToXlaPass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext *ctx = &getContext();
OwningRewritePatternList patterns;
// TODO(fengliuai): make the size 6 configurable.
patterns.insert<RewriteDequantize>(6, ctx);
applyPatternsGreedily(func, patterns);
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
return std::make_unique<MaterializeToXlaPass>();
}
static PassRegistration<MaterializeToXlaPass> pass(
"xla-hlo-materialize-quant",
"Materialize the quantization results by xla primitve ops");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,7 @@
// TODO(fengliuai): automatically generate this file
// TODO(fengliuai): add all the xla_hlo ops
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
auto spec = absl::make_unique<quant::OpQuantSpec>();
return spec;
}

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
#include <memory>
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
namespace mlir {
namespace xla_hlo {
// Propagate the quantization information to all the tensors according to the
// op quant spec.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
// Rewrite the graph and quantize the constant.
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on xla_hlo dialect.
#include <iterator>
#include <string>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
// NOLINTNEXTLINE
static llvm::cl::opt<bool> disable_per_channel(
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
llvm::cl::desc("Whether disable per-channel quantized weights."),
llvm::cl::init(false));
//===----------------------------------------------------------------------===//
// The quantization propagation Pass.
//
namespace mlir {
namespace xla_hlo {
namespace {
// Applies the quantization propagation on the input function. During the
// propagation, two facts are respected:
// - The quantization type (params) of the ops in the function
// - The quantization spec for the ops
// The propagation results should assign quantization types to all the tensors
// and the two restrictions are respected.
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
explicit PropagateQuantPass() = default;
PropagateQuantPass(const PropagateQuantPass &) {}
void runOnFunction() override;
};
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
void PropagateQuantPass::runOnFunction() {
FuncOp func = getFunction();
// XLA only support uint8/uint16 quantization for now.
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
disable_per_channel, GetOpQuantSpec);
}
} // namespace
// Creates an instance of the xla_hlo dialect quantization propagation pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
return std::make_unique<PropagateQuantPass>();
}
static PassRegistration<PropagateQuantPass> pass(
"xla-hlo-propagate-quant", "Propagate quantization information");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,54 @@
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
// CHECK-LABEL: func @quantize_rewrite
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_small
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
%w = constant dense<1.0> : tensor<1x4xf32>
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32>
return %mul: tensor<1x4xf32>
}
// CHECK-LABEL: func @quantize_non_cst
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
return %mul: tensor<2x4xf32>
}
// CHECK-LABEL: func @quantize_non_4x
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
%w = constant dense<1.0> : tensor<2x5xf32>
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
%mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32>
return %mul: tensor<2x5xf32>
}

View File

@ -0,0 +1,25 @@
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
// CHECK-LABEL: func @mul
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32>
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
%mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32>
return %mul: tensor<2x2xf32>
}
// CHECK-LABEL: func @add
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
%b = constant dense<1.0> : tensor<2xf32>
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
return %add: tensor<2x2xf32>
}

View File

@ -0,0 +1,39 @@
package(
default_visibility = [
":friends",
],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
includes = ["//third_party/mlir:subpackages"],
packages = [
"//learning/brain/experimental/mlir/...",
"//tensorflow/lite/...",
],
)
cc_library(
name = "sparsify_model",
srcs = [
"sparsify_model.cc",
],
hdrs = [
"sparsify_model.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h"
namespace mlir {
namespace lite {
TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
MLIRContext context;
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);
// Import input_model to a MLIR module
flatbuffers::FlatBufferBuilder input_builder;
flatbuffers::Offset<tflite::Model> input_model_location =
tflite::Model::Pack(input_builder, &input_model);
tflite::FinishModelBuffer(input_builder, input_model_location);
std::string serialized_model(
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
input_builder.GetSize());
std::vector<std::string> output_arrays_order;
OwningModuleRef module =
tflite::FlatBufferToMlir(serialized_model, &context,
UnknownLoc::get(&context), output_arrays_order);
if (!module) {
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
return kTfLiteError;
}
PassManager pm(module->getContext());
if (failed(pm.run(module.get()))) {
const std::string& err = statusHandler.ConsumeStatus().error_message();
error_reporter->Report("Failed to sparsify: %s", err.c_str());
return kTfLiteError;
}
// Export the results to the builder
std::string result;
if (tflite::MlirToFlatBufferTranslateFunction(
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
error_reporter->Report("Failed to export MLIR to flatbuffer.");
return kTfLiteError;
}
builder->PushFlatBuffer(reinterpret_cast<const uint8_t*>(result.data()),
result.size());
return kTfLiteOk;
}
} // namespace lite
} // namespace mlir

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_
#include <memory>
#include <unordered_set>
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
namespace lite {
// Sparsify the `input_model` and write the result to a flatbuffer `builder`.
TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter);
} // namespace lite
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_

View File

@ -76,42 +76,6 @@ func @reshape_not_removeIdentity(%arg0: tensor<?xf32>, %arg1: tensor<3xi32>) ->
// CHECK-NEXT: "tfl.reshape"
}
// Checks that tfl.fake_quant should be removed if all its users have valid
// "minmax" attributes.
func @fakequant_dropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%1 = tfl.pow %arg0, %0 {minmax = [0.4, 0.6]} : tensor<i32>
%2 = tfl.pow %1, %0 {minmax = [0.5, 0.7]} : tensor<i32>
return %2 : tensor<i32>
// CHECK-LABEL: fakequant_dropfakequant
// CHECK-NEXT: %0 = tfl.pow %arg0, %arg0 {minmax = [4.000000e-01, 6.000000e-01]} : tensor<i32>
// CHECK-NEXT: %1 = tfl.pow %0, %arg0 {minmax = [5.000000e-01, 0.69999999999999996]} : tensor<i32>
// CHECK-NEXT: return %1 : tensor<i32>
}
// Checks that tfl.fake_quant should not be removed if some of its users or
// itself don't have valid "minmax" attributes.
func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%1 = tfl.pow %arg0, %0 : tensor<i32>
%2 = tfl.pow %1, %0 : tensor<i32>
%5 = "tfl.fake_quant"(%arg0) {name = 1, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
%6 = tfl.pow %arg0, %5 : tensor<i32>
%7 = tfl.pow %6, %5 : tensor<i32>
%11 = addi %2, %7 : tensor<i32>
return %11 : tensor<i32>
// CHECK-LABEL: fakequant_notdropfakequant
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], name = 0 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
// CHECK: %3 = "tfl.fake_quant"(%arg0) {minmax = [1.000000e-01, 2.000000e-01], name = 1 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
}
// -----
// CHECK-LABEL: @RemoveRedundantUnpackPack

View File

@ -0,0 +1,231 @@
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConv
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<2> : tensor<2x2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithNonZeroSTBPadding
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
return %2 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConv
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
// CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithPad
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %4 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithPad
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %3 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedConvWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
return %3 : tensor<1x128x128x8xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
}
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
return %5 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>
%cst_0 = constant dense<3> : tensor<i32>
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
return %6 : tensor<1x128x128xf32>
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
}

View File

@ -420,6 +420,15 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
}
func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi64>) -> tensor<1x3x5x20xf32>
return %1 : tensor<1x3x5x20xf32>
// CHECK-LABEL:gatherV2VectorIndices_I64Axis
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
}
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
@ -1074,6 +1083,14 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
}
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
return %0 : tensor<1x2x2x5xcomplex<f32>>
// CHECK-LABEL: castComplex
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
}
func @unique(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0, %1 : tensor<?xf32> , tensor<?xi32>

View File

@ -1,5 +1,26 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: tensorlistConst
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
// CHECK: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
// CHECK: return %[[LIST]]
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
}
func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
// CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
// CHECK: return %[[LIST]]
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
return %1 : tensor<0x3xi32>
}
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>

View File

@ -0,0 +1,29 @@
# Description:
# Integration tests of converter & interpreter.
#
# There should be few tests in here and it should be rare where the execution
# tests are not tested by unit tests already. This is useful for verifying some
# runtime behavior, but the majority of runtime tests should be TFLite side and
# invariants only verified in the converter/compiler.
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-opt",
"//tensorflow/compiler/mlir/lite:mlir-tflite-runner",
"@llvm-project//llvm:FileCheck",
"@llvm-project//llvm:not",
],
)

View File

@ -0,0 +1,46 @@
// Test to verify translation & export work as intended with runtime.
// RUN: not mlir-tflite-runner --dump-interpreter-state %s 2>&1 | FileCheck %s --check-prefix ERROR --dump-input-on-failure
// RUN: tf-opt --mlir-print-debuginfo --canonicalize --tfl-while-loop-outline %s | mlir-tflite-runner --dump-interpreter-state 2>&1 | FileCheck %s --dump-input-on-failure
// ERROR: number of operands and results don't match
// Verify value computed:
// ----------------------
// CHECK: result: Tensor<type: FLOAT32, shape: 1, values: 96>
// Verify tensors in interpreter state:
// ------------------------------------
// CHECK: Tensor 0 dec kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
// CHECK-NEXT: Tensor 4 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 5 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 6 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// CHECK-NEXT: Tensor 7 tfl.while:3 kTfLiteInt32 kTfLiteArenaRw 4 bytes
// Verify while was not folded away:
// ------------------------------------
// CHECK: Operator Builtin Code {{[0-9]*}} WHILE
func @main() -> tensor<1xf32>
attributes {tf.entry_function = {outputs = "result"}} {
%cst = constant dense<1> : tensor<i32> loc("dec")
%arg0 = constant dense<5> : tensor<i32> loc("N")
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
%0:2 = "tfl.while"(%arg0, %arg1, %cst) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
%cst_0 = constant dense<0> : tensor<i32>
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
%1 = "tfl.sub"(%arg2, %arg4) {fused_activation_function = "NONE"} :
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
"tfl.yield"(%1, %2, %arg4) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> ()
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>)
return %0#1 : tensor<1xf32>
}

View File

@ -1,139 +1,56 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -tflite-flatbuffer-to-mlir - -o - | FileCheck --check-prefix=IMPORT %s
// TODO(b/141520199): Currently fake quant is not being written to flatbuffer
// since it is legalized to quantize and dequantize. Update this test and add
// fake_quant_v2.mlir when the op is being written to flatbuffer.
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: MUL,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: DIV,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: EXP,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: NEG,
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "squared_difference",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "div",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "neg",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: builtin_options_type: MulOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3, 2 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: DivOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 4 ],
// CHECK-NEXT: outputs: [ 5 ],
// CHECK-NEXT: builtin_options_type: ExpOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 4,
// CHECK-NEXT: inputs: [ 5 ],
// CHECK-NEXT: outputs: [ 6 ],
// CHECK-NEXT: builtin_options_type: NegOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ]
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: FAKE_QUANT,
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.fake_quant",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0 ],
// CHECK-NEXT: outputs: [ 1 ],
// CHECK-NEXT: builtin_options_type: FakeQuantOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: min: 0.3,
// CHECK-NEXT: max: 1.4,
// CHECK-NEXT: num_bits: 6
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
%0 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<4 x f32>) -> tensor<4 x f32>
// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}
%0 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor<4 x f32>) -> tensor<4 x f32>
return %0 : tensor<4xf32>
}

View File

@ -0,0 +1,219 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: WHILE,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: GREATER,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: builtin_code: SUB,
// CHECK-NEXT: version: 1
// CHECK-NEXT: }, {
// CHECK-NEXT: version: 1
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "WhileOp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "WhileOp1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2, 3 ],
// CHECK-NEXT: builtin_options_type: WhileOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-NEXT: cond_subgraph_index: 1,
// CHECK-NEXT: body_subgraph_index: 2
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "tfl.greater",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "WhileOp_cond"
// CHECK-NEXT: }, {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "Const1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "tfl.sub",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: buffer: 13,
// CHECK-NEXT: name: "tfl.add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 3, 4 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 0, 2 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: builtin_options_type: SubOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 3,
// CHECK-NEXT: inputs: [ 1, 1 ],
// CHECK-NEXT: outputs: [ 4 ],
// CHECK-NEXT: builtin_options_type: AddOptions,
// CHECK-NEXT: builtin_options: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "WhileOp_body"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 1, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: } ]
// CHECK-NEXT: }
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<0> : tensor<i32> loc("Const")
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
%cst = constant dense<1> : tensor<i32> loc("Const1")
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1 : tensor<*xi32>, tensor<*xf32>
}
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1 = call @WhileOp_cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
%1:2 = call @WhileOp_body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
return %0#1 : tensor<1xf32>
}

View File

@ -355,10 +355,8 @@ func @testConv2DNoBias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf3
// CHECK-LABEL: testFakeQuant
func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], narrow_range = true, num_bits = 2 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.fake_quant"(%arg0) {minmax = [], num_bits = 2 : i32, narrow_range = true} : (tensor<? x f32>) -> tensor<? x f32>
// CHECK: %1 = "tfl.fake_quant"(%0) {minmax = [3.000000e-01, 1.400000e+00], narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%1 = "tfl.fake_quant"(%0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<? x f32>) -> tensor<? x f32>
// CHECK: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%1 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor<? x f32>) -> tensor<? x f32>
return %1 : tensor<? x f32>
}

View File

@ -1,5 +1,5 @@
// Run optimize pass only and check the results.
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
// RUN: tf-opt %s -tfl-optimize | FileCheck %s --dump-input-on-failure
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
@ -294,8 +294,68 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
// CHECK: return %1
}
// CHECK-LABEL: @FuseFullyConnectedAddUnit
func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
// CHECK-LABEL: @FuseFullyConnectedAddWithNoBias
func @FuseFullyConnectedAddWithNoBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddWithExistingBias
func @FuseFullyConnectedAddWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40xf32>
%cst2 = constant dense<2.0> : tensor<40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddWithNoBiasAndScalarRhs
func @FuseFullyConnectedAddWithNoBiasAndScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<f32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<f32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddWithScalarRhs
func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40xf32>
%cst2 = constant dense<2.0> : tensor<f32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<f32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40x40xf32>
@ -304,24 +364,11 @@ func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf3
return %1 : tensor<40x40xf32>
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
// CHECK: return %[[fc]]
}
// CHECK-LABEL: @FuseFullyConnectedAddConst
func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant dense<3.0> : tensor<40x40xf32>
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
// CHECK: return %[[fc]]
// CHECK: %[[unit:.*]] = constant unit
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
// CHECK: return %[[add_result]]
}
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
@ -690,6 +737,54 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor
// CHECK: return %[[RES]]
}
// CHECK-LABEL: leaky_relu_fusion
func @leaky_relu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<0.2> : tensor<f32>
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.leaky_relu"
}
// CHECK-LABEL: leaky_relu_not_fused
// Should not fuse to LeakyRelu, since alpha > 1.
func @leaky_relu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<1.2> : tensor<f32>
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.maximum"
}
// CHECK-LABEL: prelu_fusion
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<3xf32>
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.prelu"
}
// CHECK-LABEL: prelu_not_fused
// Rank of alpha should be one less than input for PReLU, which is not the case.
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<f32>
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.relu"
}
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
%cst = constant dense<1.5> : tensor<16xf32>

View File

@ -242,6 +242,22 @@ func @QuantizePad(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) ->
// CHECK: return %3 : tensor<?xf32>
}
// CHECK-LABEL: QuantizePad2
// only the second tfl.pad has sufficient quantization information.
func @QuantizePad2(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor<?xf32>, tensor<?xf32>) {
^bb0(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>):
%0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x1x3xf32>
%1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
%2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
return %1, %2 : tensor<?xf32>, tensor<?xf32>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0)
// CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2)
// CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2)
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]])
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
}
// CHECK-LABEL: QuantizeReshape2D
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
@ -418,16 +434,15 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
}
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
// CHECK-LABEL: NotRequantizeAlreadyQuantizedModel
func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
%9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>
%10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
}
// CHECK-LABEL: QuantizeChain

View File

@ -0,0 +1,26 @@
// RUN: tf-opt -loop-invariant-code-motion %s -o - | FileCheck %s --dump-input-on-failure
// CHECK: while_1([[ARG0:%[^ :]*]]: tensor<i32>, [[ARG1:%[^ :]*]]: tensor<1xf32>)
func @while_1(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: [[CST:%[^ ]*]] = constant dense<1> : tensor<i32>
// CHECK: "tfl.while"([[ARG0]], [[ARG1]])
// CHECK: (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
%0:2 = "tfl.while"(%arg0, %arg1) (
// cond
{
^bb0(%condArg0: tensor<*xi32>, %condArg1: tensor<*xf32>):
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.greater"(%condArg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
},
// body
{
^bb0(%bodyArg0: tensor<*xi32>, %bodyArg1: tensor<*xf32>):
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
%1 = "tfl.sub"(%bodyArg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %bodyArg1, %bodyArg1 {fused_activation_function = "NONE"} : tensor<*xf32>
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
}
) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
return %0#1 : tensor<1xf32>
}

View File

@ -0,0 +1,31 @@
// Test to verify loop outlining.
// RUN: tf-opt --tfl-while-loop-outline %s | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: func @while
func @while() -> tensor<1xf32>
attributes {tf.entry_function = {outputs = "result"}} {
%cst = constant dense<1> : tensor<i32> loc("dec")
%arg0 = constant dense<5> : tensor<i32> loc("N")
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
%0:2 = "tfl.while"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
// CHECK: call @WhileOp_cond
%cst_0 = constant dense<0> : tensor<i32>
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
// CHECK: call @WhileOp_body
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
return %0#1 : tensor<1xf32>
}
// CHECK-LABEL: func @WhileOp_cond(
// CHECK: tfl.greater
// CHECK-LABEL: func @WhileOp_body(
// CHECK: tfl.sub
// CHECK: tfl.add

View File

@ -80,10 +80,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
}
if (pass_config.lower_tensor_list_ops) {
// Execute this pass before `CanonicalizerPass` in case some TensorList
// ops are constant folded into variant types.
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
// handle constant ops that produce `TensorList`.
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -65,23 +66,23 @@ class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
// Uses `quant_params` to quantize `value` and inserting a pair of
// tfl.quantize and tfl.dequantize ops for this `value`.
void QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params);
quant::QuantParams quant_params);
// If the value hasn't been quantized, the functions adds it to `values`.
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
// Converts the default min/max to the default quantization parameters.
TFL::QuantParams GetDefaultQuantParams(Builder builder);
quant::QuantParams GetDefaultQuantParams(Builder builder);
// Gets the quantization parameters for the bias of an operation by using the
// quantization parameters from the non-biases operands.
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func);
quant::QuantParams GetQuantParamsForBias(Operation *op, int bias,
const std::vector<int> &non_biases,
quant::AccumulatorScaleFunc func);
double default_min_;
double default_max_;
TFL::QuantParams default_quant_params_;
quant::QuantParams default_quant_params_;
};
} // namespace
@ -104,7 +105,9 @@ void DefaultQuantParamsPass::runOnFunction() {
func.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::QuantizeCastOp>(op) ||
llvm::isa<quant::DequantizeCastOp>(op))
return;
for (auto res : op->getResults()) {
@ -117,7 +120,7 @@ void DefaultQuantParamsPass::runOnFunction() {
});
// Apply the default quantization parameters for these activation values.
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
quant::QuantParams default_params = GetDefaultQuantParams(builder);
for (Value value : activation_values) {
QuantizeValue(builder, value, default_params);
}
@ -128,7 +131,7 @@ void DefaultQuantParamsPass::runOnFunction() {
Operation *op = *bias.user_begin();
auto spec = TFL::GetOpQuantSpec(op);
for (auto &it : spec->biases_params) {
TFL::QuantParams bias_params = GetQuantParamsForBias(
quant::QuantParams bias_params = GetQuantParamsForBias(
op, it.first, it.second.first, it.second.second);
if (!bias_params) continue;
QuantizeValue(builder, bias, bias_params);
@ -157,7 +160,7 @@ void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
}
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
TFL::QuantParams quant_params) {
quant::QuantParams quant_params) {
Type expressed_type = value.getType();
Type new_type = quant_params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip.
@ -182,9 +185,9 @@ void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
}
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
Operation *op, int bias, const std::vector<int> &non_biases,
TFL::AccumulatorScaleFunc func) {
quant::AccumulatorScaleFunc func) {
std::vector<quant::QuantizedType> non_bias_types;
non_bias_types.reserve(non_biases.size());
for (int non_bias : non_biases) {
@ -205,7 +208,7 @@ TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
return func(non_bias_types);
}
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
Builder builder) {
if (!default_quant_params_) {
default_quant_params_ = quant::fakeQuantAttrsToType(

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass convert dense tensor to sparse format.
#include "absl/memory/memory.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
//===----------------------------------------------------------------------===//
// The DenseToSparse Pass.
//
namespace mlir {
namespace TFL {
namespace {
struct DenseToSparse : public FunctionPass<DenseToSparse> {
void runOnFunction() override;
};
void DenseToSparse::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func);
func.walk([&](SparseOpInterface sparse_op) {
const auto& sparse_operands = sparse_op.GetSparseOperands();
for (const int operand : sparse_operands) {
auto* op = sparse_op.getOperation();
const auto& value = op->getOperand(operand);
builder.setInsertionPoint(op);
if (auto* inst = value.getDefiningOp()) {
// Replace defining op with SparseConst or SparseQConst.
// TODO(yunluli): Implement.
}
// TODO(yunluli): Implement.
bool needs_densify = false;
if (needs_densify) {
auto densify = builder.create<DensifyOp>(op->getLoc(), value);
value.replaceAllUsesWith(densify);
densify.setOperand(value);
}
}
});
}
} // namespace
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
return absl::make_unique<DenseToSparse>();
}
static PassRegistration<DenseToSparse> pass(
"tfl-dense-to-sparse", "Convert dense tensor to sparse format.");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
namespace mlir {
namespace TFL {
namespace {
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
void runOnFunction() override;
};
void IdentifyDilatedConvPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
&getContext());
applyPatternsGreedily(func, patterns);
}
} // namespace
static PassRegistration<IdentifyDilatedConvPass> pass(
"tfl-identify-dilated-conv",
"Identify and replace patterns for dilated convolution.");
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,234 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass identifies patterns for dilated convolution and replace it with
// a real convolution op.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
#include <cstdint>
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
// A dilated convolution can be emulated with a regular convolution by chaining
// SpaceToBatch and BatchToSpace ops before and after it:
//
// SpaceToBatchND -> Conv2D -> BatchToSpaceND
//
// This method was common before Conv2D fully supported dilated convolution in
// TensorFlow. This transformation detects this "emulation", and replaces it
// with a true dilated convolution, eliminating the SpaceToBatch and
// BatchtoSpace ops.
//
// Detecting this alone would be relatively easy. However, in practice some
// extra ops are used, so we detect the following patterns:
//
//
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
//
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
// BiasAdd
//
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
//
// SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
//
// SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
//
//
// The Expand/Squeeze combination is used to adapt a 3D array (such as in
// WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
template <typename Conv2dOpTy>
class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
private:
using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
Value stb_block_shape, Value bts_block_shape,
PatternRewriter& rewriter) const;
public:
PatternMatchResult matchAndRewrite(Conv2dOpTy op,
PatternRewriter& rewriter) const override;
};
template <typename Conv2dOpTy>
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
Conv2dOpTy op, PatternRewriter& rewriter) const {
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
// `Squeeze` op.
Operation* prev_op = op.getOperation()->getPrevNode();
if (!prev_op) return Pattern::matchFailure();
Operation* next_op = op.getOperation()->getNextNode();
if (!next_op) return Pattern::matchFailure();
TF::ExpandDimsOp expand_op;
TF::SqueezeOp squeeze_op;
// Expand + Squeeze op.
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
// Expand/Squeeze op must come in pair.
return Pattern::matchFailure();
}
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
// Update previous/next op pointer.
prev_op = prev_op->getPrevNode();
if (!prev_op) return Pattern::matchFailure();
next_op = next_op->getNextNode();
if (!next_op) return Pattern::matchFailure();
}
// SpaceToBatchND op.
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
// Pad op.
TF::PadOp pad_op;
if (llvm::isa<TF::PadOp>(next_op)) {
pad_op = llvm::cast<TF::PadOp>(next_op);
next_op = next_op->getNextNode();
if (!next_op) return Pattern::matchFailure();
}
// BatchToSpaceND + BiasAdd.
TF::BatchToSpaceNDOp bts_op;
TF::BiasAddOp biasadd_op;
bool final_op_is_bts = true;
if (llvm::isa<TF::BiasAddOp>(next_op)) {
// Must be BiasAdd + BatchToSpaceND.
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
next_op = next_op->getNextNode();
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op))
return Pattern::matchFailure();
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
// BatchToSpaceND + (optional) BiasAdd.
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
next_op = next_op->getNextNode();
if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
final_op_is_bts = false;
}
} else {
return Pattern::matchFailure();
}
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
stb_op.block_shape(), bts_op.block_shape(), rewriter);
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
op.setAttr("dilations", dilations_attr.getValue());
// Here we need to set the correct padding for Conv op. In TF, the conv op
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
// become a problem here if the original Conv op has 'SAME' padding. When
// the original conv has 'SAME' padding, TF will set a non-zero padding for
// the 'SpaceToBatch' op, so we rely on this information to check if we need
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
// values in `stb_op.paddings`, we change the current Conv's padding to
// 'SAME').
auto stb_paddings = stb_op.paddings();
ElementsAttr stb_paddings_attr;
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
if (llvm::any_of(stb_paddings_attr.getValues<IntegerAttr>(),
[](IntegerAttr attr) { return attr.getInt() != 0; })) {
op.setAttr("padding", rewriter.getStringAttr("SAME"));
}
}
if (expand_op) {
// If there is `expand_op`, we need to rewire the inputs to bypass the
// `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
// 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
// BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
// Connect `expand_op` with the input of `stb_op`.
expand_op.setOperand(0, stb_op.input());
// Calculate the shape for expand.
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
input_shape.end());
expand_shape.push_back(1);
auto expand_result_type = RankedTensorType::get(
expand_shape, getElementTypeOrSelf(stb_op.input()));
expand_op.getResult().setType(expand_result_type);
op.getResult().setType(expand_result_type);
squeeze_op.getResult().setType(bts_op.output().getType());
// Connect `biasadd_op` with the output of `squeeze_op`.
biasadd_op.setOperand(0, squeeze_op.output());
biasadd_op.output().setType(squeeze_op.output().getType());
} else {
if (biasadd_op) biasadd_op.setOperand(0, op.output());
op.setOperand(0, stb_op.input());
op.getResult().setType(bts_op.getResult().getType());
}
if (final_op_is_bts) {
bts_op.getResult().replaceAllUsesWith(bts_op.input());
}
stb_op.getResult().dropAllUses();
return Pattern::matchSuccess();
}
template <typename Conv2dOpTy>
llvm::Optional<ArrayAttr>
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
Value stb_block_shape, Value bts_block_shape,
PatternRewriter& rewriter) const {
ElementsAttr stb_bs_attr, bts_bs_attr;
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
// Returns failure status if block shape is not a constant.
return {};
}
// Check that the block_shape of `stb_op` and `bts_op` are equal.
if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
for (uint64_t i = 0; i < stb_bs_attr.getNumElements(); ++i) {
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
}
// TODO(haoliang): support 1-D dilated conv.
if (stb_bs_attr.getNumElements() < 2) return {};
int dilation_h_factor =
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
int dilation_w_factor =
stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
}
} // namespace TFL
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_

View File

@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
"quant::GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
// Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall<
@ -49,11 +49,19 @@ def convertIntAttrTo32Bit : NativeCodeCall<
def ExtractSingleElementAsInteger : NativeCodeCall<
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
// Extracts the single int32 element from $_self.
def ExtractSingleElementAsInt32 : NativeCodeCall<
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
// Checks whether the given operation has static shapes and same shapes of all inputs.
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
// Checks if the value has only one user.
// TODO(karimnosseir): Move to a common place?
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
@ -186,7 +194,7 @@ def : Pat<(TF_GatherV2Op $params, $indices,
(ConstantOp ElementsAttr:$axis),
ConstantAttr<I64Attr, "0">:$batch_dims),
(TFL_GatherOp $params, $indices,
ExtractSingleElementAsInteger:$axis)>;
ExtractSingleElementAsInt32:$axis)>;
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
@ -198,16 +206,20 @@ def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
// TODO(karimnosseir): Can the activation part here be removed by modifying the
// very similar pass in optimize_patterns.td?
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, TFL_AF_None)>;
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
[TF_Relu6Op, TFL_AF_Relu6]] in {
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
(ToOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (FromOp:$bin_out $lhs, $rhs)),
(ToOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $bin_out)]>;
// TODO: Maybe move these below to general pass?
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
(ToOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (ToOp:$bin_out $lhs, $rhs, TFL_AF_None)),
(ToOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $bin_out)]>;
}
}

View File

@ -23,9 +23,11 @@ limitations under the License.
#include <climits>
#include <cstdint>
#include "absl/container/inlined_vector.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@ -57,6 +59,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/tensor_list.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -162,6 +168,86 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size);
}
// Converts tf.Const containing variant of type TensorList to a tensor of
// primitive element types. Each of the individual tensor in the list is
// converted to an ElementsAttr and then those are packed together using
// tf.Pack op.
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
TF::ConstOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Verify that the opaque elements attribute contains tensor of type variant
// and scalar shape. The variant type should hold a TensorList.
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
if (!opaque_attr) return matchFailure();
tensorflow::Tensor tensor;
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
return matchFailure();
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
return matchFailure();
const tensorflow::TensorList *list =
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
if (!list) return matchFailure();
// Verify output type is variant and contains exactly one ranked subtypes.
auto variant_ty =
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
if (!variant_ty) return matchFailure();
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
if (subtypes.size() != 1) return matchFailure();
RankedTensorType list_element_ty =
subtypes.front().dyn_cast<RankedTensorType>();
if (!list_element_ty) return matchFailure();
// Extract tensor elements for the TensorList and construct result type
// based on the number of elements and element shape.
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
llvm::SmallVector<int64_t, 4> result_shape = {
static_cast<int64_t>(tensors.size())};
result_shape.append(list_element_ty.getShape().begin(),
list_element_ty.getShape().end());
auto result_ty =
RankedTensorType::get(result_shape, list_element_ty.getElementType());
// If the list is empty, directly create the final result instead of
// creating the tf.Pack op. tf.Pack op requires at least one operand.
if (tensors.empty()) {
absl::InlinedVector<tensorflow::int64, 4> tf_shape;
tf_shape.reserve(result_shape.size());
for (int64_t dim : result_shape) {
tf_shape.push_back(dim);
}
tensorflow::Tensor tensor(list->element_dtype,
tensorflow::TensorShape(tf_shape));
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
return matchSuccess();
}
// Extract individual tensor list element and combine them using the tf.Pack
// op.
Location loc = op.getLoc();
llvm::SmallVector<Value, 4> values;
values.reserve(tensors.size());
for (const tensorflow::Tensor &tensor : tensors) {
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
if (!attr_or.ok()) return matchFailure();
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
values.push_back(value);
}
rewriter.replaceOpWithNewOp<TF::PackOp>(
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
return matchSuccess();
}
};
struct ConvertTensorListSetItem
: public OpConversionPattern<TF::TensorListSetItemOp> {
using OpConversionPattern::OpConversionPattern;
@ -615,7 +701,7 @@ struct ConvertTensorListStack
if ((ranked_type && ranked_type.getRank() == 0) ||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
// If no constant is spotted, just forward the operand.
rewriter.replaceOp(op, {input}, llvm::None);
rewriter.replaceOp(op, {input});
return matchSuccess();
}
@ -768,7 +854,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
OwningRewritePatternList patterns;
patterns
.insert<ConvertEmptyTensorList, ConvertIdentity,
.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
ConvertTensorListFromTensor, ConvertTensorListGetItem,
ConvertTensorListLength, ConvertTensorListPushBack,
ConvertTensorListReserve, ConvertTensorListSetItem,

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
@ -173,7 +174,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
}
TypeAttr RescaleQtype(Type input, Attribute factor) {
return TFL::RescaleQuantizedType(input, factor);
return quant::RescaleQuantizedType(input, factor);
}
// Returns shape of a ranked tensor.
@ -201,34 +202,80 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
PatternMatchResult matchAndRewrite(TFL::AddOp add_op,
PatternRewriter &rewriter) const override {
// Add.
// Match Add.
DenseElementsAttr added_value;
Value constant_val = add_op.rhs();
if (!matchPattern(constant_val, m_Constant(&added_value)))
return matchFailure();
// Fully Connected.
// Match Fully Connected.
auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure();
// Check if the constant RHS is either 0D (scalar), or a 1D with
// `{num_channels}` shape.
auto constant_val_type = constant_val.getType().cast<TensorType>();
// In TFLite FullyConnect definition, bias must be a 1D tensor where
// the number of elements is equal to the number of channels.
// If it's not 1D or 0D (which can be broadcasted to 1D), reject the
// matching.
bool is_scalar_rhs = false;
if (constant_val_type.getRank() == 0) {
is_scalar_rhs = true;
} else if (constant_val_type.getRank() != 1) {
return matchFailure();
}
Value filter = fc_op.filter();
Value bias = fc_op.bias();
ElementsAttr bias_value;
const bool is_none_bias = bias.getType().isa<NoneType>();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
// Rewrite
Location loc = fc_op.getLoc();
// If bias isn't None, it needs to be added as well.
if (is_none_bias) {
bias = constant_val;
if (is_scalar_rhs) {
// If the `constant_val` is scalar, we must the shape of filter
// to properly broadcast the scalar to `{num_channels}` shape.
// Get the number of channels if possible.
auto filter_type = filter.getType().cast<ShapedType>();
// Filter must be a `2D` tensor with `{num_channels, num_features}`
// shape. The following check is rejecting unknown rank (-1).
if (filter_type.getRank() != 2) {
return matchFailure();
}
int num_channels = filter_type.getShape()[0];
// Create a zero tensor with shape {num_channels}, and the type need to
// be the same as constant_val.
// This is a way to gracefully handle scalar tensor. The Add will always
// be constant-folded away regardless if `constant_val` is a scalar or
// not.
RankedTensorType type = RankedTensorType::get(
{num_channels}, constant_val_type.getElementType());
auto attr = rewriter.getZeroAttr(type);
bias = rewriter.create<ConstantOp>(loc, type, attr);
auto none_af = rewriter.getStringAttr("NONE");
bias =
rewriter.create<AddOp>(loc, bias, constant_val, none_af).output();
} else {
// If there no pre-existing bias and the `constant_val` is 1D, simply
// use `constant_val` as bias.
bias = constant_val;
}
} else {
auto none_af = rewriter.getStringAttr("NONE");
bias = rewriter.create<AddOp>(loc, bias, constant_val, none_af).output();
}
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
add_op, add_op.getType(),
/*input=*/fc_op.input(),

View File

@ -23,26 +23,34 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
def ExtractSingleElementAsFloat : NativeCodeCall<
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w)>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
$padding, $stride_h, $stride_w),
[(HasOneUse $conv_out)]>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w,
$multiplier)>;
$multiplier),
[(HasOneUse $conv_out)]>;
}
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
@ -58,9 +66,6 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
@ -288,8 +293,9 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $binary_out)]>;
}
}
@ -358,6 +364,7 @@ class ValueEquals<string val> : Constraint<CPred<
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
// ReLU patterns
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
(ConstantOp $NegOne)),
(ConstantOp $One)),
@ -370,6 +377,35 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
(TFL_Relu1Op $input),
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
$input2),
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
[(ConstDoubleValueLessThan<"1"> $alpha),
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "
"$1.getType().cast<ShapedType>().getRank() - 1">>;
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def : Pat<(TFL_AddOp
(TFL_ReluOp:$relu_out $input1),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;

View File

@ -50,7 +50,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
const QuantizationSpecs& quant_specs);
// Creates a instance of the TensorFlow Lite dialect PostQuantize pass.
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
bool emit_quant_adaptor_ops);
@ -70,14 +70,20 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
// pass. The composite op is created from the ophint extraction pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
// Creates an instance of TensorFlow Lite dialect SplitMergedOperandsPass.
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance pass to add default quantization parameters.
// Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters.
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format.
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
} // namespace TFL
} // namespace mlir

View File

@ -136,7 +136,7 @@ def : Pat<(TF_ReshapeOp
// Casts result type of $1 to a quantized type by using the quantization
// parameters from the type in $0.
class UpdateShapeWithAxis<int i> : NativeCodeCall<
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
"quant::CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
class UsedBy<string op> : Constraint<
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
@ -144,16 +145,16 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
if (shaped.getElementType().isa<FloatType>()) {
auto min_max = GetMinMaxValuesForArgument(func_name, i);
TypeAttr params = GetQuantizedTypeAttr(
TypeAttr params = quant::GetQuantizedTypeAttr(
builder, input_type, builder.getF64FloatAttr(min_max.first),
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
narrow_range, is_signed);
builder.setInsertionPoint(block, insertion_point);
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(), arg,
params);
auto dq_op =
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
arg.replaceAllUsesWith(dq_op.output());
auto q_op =
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
q_op.getResult());
arg.replaceAllUsesWith(dq_op.getResult());
q_op.setOperand(arg);
}
}
@ -176,12 +177,14 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
}
using PrepareQuantStats =
TFL::ConvertStatsToQDQs<TFL::QuantizeOp, TFL::DequantizeOp>;
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext* ctx = func.getContext();
ConvertTFLQuantOpsToMlirQuantOps(func);
if (quant_specs_.post_training_quantization) {
RemoveRedundantStats(func);
} else {
@ -198,7 +201,7 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
if (is_signed) {
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(ctx);
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
@ -213,6 +216,8 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
GetOpQuantSpec);
ConvertMlirQuantOpsToTFLQuantOps(func);
}
} // namespace

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -81,6 +82,7 @@ class PrepareTFPass : public FunctionPass<PrepareTFPass> {
};
// TODO(fengliuai): move this rule to PreparePatterns.td
// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
// TODO(b/140968741): propagate the sign from the command line. Currently all
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
// actually INT8.
@ -149,9 +151,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
Type res_type = tf_op.getType();
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_value,
max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/false);
TypeAttr qtype = quant::GetQuantizedTypeAttr(
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
narrow_range, /*is_signed=*/false);
if (!qtype) this->matchFailure();
// Finally, use the quantization parameter to create the quantize and
@ -503,6 +505,12 @@ void PrepareTFPass::runOnFunction() {
// first `applyPatternsGreedily` method, which would otherwise removes the
// TF FakeQuant ops by the constant folding.
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
// This pattern will try to identify and optimize for dilated convolution.
// e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
// replaced with a single Conv op with dilation parameter.
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
TFL::populateWithGenerated(ctx, &patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.

View File

@ -65,8 +65,8 @@ namespace {
// Full integer quantization rewrite pattern for TFLite.
struct TFLFullQuantization
: public QuantizationPattern<TFLFullQuantization, QuantizeOp, DequantizeOp,
NumericVerifyOp> {
: public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
DequantizeOp, NumericVerifyOp> {
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric,
float tolerance, bool verify_single_layer)
: BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {}

View File

@ -20,7 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
// Quantize attribute $0 by using quantization parameter from %1.
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
// Squash tfl.dequantize and tfl.quantize pairs.
// TODO(fengliuai): Compare the scale of input and output. This can also be

View File

@ -263,6 +263,18 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
return this->matchSuccess();
}
// Input dimensions must be defined. MatMulBCast does not support partial
// shapes.
for (auto dim : lhs_shape) {
if (dim == -1) {
return this->matchFailure();
}
}
for (auto dim : rhs_shape) {
if (dim == -1) {
return this->matchFailure();
}
}
// Ensure that batch shapes are broadcastable.
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
lhs_shape.begin(), lhs_shape.end()),

Some files were not shown because too many files have changed in this diff Show More