Merge branch 'master' into sparse-xent-op-hessian
This commit is contained in:
commit
f0df47fddf
.bazelrcRELEASE.mdSECURITY.md
tensorflow
c
BUILDc_api_experimental.ccc_api_experimental.hc_api_experimental_test.cc
eager
BUILDc_api.ccc_api_debug.ccc_api_experimental.ccc_api_experimental.hc_api_test.cccustom_device_test.cc
experimental/filesystem
filesystem_interface.hmodular_filesystem.ccmodular_filesystem.hmodular_filesystem_registration.ccmodular_filesystem_test.cc
kernels_test.ccplugins
cc
compiler
aot
jit
BUILDcompilability_check_util.ccflags.ccflags.hmark_for_compilation_pass.ccxla_compilation_cache.ccxla_device.ccxla_kernel_creator.ccxla_kernel_creator_util.cc
mlir
BUILD
lite
BUILDflatbuffer_translate.cc
ir
mlir_tflite_runner.ccoperator_converter_gen.ccquantization
BUILD
lite
quantization.tdquantization_driver.ccquantization_traits.hquantization_utils.ccquantization_utils.htensorflow
tools
xla
sparsity
tests
canonicalize.mlirdilated-conv.mlirlegalize-tf.mlirlower-static-tensor-list.mlir
tf_tfl_passes.ccmlir2exec
mlir2flatbuffer
ops.mliroptimize.mlirprepare-quantize.mlirtfl_while_op_licm.mlirtfl_while_outline.mlirtransforms
8
.bazelrc
8
.bazelrc
@ -69,6 +69,7 @@
|
||||
# rbe_linux_py3: Linux Python 3 RBE config
|
||||
#
|
||||
# rbe_win_py37: Windows Python 3.7 RBE config
|
||||
# rbe_win_py38: Windows Python 3.8 RBE config
|
||||
#
|
||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||
@ -392,6 +393,7 @@ build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
|
||||
|
||||
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
|
||||
build:rbe_win --define=override_eigen_strong_inline=true
|
||||
build:rbe_win --jobs=500
|
||||
|
||||
build:rbe_win_py37 --config=rbe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
||||
@ -399,6 +401,12 @@ build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||
|
||||
build:rbe_win_py38 --config=rbe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
|
||||
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
|
||||
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
|
||||
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
|
||||
|
||||
# These you may need to change for your own GCP project.
|
||||
build:tensorflow_testing_rbe --project_id=tensorflow-testing
|
||||
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
|
||||
|
16
RELEASE.md
16
RELEASE.md
@ -1,3 +1,19 @@
|
||||
# Release 2.0.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
|
||||
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
|
||||
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
|
||||
|
||||
|
||||
# Release 1.15.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Fixes a security vulnerability where converting a Python string to a `tf.float16` value produces a segmentation fault ([CVE-2020-5215](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-5215))
|
||||
* Updates `curl` to `7.66.0` to handle [CVE-2019-5482](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5482) and [CVE-2019-5481](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-5481)
|
||||
* Updates `sqlite3` to `3.30.01` to handle [CVE-2019-19646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19646), [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) and [CVE-2019-16168](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-16168)
|
||||
|
||||
|
||||
# Release 2.1.0
|
||||
|
||||
TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019.
|
||||
|
@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
||||
### Known Vulnerabilities
|
||||
|
||||
For a list of known vulnerabilities and security advisories for TensorFlow,
|
||||
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
|
||||
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).
|
||||
|
@ -54,9 +54,10 @@ filegroup(
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"python_api.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
@ -98,6 +99,17 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_tf_session_hdrs",
|
||||
srcs = [
|
||||
"python_api.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_attrtype",
|
||||
hdrs = ["tf_attrtype.h"],
|
||||
@ -536,6 +548,7 @@ tf_cc_test(
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = ["notsan"], # b/149031034
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
# the shared library must be able to use core:framework.
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
@ -640,7 +653,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -519,72 +519,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
||||
auto* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::Tensor dst;
|
||||
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
|
||||
LOG(INFO) << dst.DebugString();
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TFE_OpPrintDebugString(TFE_Op* op) {
|
||||
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
|
||||
LOG(INFO) << op->operation.DebugString();
|
||||
}
|
||||
|
||||
struct TFE_ExecuteOpNotification {
|
||||
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
|
||||
tensorflow::Notification n;
|
||||
std::unique_ptr<tensorflow::Thread> thread;
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
|
||||
};
|
||||
|
||||
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
||||
TFE_TensorHandle** retvals,
|
||||
int* num_retvals,
|
||||
TF_Status* status) {
|
||||
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||
|
||||
n->thread.reset(op->operation.EagerContext().TFEnv()->StartThread(
|
||||
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||
[op, retvals, num_retvals, n]() {
|
||||
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||
n->n.Notify();
|
||||
}));
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||
TFE_ExecuteOpNotification* notification, TF_Status* status) {
|
||||
if (notification == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Passed in notification is a nullptr.");
|
||||
|
||||
return;
|
||||
}
|
||||
if (notification->thread == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Passed in notification didn't start a thread correctly. Cleaning up "
|
||||
"this notification. Please re-execute the operation to get a new "
|
||||
"notification.");
|
||||
|
||||
delete notification;
|
||||
return;
|
||||
}
|
||||
|
||||
notification->n.WaitForNotification();
|
||||
|
||||
status->status = notification->status->status;
|
||||
|
||||
delete notification;
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
|
@ -188,31 +188,6 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
// Prints `handle` in a human readable format to standard output for debugging.
|
||||
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
|
||||
TFE_TensorHandle* handle);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpPrintDebugString(TFE_Op* op);
|
||||
|
||||
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
|
||||
|
||||
// Allows invoking a kernel asynchronously, and explicitly returns a
|
||||
// notification that can be waited upon. This always executes the kernel in a
|
||||
// new thread.
|
||||
// 1. `retvals` and `num_retvals` can only be consumed after
|
||||
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
|
||||
// if the return is unsuccessful
|
||||
// 2. These new APIs cannot be used together with the TFE context level async
|
||||
// support.
|
||||
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
|
||||
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status);
|
||||
|
||||
// Waits to complete the op execution, and cleans up the notification.
|
||||
// Errors reported by op execution are set in `status`.
|
||||
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||
TFE_ExecuteOpNotification* notification, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -84,127 +84,6 @@ TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
||||
EXPECT_EQ(id, 0);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
|
||||
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
auto* r =
|
||||
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
|
||||
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
|
||||
TFE_DeleteOp(matmul_op);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Perform a send/recv test. Recv blocks, so they need to be executed
|
||||
// asynchronously.
|
||||
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
|
||||
// Build a send op.
|
||||
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(send_op, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
string tensor_name = "Tensor";
|
||||
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
|
||||
tensor_name.size());
|
||||
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
|
||||
send_device.size());
|
||||
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
|
||||
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
|
||||
recv_device.size());
|
||||
TFE_OpSetAttrBool(send_op, "client_terminated", true);
|
||||
|
||||
// Build a recv op.
|
||||
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
|
||||
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
|
||||
tensor_name.size());
|
||||
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
|
||||
send_device.size());
|
||||
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
|
||||
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
|
||||
recv_device.size());
|
||||
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
|
||||
|
||||
TFE_TensorHandle* send_retvals;
|
||||
int send_num_retvals = 0;
|
||||
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
|
||||
&send_num_retvals, status);
|
||||
|
||||
TFE_TensorHandle* recv_retvals[1] = {nullptr};
|
||||
int recv_num_retvals = 1;
|
||||
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
|
||||
&recv_num_retvals, status);
|
||||
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(1, product[0]);
|
||||
EXPECT_EQ(2, product[1]);
|
||||
EXPECT_EQ(3, product[2]);
|
||||
EXPECT_EQ(4, product[3]);
|
||||
|
||||
TFE_DeleteOp(send_op);
|
||||
TFE_DeleteOp(recv_op);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
TFE_DeleteTensorHandle(recv_retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
class ShapeInferenceTest : public ::testing::Test {
|
||||
protected:
|
||||
ShapeInferenceTest()
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
@ -89,7 +90,7 @@ tf_cuda_library(
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_eager_hdrs",
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
@ -129,18 +130,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
],
|
||||
)
|
||||
@ -301,6 +290,27 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"custom_device_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = ["tape.h"],
|
||||
@ -313,7 +323,10 @@ cc_library(
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["c_api.h"],
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/platform.h" // NOLINT
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -102,7 +103,12 @@ const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(const tensorflow::Device* d) {
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
return false;
|
||||
}
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
@ -265,9 +271,9 @@ tensorflow::Status GetReplacedFromExistingWorkers(
|
||||
}
|
||||
|
||||
tensorflow::Status CreateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, int keep_alive_secs,
|
||||
const tensorflow::ServerDef& server_def,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||
const bool lazy_copy_remote_function_inputs,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
@ -296,7 +302,7 @@ tensorflow::Status CreateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
tensorflow::eager::CreateContextRequest request(base_request);
|
||||
tensorflow::eager::CreateContextRequest request;
|
||||
tensorflow::eager::CreateContextResponse* response =
|
||||
new tensorflow::eager::CreateContextResponse();
|
||||
request.set_context_id(context_id);
|
||||
@ -304,6 +310,21 @@ tensorflow::Status CreateRemoteContexts(
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(),
|
||||
base_request.cluster_device_attributes_size());
|
||||
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
request.set_lazy_copy_remote_function_inputs(
|
||||
@ -325,13 +346,34 @@ tensorflow::Status CreateRemoteContexts(
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
||||
const std::vector<string>& added_workers,
|
||||
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
||||
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
||||
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
||||
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||
int num_remote_workers = remote_workers.size();
|
||||
tensorflow::BlockingCounter counter(num_remote_workers);
|
||||
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
||||
|
||||
int cluster_device_count = base_request.cluster_device_attributes_size();
|
||||
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
||||
added_workers.end());
|
||||
std::copy(removed_workers.begin(), removed_workers.end(),
|
||||
std::inserter(added_or_removed, added_or_removed.end()));
|
||||
// Whether each device is in the updated (added or removed) workers
|
||||
std::vector<bool> device_added_or_removed(cluster_device_count);
|
||||
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
||||
const auto& da = base_request.cluster_device_attributes().at(i);
|
||||
tensorflow::DeviceNameUtils::ParsedName pn;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
||||
string task_name;
|
||||
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
||||
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
||||
device_added_or_removed[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
const string& remote_worker = remote_workers[i];
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||
@ -354,17 +396,42 @@ tensorflow::Status UpdateRemoteContexts(
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<bool> filtered_device_mask;
|
||||
ctx->context->FilterDevicesForRemoteWorkers(
|
||||
remote_worker, base_request.cluster_device_attributes(),
|
||||
&filtered_device_mask);
|
||||
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
||||
|
||||
// If any of the devices that match the device filters are in the set of
|
||||
// added or removed workers, we must send a complete UpdateContextRequest.
|
||||
// Otherwise, only send a simple request to increment context view ID.
|
||||
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
||||
std::transform(device_added_or_removed.begin(),
|
||||
device_added_or_removed.end(), filtered_device_mask.begin(),
|
||||
added_or_removed_filtered_devices.begin(),
|
||||
std::logical_and<bool>());
|
||||
const bool full_update_request =
|
||||
std::accumulate(added_or_removed_filtered_devices.begin(),
|
||||
added_or_removed_filtered_devices.end(), false,
|
||||
std::logical_or<bool>());
|
||||
|
||||
tensorflow::eager::UpdateContextRequest request;
|
||||
auto* response = new tensorflow::eager::UpdateContextResponse();
|
||||
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
for (const auto& da : base_request.cluster_device_attributes()) {
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
request.set_context_id(context_id);
|
||||
request.set_context_view_id(context_view_id);
|
||||
if (full_update_request) {
|
||||
*request.mutable_server_def() = server_def;
|
||||
request.mutable_server_def()->set_job_name(parsed_name.job);
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
||||
server_def.default_session_config());
|
||||
for (int i = 0; i < cluster_device_count; i++) {
|
||||
if (filtered_device_mask[i]) {
|
||||
const auto& da = base_request.cluster_device_attributes(i);
|
||||
*request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eager_client->UpdateContextAsync(
|
||||
&request, response,
|
||||
@ -525,15 +592,12 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
base_request.mutable_server_def()
|
||||
->mutable_default_session_config()
|
||||
->MergeFrom(server_def.default_session_config());
|
||||
|
||||
// Initialize remote eager workers.
|
||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
} else {
|
||||
@ -543,7 +607,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
// we must set their context_view_id to the existing master's
|
||||
// context_view_id + 1.
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
if (!existing_workers.empty()) {
|
||||
@ -553,8 +617,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
}
|
||||
}
|
||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
||||
existing_workers, context_id, context_view_id + 1, server_def,
|
||||
remote_eager_workers.get(), base_request));
|
||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
||||
base_request));
|
||||
}
|
||||
}
|
||||
|
||||
@ -709,6 +774,22 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
||||
"Invalid tensorflow.ServerDef protocol buffer");
|
||||
return;
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
const auto& cdf = server_def.cluster_device_filters();
|
||||
for (const auto& jdf : cdf.jobs()) {
|
||||
const string& remote_prefix = "/job:" + jdf.name() + "/task:";
|
||||
for (const auto& tdf : jdf.tasks()) {
|
||||
const int32_t task_index = tdf.first;
|
||||
std::vector<string> device_filters(tdf.second.device_filters_size());
|
||||
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
||||
device_filters[i] = tdf.second.device_filters(i);
|
||||
}
|
||||
const string remote_worker = remote_prefix + std::to_string(task_index);
|
||||
status->status =
|
||||
ctx->context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
||||
}
|
||||
}
|
||||
}
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/true);
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
@ -733,6 +814,11 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Trying to update a context with invalid context id.");
|
||||
}
|
||||
if (server_def.has_cluster_device_filters()) {
|
||||
LOG(WARNING) << "Device filters can only be specified when initializing "
|
||||
"the cluster. Any changes in device filters are ignored "
|
||||
"when updating the server def.";
|
||||
}
|
||||
// TODO(haoyuzhang): Check server_def compatibility before the update
|
||||
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
||||
ctx, /*reset_context=*/false);
|
||||
@ -797,6 +883,15 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
||||
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
||||
ctx->context->SetThreadLocalDevicePlacementPolicy(
|
||||
@ -928,6 +1023,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
@ -948,9 +1046,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<tensorflow::CustomDevice*>(handle_->device())
|
||||
->name()
|
||||
.c_str();
|
||||
} else {
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
@ -984,6 +1088,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
tensorflow::CustomDevice* custom_device =
|
||||
absl::get<tensorflow::CustomDevice*>(handle_->device());
|
||||
tensorflow::TensorHandle* copy;
|
||||
*status = custom_device->CopyTensorFromDevice(
|
||||
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", ©);
|
||||
if (status->ok()) {
|
||||
return TensorHandleInterface(copy).Resolve(status);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle_->IsRemote()) {
|
||||
@ -1029,6 +1145,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1036,8 +1157,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
"handle.");
|
||||
return nullptr;
|
||||
}
|
||||
if (handle->device() != nullptr) {
|
||||
status->status = handle->device()->Sync();
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
if (device != nullptr) {
|
||||
status->status = device->Sync();
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1056,12 +1178,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
status->status =
|
||||
context->FindCustomDeviceFromName(device_name, &custom_device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
@ -1085,8 +1212,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
|
||||
if (custom_device == nullptr) {
|
||||
variant_device = device;
|
||||
} else {
|
||||
variant_device = custom_device;
|
||||
}
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
t, variant_device, context, &ret_handle);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1427,8 +1560,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
&handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = h->handle->DeviceName(&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
@ -1567,3 +1734,94 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
|
||||
: device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
&tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
||||
TF_Status status;
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
outputs[i]->handle.get())
|
||||
->Handle();
|
||||
retvals[i]->Ref();
|
||||
}
|
||||
}
|
||||
for (auto inp : inputs) {
|
||||
delete inp;
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle_->device();
|
||||
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
|
@ -88,14 +88,14 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
int num_tracing_attempts,
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return false;
|
||||
}
|
||||
s = tensorflow::profiler::client::StartTracing(
|
||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
|
||||
include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return s.ok();
|
||||
}
|
||||
@ -104,14 +104,14 @@ void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
tensorflow::profiler::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::client::Monitor(
|
||||
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
||||
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
|
||||
display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
|
@ -27,7 +27,7 @@ extern "C" {
|
||||
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
|
||||
// does not set the device name. If it's not `NULL`, then it attempts to parse
|
||||
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
|
||||
// than seperately calling it because if the existing op has the same
|
||||
// than separately calling it because if the existing op has the same
|
||||
// `raw_device_name`, it skips parsing and just leave as it is.
|
||||
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
const char* op_or_function_name,
|
||||
@ -434,6 +434,10 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
@ -463,6 +467,57 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
// TODO(allenl) figure out a generic way of passing attrs here
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
} TFE_CustomDevice;
|
||||
|
||||
// Registers a custom device for use with eager execution.
|
||||
//
|
||||
// Eager operations may be placed on this device, e.g. `with
|
||||
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||
//
|
||||
// The custom device defines copy operations for moving TensorHandles on and
|
||||
// off, and an an execution operation for named operations. Often execution will
|
||||
// simply wrap op execution on one or more physical devices.
|
||||
//
|
||||
// device_info is an opaque caller-defined type stored with the custom device
|
||||
// which is passed to the functions referenced in the TFE_CustomDevice struct
|
||||
// `device` (execute, delete_device, etc.). It can for example contain the
|
||||
// names of wrapped devices.
|
||||
//
|
||||
// There are currently no graph semantics implemented for registered custom
|
||||
// devices, so executing tf.functions which contain operations placed on custom
|
||||
// devices will fail.
|
||||
//
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -363,13 +363,18 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) {
|
||||
TensorHandleCopyBetweenTwoGPUDevices(true);
|
||||
}
|
||||
|
||||
void TensorHandleSilentCopy(bool async) {
|
||||
void TensorHandleSilentCopy(bool async,
|
||||
TFE_ContextDevicePlacementPolicy global_policy,
|
||||
TFE_ContextDevicePlacementPolicy thread_policy) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, global_policy);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
if (thread_policy != global_policy) {
|
||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, thread_policy);
|
||||
}
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -404,57 +409,21 @@ void TensorHandleSilentCopy(bool async) {
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); }
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); }
|
||||
|
||||
void TensorHandleSilentCopyLocal(bool async) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); }
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalAsync) {
|
||||
TensorHandleSilentCopyLocal(true);
|
||||
TEST(CAPI, TensorHandleSilentCopyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_SILENT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicy) {
|
||||
TensorHandleSilentCopy(false, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
}
|
||||
TEST(CAPI, TensorHandleSilentCopyLocalPolicyAsync) {
|
||||
TensorHandleSilentCopy(true, TFE_DEVICE_PLACEMENT_EXPLICIT,
|
||||
TFE_DEVICE_PLACEMENT_SILENT);
|
||||
}
|
||||
|
||||
void SetAndGetOpDevices(bool async) {
|
||||
|
159
tensorflow/c/eager/custom_device_test.cc
Normal file
159
tensorflow/c/eager/custom_device_test.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
TFE_Context* ctx;
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* ctx, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, dev->ctx, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* context = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived);
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
ASSERT_TRUE(arrived);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
TFE_DeleteContext(context);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -56,7 +56,7 @@ extern "C" {
|
||||
/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data
|
||||
/// pointed to by the `void*` members is always owned by the plugin. The plugin
|
||||
/// will provide functions to call to allocate and deallocate this data (see
|
||||
/// next section) and core TensorFlow ensures to call these at the proper time.
|
||||
/// next sections) and core TensorFlow ensures to call these at the proper time.
|
||||
///
|
||||
/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core
|
||||
/// TensorFlow will never touch the `void*` wrapped by these structures, except
|
||||
@ -601,6 +601,10 @@ typedef struct TF_FilesystemOps {
|
||||
///
|
||||
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// This function will be called by core TensorFlow to clean up all path
|
||||
/// arguments for all other methods in the filesystem API.
|
||||
///
|
||||
@ -618,6 +622,10 @@ typedef struct TF_FilesystemOps {
|
||||
/// In case of error, plugins must set `status` to a value different than
|
||||
/// `TF_OK`, free memory allocated for `entries` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if all children were returned.
|
||||
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a
|
||||
@ -654,6 +662,10 @@ typedef struct TF_FilesystemOps {
|
||||
/// different than `TF_OK`, free any memory that might have been allocated for
|
||||
/// `entries` and return -1.
|
||||
///
|
||||
/// The allocation and freeing of memory must happen via the functions sent to
|
||||
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
|
||||
/// structure in Section 4).
|
||||
///
|
||||
/// Plugins:
|
||||
/// * Must set `status` to `TF_OK` if all matches were returned.
|
||||
/// * Might use any other error value for `status` to signal other errors.
|
||||
@ -741,8 +753,11 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
||||
/// * `TF_InitPlugin` function: must be present in the plugin shared object as
|
||||
/// it will be called by core TensorFlow when the filesystem plugin is
|
||||
/// loaded;
|
||||
/// * `TF_FilesystemPluginInfo` struct: used to transfer information between
|
||||
/// * `TF_FilesystemPluginOps` struct: used to transfer information between
|
||||
/// plugins and core TensorFlow about the operations provided and metadata;
|
||||
/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but
|
||||
/// collects information about all the file schemes that the plugin provides
|
||||
/// support for, as well as about the plugin's memory handling routines;
|
||||
/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in
|
||||
/// their `TF_InitPlugin` to record the versioning information the plugins
|
||||
/// are compiled against.
|
||||
@ -774,7 +789,7 @@ constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps);
|
||||
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||
/// must not change! In the unlikely case that a new type of file needs to be
|
||||
/// supported, add the new ops and metadata at the end of the structure.
|
||||
typedef struct TF_FilesystemPluginInfo {
|
||||
typedef struct TF_FilesystemPluginOps {
|
||||
char* scheme;
|
||||
int filesystem_ops_abi;
|
||||
int filesystem_ops_api;
|
||||
@ -792,6 +807,29 @@ typedef struct TF_FilesystemPluginInfo {
|
||||
int read_only_memory_region_ops_api;
|
||||
size_t read_only_memory_region_ops_size;
|
||||
TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops;
|
||||
} TF_FilesystemPluginOps;
|
||||
|
||||
/// This structure gathers together all the operations provided by the plugin.
|
||||
///
|
||||
/// Plugins must provide exactly `num_schemes` elements in the `ops` array.
|
||||
///
|
||||
/// Since memory that is allocated by the DSO gets transferred to core
|
||||
/// TensorFlow, we need to provide a way for the allocation and deallocation to
|
||||
/// match. This is why this structure also defines `plugin_memory_allocate` and
|
||||
/// `plugin_memory_free` members.
|
||||
///
|
||||
/// All memory allocated by the plugin that will be owned by core TensorFlow
|
||||
/// must be allocated using the allocator in this structure. Core TensorFlow
|
||||
/// will use the deallocator to free this memory once it no longer needs it.
|
||||
///
|
||||
/// IMPORTANT: To maintain binary compatibility, the layout of this structure
|
||||
/// must not change! In the unlikely case that new global operations must be
|
||||
/// provided, add them at the end of the structure.
|
||||
typedef struct TF_FilesystemPluginInfo {
|
||||
size_t num_schemes;
|
||||
TF_FilesystemPluginOps* ops;
|
||||
void* (*plugin_memory_allocate)(size_t size);
|
||||
void (*plugin_memory_free)(void* ptr);
|
||||
} TF_FilesystemPluginInfo;
|
||||
|
||||
/// Convenience function for setting the versioning metadata.
|
||||
@ -801,19 +839,19 @@ typedef struct TF_FilesystemPluginInfo {
|
||||
/// We want this to be defined in the plugin's memory space and we guarantee
|
||||
/// that core TensorFlow will never call this.
|
||||
static inline void TF_SetFilesystemVersionMetadata(
|
||||
TF_FilesystemPluginInfo* info) {
|
||||
info->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||
info->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||
info->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||
info->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||
info->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||
info->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||
info->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||
info->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||
info->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||
info->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||
info->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||
info->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||
TF_FilesystemPluginOps* ops) {
|
||||
ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI;
|
||||
ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API;
|
||||
ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE;
|
||||
ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI;
|
||||
ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API;
|
||||
ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE;
|
||||
ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI;
|
||||
ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API;
|
||||
ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE;
|
||||
ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI;
|
||||
ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API;
|
||||
ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE;
|
||||
}
|
||||
|
||||
/// Initializes a TensorFlow plugin.
|
||||
@ -828,16 +866,14 @@ static inline void TF_SetFilesystemVersionMetadata(
|
||||
/// manage themselves). In both of these cases, core TensorFlow looks for
|
||||
/// the `TF_InitPlugin` symbol and calls this function.
|
||||
///
|
||||
/// All memory allocated by this function must be allocated via the `allocator`
|
||||
/// argument.
|
||||
///
|
||||
/// For every filesystem URI scheme that this plugin supports, the plugin must
|
||||
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info`.
|
||||
/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call
|
||||
/// `TF_SetFilesystemVersionMetadata` for that entry.
|
||||
///
|
||||
/// Returns number of entries in `plugin_info` (i.e., number of URI schemes
|
||||
/// supported).
|
||||
TF_CAPI_EXPORT extern int TF_InitPlugin(void* (*allocator)(size_t size),
|
||||
TF_FilesystemPluginInfo** plugin_info);
|
||||
/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and
|
||||
/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is
|
||||
/// freed in a compatible way.
|
||||
TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
|
@ -164,16 +164,18 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
std::string translated_name = TranslateName(dir);
|
||||
char** children;
|
||||
// Note that `children` is allocated by the plugin and freed by core
|
||||
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||
char** children = nullptr;
|
||||
const int num_children =
|
||||
ops_->get_children(filesystem_.get(), translated_name.c_str(), &children,
|
||||
plugin_status.get());
|
||||
if (num_children >= 0) {
|
||||
for (int i = 0; i < num_children; i++) {
|
||||
result->push_back(std::string(children[i]));
|
||||
free(children[i]);
|
||||
plugin_memory_free_(children[i]);
|
||||
}
|
||||
free(children);
|
||||
plugin_memory_free_(children);
|
||||
}
|
||||
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
@ -185,15 +187,17 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
|
||||
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
|
||||
|
||||
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
|
||||
char** matches;
|
||||
// Note that `matches` is allocated by the plugin and freed by core
|
||||
// TensorFlow, so we need to use `plugin_memory_free_` here.
|
||||
char** matches = nullptr;
|
||||
const int num_matches = ops_->get_matching_paths(
|
||||
filesystem_.get(), pattern.c_str(), &matches, plugin_status.get());
|
||||
if (num_matches >= 0) {
|
||||
for (int i = 0; i < num_matches; i++) {
|
||||
result->push_back(std::string(matches[i]));
|
||||
free(matches[i]);
|
||||
plugin_memory_free_(matches[i]);
|
||||
}
|
||||
free(matches);
|
||||
plugin_memory_free_(matches);
|
||||
}
|
||||
|
||||
return StatusFromTF_Status(plugin_status.get());
|
||||
@ -357,7 +361,8 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
|
||||
CHECK(p != nullptr) << "TranslateName(" << name << ") returned nullptr";
|
||||
|
||||
std::string ret(p);
|
||||
free(p);
|
||||
// Since `p` is allocated by plugin, free it using plugin's method.
|
||||
plugin_memory_free_(p);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -46,12 +46,16 @@ class ModularFileSystem final : public FileSystem {
|
||||
std::unique_ptr<const TF_RandomAccessFileOps> random_access_file_ops,
|
||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops,
|
||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||
read_only_memory_region_ops)
|
||||
read_only_memory_region_ops,
|
||||
std::function<void*(size_t)> plugin_memory_allocate,
|
||||
std::function<void(void*)> plugin_memory_free)
|
||||
: filesystem_(std::move(filesystem)),
|
||||
ops_(std::move(filesystem_ops)),
|
||||
random_access_file_ops_(std::move(random_access_file_ops)),
|
||||
writable_file_ops_(std::move(writable_file_ops)),
|
||||
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)) {}
|
||||
read_only_memory_region_ops_(std::move(read_only_memory_region_ops)),
|
||||
plugin_memory_allocate_(std::move(plugin_memory_allocate)),
|
||||
plugin_memory_free_(std::move(plugin_memory_free)) {}
|
||||
|
||||
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
|
||||
|
||||
@ -93,6 +97,8 @@ class ModularFileSystem final : public FileSystem {
|
||||
std::unique_ptr<const TF_WritableFileOps> writable_file_ops_;
|
||||
std::unique_ptr<const TF_ReadOnlyMemoryRegionOps>
|
||||
read_only_memory_region_ops_;
|
||||
std::function<void*(size_t)> plugin_memory_allocate_;
|
||||
std::function<void(void*)> plugin_memory_free_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ModularFileSystem);
|
||||
};
|
||||
|
||||
|
@ -50,21 +50,21 @@ static Status CheckABI(int pluginABI, int coreABI, StringPiece where) {
|
||||
// If the numbers don't match, plugin cannot be loaded.
|
||||
//
|
||||
// Uses the simpler `CheckABI(int, int, StringPiece)`.
|
||||
static Status ValidateABI(const TF_FilesystemPluginInfo* info) {
|
||||
static Status ValidateABI(const TF_FilesystemPluginOps* ops) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckABI(info->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||
CheckABI(ops->filesystem_ops_abi, TF_FILESYSTEM_OPS_ABI, "filesystem"));
|
||||
|
||||
if (info->random_access_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->random_access_file_ops_abi,
|
||||
if (ops->random_access_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->random_access_file_ops_abi,
|
||||
TF_RANDOM_ACCESS_FILE_OPS_ABI,
|
||||
"random access file"));
|
||||
|
||||
if (info->writable_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->writable_file_ops_abi,
|
||||
if (ops->writable_file_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->writable_file_ops_abi,
|
||||
TF_WRITABLE_FILE_OPS_ABI, "writable file"));
|
||||
|
||||
if (info->read_only_memory_region_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(info->read_only_memory_region_ops_abi,
|
||||
if (ops->read_only_memory_region_ops != nullptr)
|
||||
TF_RETURN_IF_ERROR(CheckABI(ops->read_only_memory_region_ops_abi,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_ABI,
|
||||
"read only memory region"));
|
||||
|
||||
@ -83,19 +83,19 @@ static void CheckAPI(int plugin_API, int core_API, StringPiece where) {
|
||||
// Checks if the plugin and core API numbers match, for all operations.
|
||||
//
|
||||
// Uses the simpler `CheckAPIHelper(int, int, StringPiece)`.
|
||||
static void ValidateAPI(const TF_FilesystemPluginInfo* info) {
|
||||
CheckAPI(info->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||
static void ValidateAPI(const TF_FilesystemPluginOps* ops) {
|
||||
CheckAPI(ops->filesystem_ops_api, TF_FILESYSTEM_OPS_API, "filesystem");
|
||||
|
||||
if (info->random_access_file_ops != nullptr)
|
||||
CheckAPI(info->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||
if (ops->random_access_file_ops != nullptr)
|
||||
CheckAPI(ops->random_access_file_ops_api, TF_RANDOM_ACCESS_FILE_OPS_API,
|
||||
"random access file");
|
||||
|
||||
if (info->writable_file_ops != nullptr)
|
||||
CheckAPI(info->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||
if (ops->writable_file_ops != nullptr)
|
||||
CheckAPI(ops->writable_file_ops_api, TF_WRITABLE_FILE_OPS_API,
|
||||
"writable file");
|
||||
|
||||
if (info->read_only_memory_region_ops != nullptr)
|
||||
CheckAPI(info->read_only_memory_region_ops_api,
|
||||
if (ops->read_only_memory_region_ops != nullptr)
|
||||
CheckAPI(ops->read_only_memory_region_ops_api,
|
||||
TF_READ_ONLY_MEMORY_REGION_OPS_API, "read only memory region");
|
||||
}
|
||||
|
||||
@ -177,27 +177,27 @@ static Status ValidateHelper(const TF_ReadOnlyMemoryRegionOps* ops) {
|
||||
// individual function table and then checks that the function table for a
|
||||
// specific file type exists if the plugin offers support for creating that
|
||||
// type of files.
|
||||
static Status ValidateOperations(const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->filesystem_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->random_access_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->writable_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(info->read_only_memory_region_ops));
|
||||
static Status ValidateOperations(const TF_FilesystemPluginOps* ops) {
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->filesystem_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->random_access_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->writable_file_ops));
|
||||
TF_RETURN_IF_ERROR(ValidateHelper(ops->read_only_memory_region_ops));
|
||||
|
||||
if (info->filesystem_ops->new_random_access_file != nullptr &&
|
||||
info->random_access_file_ops == nullptr)
|
||||
if (ops->filesystem_ops->new_random_access_file != nullptr &&
|
||||
ops->random_access_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of random access files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if ((info->filesystem_ops->new_writable_file != nullptr ||
|
||||
info->filesystem_ops->new_appendable_file != nullptr) &&
|
||||
info->writable_file_ops == nullptr)
|
||||
if ((ops->filesystem_ops->new_writable_file != nullptr ||
|
||||
ops->filesystem_ops->new_appendable_file != nullptr) &&
|
||||
ops->writable_file_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of writable files but no "
|
||||
"operations on them have been supplied.");
|
||||
|
||||
if (info->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
info->read_only_memory_region_ops == nullptr)
|
||||
if (ops->filesystem_ops->new_read_only_memory_region_from_file != nullptr &&
|
||||
ops->read_only_memory_region_ops == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Filesystem allows creation of readonly memory regions but no "
|
||||
"operations on them have been supplied.");
|
||||
@ -232,18 +232,23 @@ static std::unique_ptr<const T> CopyToCore(const T* plugin_ops,
|
||||
}
|
||||
|
||||
// Registers one filesystem from the plugin.
|
||||
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
|
||||
//
|
||||
// Must be called only with `index` a valid index in `info->ops`.
|
||||
static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info,
|
||||
int index) {
|
||||
// Step 1: Copy all the function tables to core TensorFlow memory space
|
||||
auto core_filesystem_ops = CopyToCore<TF_FilesystemOps>(
|
||||
info->filesystem_ops, info->filesystem_ops_size);
|
||||
info->ops[index].filesystem_ops, info->ops[index].filesystem_ops_size);
|
||||
auto core_random_access_file_ops = CopyToCore<TF_RandomAccessFileOps>(
|
||||
info->random_access_file_ops, info->random_access_file_ops_size);
|
||||
auto core_writable_file_ops = CopyToCore<TF_WritableFileOps>(
|
||||
info->writable_file_ops, info->writable_file_ops_size);
|
||||
info->ops[index].random_access_file_ops,
|
||||
info->ops[index].random_access_file_ops_size);
|
||||
auto core_writable_file_ops =
|
||||
CopyToCore<TF_WritableFileOps>(info->ops[index].writable_file_ops,
|
||||
info->ops[index].writable_file_ops_size);
|
||||
auto core_read_only_memory_region_ops =
|
||||
CopyToCore<TF_ReadOnlyMemoryRegionOps>(
|
||||
info->read_only_memory_region_ops,
|
||||
info->read_only_memory_region_ops_size);
|
||||
info->ops[index].read_only_memory_region_ops,
|
||||
info->ops[index].read_only_memory_region_ops_size);
|
||||
|
||||
// Step 2: Initialize the opaque filesystem structure
|
||||
auto filesystem = tensorflow::MakeUnique<TF_Filesystem>();
|
||||
@ -256,32 +261,46 @@ static Status RegisterFileSystem(const TF_FilesystemPluginInfo* info) {
|
||||
|
||||
// Step 3: Actual registration
|
||||
return Env::Default()->RegisterFileSystem(
|
||||
info->scheme, tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops)));
|
||||
info->ops[index].scheme,
|
||||
tensorflow::MakeUnique<tensorflow::ModularFileSystem>(
|
||||
std::move(filesystem), std::move(core_filesystem_ops),
|
||||
std::move(core_random_access_file_ops),
|
||||
std::move(core_writable_file_ops),
|
||||
std::move(core_read_only_memory_region_ops),
|
||||
info->plugin_memory_allocate, info->plugin_memory_free));
|
||||
}
|
||||
|
||||
// Registers all filesystems, if plugin is providing valid information.
|
||||
// Registers filesystem at `index`, if plugin is providing valid information.
|
||||
//
|
||||
// Extracted to a separate function so that pointers inside `info` are freed
|
||||
// by the caller regardless of whether validation/registration failed or not.
|
||||
//
|
||||
// Must be called only with `index` a valid index in `info->ops`.
|
||||
static Status ValidateAndRegisterFilesystems(
|
||||
const TF_FilesystemPluginInfo* info) {
|
||||
TF_RETURN_IF_ERROR(ValidateScheme(info->scheme));
|
||||
TF_RETURN_IF_ERROR(ValidateABI(info));
|
||||
ValidateAPI(info); // we just warn on API number mismatch
|
||||
TF_RETURN_IF_ERROR(ValidateOperations(info));
|
||||
TF_RETURN_IF_ERROR(RegisterFileSystem(info));
|
||||
const TF_FilesystemPluginInfo* info, int index) {
|
||||
TF_RETURN_IF_ERROR(ValidateScheme(info->ops[index].scheme));
|
||||
TF_RETURN_IF_ERROR(ValidateABI(&info->ops[index]));
|
||||
ValidateAPI(&info->ops[index]); // we just warn on API number mismatch
|
||||
TF_RETURN_IF_ERROR(ValidateOperations(&info->ops[index]));
|
||||
TF_RETURN_IF_ERROR(RegisterFileSystem(info, index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Alocates memory in plugin DSO.
|
||||
//
|
||||
// Provided by core TensorFlow so that it can free this memory after DSO is
|
||||
// loaded and filesystem information has been used to register the filesystem.
|
||||
static void* basic_allocator(size_t size) { return calloc(1, size); }
|
||||
// Ensures that the plugin provides the required memory management operations.
|
||||
static Status ValidatePluginMemoryRoutines(
|
||||
const TF_FilesystemPluginInfo* info) {
|
||||
if (info->plugin_memory_allocate == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Cannot load filesystem plugin which does not provide "
|
||||
"`plugin_memory_allocate`");
|
||||
|
||||
if (info->plugin_memory_free == nullptr)
|
||||
return errors::FailedPrecondition(
|
||||
"Cannot load filesystem plugin which does not provide "
|
||||
"`plugin_memory_free`");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace filesystem_registration {
|
||||
|
||||
@ -297,26 +316,28 @@ Status RegisterFilesystemPluginImpl(const std::string& dso_path) {
|
||||
env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol));
|
||||
|
||||
// Step 3: Call `TF_InitPlugin`
|
||||
TF_FilesystemPluginInfo* info = nullptr;
|
||||
auto TF_InitPlugin = reinterpret_cast<int (*)(
|
||||
decltype(&basic_allocator), TF_FilesystemPluginInfo**)>(dso_symbol);
|
||||
int num_schemes = TF_InitPlugin(&basic_allocator, &info);
|
||||
if (num_schemes < 0 || info == nullptr)
|
||||
return errors::InvalidArgument("DSO returned invalid filesystem data");
|
||||
TF_FilesystemPluginInfo info;
|
||||
memset(&info, 0, sizeof(info));
|
||||
auto TF_InitPlugin =
|
||||
reinterpret_cast<int (*)(TF_FilesystemPluginInfo*)>(dso_symbol);
|
||||
TF_InitPlugin(&info);
|
||||
|
||||
// Step 4: Validate and register all filesystems
|
||||
// Step 4: Ensure plugin provides the memory management functions.
|
||||
TF_RETURN_IF_ERROR(ValidatePluginMemoryRoutines(&info));
|
||||
|
||||
// Step 5: Validate and register all filesystems
|
||||
// Try to register as many filesystems as possible.
|
||||
// Free memory once we no longer need it
|
||||
Status status;
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(&info[i]));
|
||||
free(info[i].scheme);
|
||||
free(info[i].filesystem_ops);
|
||||
free(info[i].random_access_file_ops);
|
||||
free(info[i].writable_file_ops);
|
||||
free(info[i].read_only_memory_region_ops);
|
||||
for (int i = 0; i < info.num_schemes; i++) {
|
||||
status.Update(ValidateAndRegisterFilesystems(&info, i));
|
||||
info.plugin_memory_free(info.ops[i].scheme);
|
||||
info.plugin_memory_free(info.ops[i].filesystem_ops);
|
||||
info.plugin_memory_free(info.ops[i].random_access_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].writable_file_ops);
|
||||
info.plugin_memory_free(info.ops[i].read_only_memory_region_ops);
|
||||
}
|
||||
free(info);
|
||||
info.plugin_memory_free(info.ops);
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -1569,7 +1569,7 @@ TEST_P(ModularFileSystemTest, TestRoundTrip) {
|
||||
if (!status.ok())
|
||||
GTEST_SKIP() << "NewRandomAccessFile() not supported: " << status;
|
||||
|
||||
char scratch[64 /* big enough to accomodate test_data */] = {0};
|
||||
char scratch[64 /* big enough to accommodate test_data */] = {0};
|
||||
StringPiece result;
|
||||
status = read_file->Read(0, test_data.size(), &result, scratch);
|
||||
EXPECT_PRED2(UnimplementedOrReturnsCode, status, Code::OK);
|
||||
|
@ -31,6 +31,9 @@ limitations under the License.
|
||||
// Implementation of a filesystem for POSIX environments.
|
||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -43,7 +46,9 @@ typedef struct PosixFile {
|
||||
static void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||
close(posix_file->fd);
|
||||
free(const_cast<char*>(posix_file->filename));
|
||||
// This would be safe to free using `free` directly as it is only opaque.
|
||||
// However, it is better to be consistent everywhere.
|
||||
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||
delete posix_file;
|
||||
}
|
||||
|
||||
@ -98,7 +103,7 @@ typedef struct PosixFile {
|
||||
|
||||
static void Cleanup(TF_WritableFile* file) {
|
||||
auto posix_file = static_cast<PosixFile*>(file->plugin_file);
|
||||
free(const_cast<char*>(posix_file->filename));
|
||||
plugin_memory_free(const_cast<char*>(posix_file->filename));
|
||||
delete posix_file;
|
||||
}
|
||||
|
||||
@ -381,12 +386,13 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
if (num_entries < 0) {
|
||||
TF_SetStatusFromIOError(status, errno, path);
|
||||
} else {
|
||||
*entries = static_cast<char**>(calloc(num_entries, sizeof((*entries)[0])));
|
||||
*entries = static_cast<char**>(
|
||||
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
|
||||
for (int i = 0; i < num_entries; i++) {
|
||||
(*entries)[i] = strdup(dir_entries[i]->d_name);
|
||||
free(dir_entries[i]);
|
||||
plugin_memory_free(dir_entries[i]);
|
||||
}
|
||||
free(dir_entries);
|
||||
plugin_memory_free(dir_entries);
|
||||
}
|
||||
|
||||
return num_entries;
|
||||
@ -394,65 +400,59 @@ static int GetChildren(const TF_Filesystem* filesystem, const char* path,
|
||||
|
||||
} // namespace tf_posix_filesystem
|
||||
|
||||
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||
const int num_schemes = 2;
|
||||
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||
allocator(num_schemes * sizeof((*info)[0])));
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||
TF_SetFilesystemVersionMetadata(current_info);
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
current_info->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
allocator(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
current_info->random_access_file_ops->cleanup =
|
||||
tf_random_access_file::Cleanup;
|
||||
current_info->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
ops->writable_file_ops->append = tf_writable_file::Append;
|
||||
ops->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
ops->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
ops->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
ops->writable_file_ops->close = tf_writable_file::Close;
|
||||
|
||||
current_info->writable_file_ops =
|
||||
static_cast<TF_WritableFileOps*>(allocator(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
current_info->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
current_info->writable_file_ops->append = tf_writable_file::Append;
|
||||
current_info->writable_file_ops->tell = tf_writable_file::Tell;
|
||||
current_info->writable_file_ops->flush = tf_writable_file::Flush;
|
||||
current_info->writable_file_ops->sync = tf_writable_file::Sync;
|
||||
current_info->writable_file_ops->close = tf_writable_file::Close;
|
||||
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||
ops->read_only_memory_region_ops->cleanup =
|
||||
tf_read_only_memory_region::Cleanup;
|
||||
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
|
||||
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
|
||||
|
||||
current_info->read_only_memory_region_ops =
|
||||
static_cast<TF_ReadOnlyMemoryRegionOps*>(
|
||||
allocator(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
|
||||
current_info->read_only_memory_region_ops->cleanup =
|
||||
tf_read_only_memory_region::Cleanup;
|
||||
current_info->read_only_memory_region_ops->data =
|
||||
tf_read_only_memory_region::Data;
|
||||
current_info->read_only_memory_region_ops->length =
|
||||
tf_read_only_memory_region::Length;
|
||||
|
||||
current_info->filesystem_ops =
|
||||
static_cast<TF_FilesystemOps*>(allocator(TF_FILESYSTEM_OPS_SIZE));
|
||||
current_info->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||
current_info->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||
current_info->filesystem_ops->new_random_access_file =
|
||||
tf_posix_filesystem::NewRandomAccessFile;
|
||||
current_info->filesystem_ops->new_writable_file =
|
||||
tf_posix_filesystem::NewWritableFile;
|
||||
current_info->filesystem_ops->new_appendable_file =
|
||||
tf_posix_filesystem::NewAppendableFile;
|
||||
current_info->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
current_info->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||
current_info->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||
current_info->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||
current_info->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||
current_info->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||
current_info->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||
current_info->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||
current_info->filesystem_ops->get_children =
|
||||
tf_posix_filesystem::GetChildren;
|
||||
}
|
||||
|
||||
(*info)[0].scheme = strdup("");
|
||||
(*info)[1].scheme = strdup("file");
|
||||
|
||||
return num_schemes;
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_posix_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_posix_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_posix_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file = tf_posix_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_posix_filesystem::NewAppendableFile;
|
||||
ops->filesystem_ops->new_read_only_memory_region_from_file =
|
||||
tf_posix_filesystem::NewReadOnlyMemoryRegionFromFile;
|
||||
ops->filesystem_ops->create_dir = tf_posix_filesystem::CreateDir;
|
||||
ops->filesystem_ops->delete_file = tf_posix_filesystem::DeleteFile;
|
||||
ops->filesystem_ops->delete_dir = tf_posix_filesystem::DeleteDir;
|
||||
ops->filesystem_ops->rename_file = tf_posix_filesystem::RenameFile;
|
||||
ops->filesystem_ops->copy_file = tf_posix_filesystem::CopyFile;
|
||||
ops->filesystem_ops->path_exists = tf_posix_filesystem::PathExists;
|
||||
ops->filesystem_ops->stat = tf_posix_filesystem::Stat;
|
||||
ops->filesystem_ops->get_children = tf_posix_filesystem::GetChildren;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 2;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||
}
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
// Implementation of a filesystem for POSIX environments.
|
||||
// This filesystem will support `file://` and empty (local) URI schemes.
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -53,18 +56,18 @@ namespace tf_windows_filesystem {
|
||||
|
||||
} // namespace tf_windows_filesystem
|
||||
|
||||
int TF_InitPlugin(void* (*allocator)(size_t), TF_FilesystemPluginInfo** info) {
|
||||
const int num_schemes = 2;
|
||||
*info = static_cast<TF_FilesystemPluginInfo*>(
|
||||
allocator(num_schemes * sizeof((*info)[0])));
|
||||
|
||||
for (int i = 0; i < num_schemes; i++) {
|
||||
TF_FilesystemPluginInfo* current_info = &((*info)[i]);
|
||||
TF_SetFilesystemVersionMetadata(current_info);
|
||||
}
|
||||
|
||||
(*info)[0].scheme = strdup("");
|
||||
(*info)[1].scheme = strdup("file");
|
||||
|
||||
return num_schemes;
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->plugin_memory_allocate = plugin_memory_allocate;
|
||||
info->plugin_memory_free = plugin_memory_free;
|
||||
info->num_schemes = 2;
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "");
|
||||
ProvideFilesystemSupportFor(&info->ops[1], "file");
|
||||
}
|
||||
|
@ -18,19 +18,36 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/kernels.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct MyCustomKernel {
|
||||
bool created;
|
||||
|
@ -41,6 +41,16 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"training/coordinator.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
|
@ -33,6 +33,7 @@ cc_library(
|
||||
deps = [
|
||||
":aot_only_var_handle_op",
|
||||
":embedded_protocol_buffers",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -133,8 +134,9 @@ cc_library(
|
||||
# tfcompile.bzl correctly handles usage from outside of the package that it is
|
||||
# defined in.
|
||||
|
||||
# A simple test of tf_library from a text protobuf, mostly to enable the
|
||||
# benchmark_test.
|
||||
# A simple test of tf_library from a text protobuf, to enable benchmark_test.
|
||||
# This test uses an incompleted graph with a node that is not defined. The
|
||||
# compilation works because the undefined node is a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfadd",
|
||||
testonly = 1,
|
||||
@ -146,8 +148,21 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfadd_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfadd.config.pbtxt",
|
||||
cpp_class = "AddComp",
|
||||
graph = "test_graph_tfadd.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is not needed for the fetches.
|
||||
# the compilation works because the node with the unknown op is not needed
|
||||
# for the fetches.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop",
|
||||
testonly = 1,
|
||||
@ -159,9 +174,21 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the op between the unknown op and the
|
||||
# fetches is a feed.
|
||||
# the compilation works because the node with the unknown op is only used as
|
||||
# an input of a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop2",
|
||||
testonly = 1,
|
||||
@ -173,8 +200,20 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop2_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop2.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# A test of tf_library that includes a graph with an unknown op, but where
|
||||
# the compilation works because the unknown op is fed.
|
||||
# the compilation works because the node with the unknown op is a feed node.
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop3",
|
||||
testonly = 1,
|
||||
@ -186,6 +225,18 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfunknownop3_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tfunknownop3.config.pbtxt",
|
||||
cpp_class = "UnknownOpAddComp",
|
||||
graph = "test_graph_tfunknownop.pbtxt",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
# Utility library for benchmark binaries, used by the *_benchmark rules that are
|
||||
# added by the tfcompile bazel macro.
|
||||
cc_library(
|
||||
@ -260,9 +311,13 @@ test_suite(
|
||||
tests = [
|
||||
":benchmark_test",
|
||||
":codegen_test",
|
||||
":test_graph_tfadd_mlir_bridge_test",
|
||||
":test_graph_tfadd_test",
|
||||
":test_graph_tfunknownop2_mlir_bridge_test",
|
||||
":test_graph_tfunknownop2_test",
|
||||
":test_graph_tfunknownop3_mlir_bridge_test",
|
||||
":test_graph_tfunknownop3_test",
|
||||
":test_graph_tfunknownop_mlir_bridge_test",
|
||||
":test_graph_tfunknownop_test",
|
||||
"//tensorflow/compiler/aot/tests:all_tests",
|
||||
],
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -142,7 +143,7 @@ static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
}
|
||||
}
|
||||
|
||||
static std::once_flag targets_init;
|
||||
static absl::once_flag targets_init;
|
||||
|
||||
static void InitializeTargets() {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
@ -167,7 +168,7 @@ static void InitializeTargets() {
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
std::call_once(targets_init, &InitializeTargets);
|
||||
absl::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
|
@ -349,6 +349,18 @@ tf_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tffunction_mlir_bridge",
|
||||
testonly = 1,
|
||||
config = "test_graph_tffunction.config.pbtxt",
|
||||
cpp_class = "FunctionComp",
|
||||
graph = "test_graph_tffunction.pb",
|
||||
mlir_components = "Bridge",
|
||||
tags = [
|
||||
"manual",
|
||||
],
|
||||
)
|
||||
|
||||
tf_library(
|
||||
name = "test_graph_tfassert_eq_mlir_bridge",
|
||||
testonly = 1,
|
||||
@ -484,6 +496,7 @@ tf_cc_test(
|
||||
":test_graph_tfadd_with_ckpt_saver_mlir_bridge",
|
||||
":test_graph_tfassert_eq_mlir_bridge",
|
||||
":test_graph_tfcond_mlir_bridge",
|
||||
":test_graph_tffunction_mlir_bridge",
|
||||
":test_graph_tfgather_mlir_bridge",
|
||||
":test_graph_tfmatmul_mlir_bridge",
|
||||
":test_graph_tfmatmulandadd_mlir_bridge",
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
|
||||
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
|
||||
@ -429,8 +430,6 @@ TEST(TFCompileTest, MatMulAndAdd1) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bixia): the following tests failed with MLIR bridge.
|
||||
#if !defined(ENABLE_MLIR_BRIDGE_TEST)
|
||||
TEST(TFCompileTest, Function) {
|
||||
// The function is equivalent to an addition
|
||||
FunctionComp add_fn;
|
||||
@ -445,7 +444,6 @@ TEST(TFCompileTest, Function) {
|
||||
EXPECT_EQ(add_fn.result0_data()[0], 3);
|
||||
EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(TFCompileTest, Splits) {
|
||||
Eigen::ThreadPool tp(1);
|
||||
|
@ -57,6 +57,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -70,6 +71,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_cuda_or_rocm([
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||
@ -157,7 +159,9 @@ XLA_DEVICE_DEPS = [
|
||||
":common",
|
||||
":xla_launch_util",
|
||||
":xla_tensor",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||
@ -250,13 +254,26 @@ cc_library(
|
||||
}),
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "flags",
|
||||
srcs = ["flags.cc"],
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Header-only version of "flags" library, for linking from the shared object
|
||||
# without ODR violations.
|
||||
cc_library(
|
||||
name = "flags_headers_only",
|
||||
hdrs = ["flags.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -276,6 +293,8 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
# Internal targets below this point.
|
||||
|
||||
cc_library(
|
||||
name = "xla_launch_util",
|
||||
srcs = ["xla_launch_util.cc"],
|
||||
@ -397,6 +416,7 @@ cc_library(
|
||||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -625,6 +645,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
|
@ -266,9 +266,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall(
|
||||
s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()),
|
||||
&handle);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
std::string uncompilable_reason = "could not instantiate call";
|
||||
std::string uncompilable_reason =
|
||||
absl::StrCat("could not instantiate call: '", function.name(), "'");
|
||||
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
|
||||
encapsulating_function, uncompilable_nodes);
|
||||
VLOG(2) << "Rejecting " << call_def.DebugString() << ": "
|
||||
|
@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
|
||||
#include <mutex> // NOLINT
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -32,7 +35,7 @@ XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
std::once_flag flags_init;
|
||||
absl::once_flag flags_init;
|
||||
|
||||
bool SetterForXlaAutoJitFlag(const string& value) {
|
||||
int32 opt_level;
|
||||
@ -213,38 +216,45 @@ void AllocateAndParseFlags() {
|
||||
} // namespace
|
||||
|
||||
bool SetXlaAutoJitFlagFromFlagString(const string& value) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return SetterForXlaAutoJitFlag(value);
|
||||
}
|
||||
|
||||
BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return build_ops_flags;
|
||||
}
|
||||
|
||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mark_for_compilation_flags;
|
||||
}
|
||||
|
||||
XlaDeviceFlags* GetXlaDeviceFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return device_flags;
|
||||
}
|
||||
|
||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *ops_flags;
|
||||
}
|
||||
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags() {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
}
|
||||
|
||||
static bool xla_is_enabled = false;
|
||||
|
||||
void SetXlaIsEnabled() { xla_is_enabled = true; }
|
||||
|
||||
bool IsXlaEnabled() { return xla_is_enabled; }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -154,6 +154,15 @@ GetIntroduceFloatingPointJitterPassFlags();
|
||||
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
||||
void AppendMarkForCompilationPassFlags(
|
||||
std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
// Makes all future calls to `IsXlaEnabled()` return `true`.
|
||||
//
|
||||
// Should only be called when XLA is linked in.
|
||||
void SetXlaIsEnabled();
|
||||
|
||||
// Returns whether XLA is enabled.
|
||||
bool IsXlaEnabled();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -1616,8 +1617,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||
|
||||
if (!should_compile && global_jit_level_ != OptimizerOptions::OFF &&
|
||||
device_type.type_string() == DEVICE_CPU) {
|
||||
static std::once_flag once;
|
||||
std::call_once(once, [] {
|
||||
static absl::once_flag once;
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "(One-time warning): Not using XLA:CPU for cluster because envvar "
|
||||
"TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set. If you want "
|
||||
|
@ -163,12 +163,11 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
build_options.set_alias_passthrough_params(options.alias_passthrough_params);
|
||||
|
||||
auto compile_result =
|
||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||
if (!compile_result.ok()) {
|
||||
return compile_result.status();
|
||||
}
|
||||
*executable = std::move(compile_result.ValueOrDie());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto executables,
|
||||
client_->Compile(*result.computation, argument_layouts, build_options));
|
||||
TF_RET_CHECK(executables.size() == 1);
|
||||
*executable = std::move(executables[0]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||
#include "tensorflow/compiler/jit/xla_device_context.h"
|
||||
@ -386,14 +388,33 @@ Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Warn about XLA_CPU/XLA_GPU exactly once.
|
||||
static void ShowXlaDeviceDeprecationWarning(
|
||||
absl::string_view compilation_device_name) {
|
||||
static absl::once_flag once;
|
||||
if (absl::StrContains(compilation_device_name, "CPU") ||
|
||||
absl::StrContains(compilation_device_name, "GPU")) {
|
||||
absl::call_once(once, [] {
|
||||
LOG(WARNING)
|
||||
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
|
||||
"removed in subsequent releases. Instead, use either "
|
||||
"@tf.function(experimental_compile=True) for must-compile "
|
||||
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
|
||||
"for auto-clustering best-effort compilation.";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
op_kernel->Compute(context);
|
||||
}
|
||||
|
||||
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||
AsyncOpKernel::DoneCallback done) {
|
||||
ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
|
||||
VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
|
||||
<< op_kernel->type_string();
|
||||
op_kernel->ComputeAsync(context, done);
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
@ -39,6 +40,10 @@ bool RegisterLaunchOpCreator() {
|
||||
}
|
||||
|
||||
static bool register_me = RegisterLaunchOpCreator();
|
||||
static bool register_xla = [] {
|
||||
SetXlaIsEnabled();
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // end namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -143,11 +143,11 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
}
|
||||
string message = absl::StrCat(
|
||||
"Function invoked by the following node is not compilable: ",
|
||||
node_def.ShortDebugString(), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:\n");
|
||||
SummarizeNodeDef(node_def), ".\n");
|
||||
absl::StrAppend(&message, "Uncompilable nodes:");
|
||||
for (const auto& node_info : uncompilable_node_info) {
|
||||
string node_message =
|
||||
absl::StrCat("\t", node_info.name, ": ",
|
||||
absl::StrCat("\n", node_info.name, ": ",
|
||||
node_info.uncompilable_reason, "\n", "\tStacktrace:\n");
|
||||
for (const auto& stack_frame : node_info.stack_trace) {
|
||||
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
|
||||
@ -156,7 +156,6 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
absl::StrAppend(&message, node_message);
|
||||
}
|
||||
VLOG(1) << message;
|
||||
// node_def is calling a function that XLA can't compile.
|
||||
return errors::InvalidArgument(message);
|
||||
}
|
||||
|
||||
|
@ -66,6 +66,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_passes",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_to_quant",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/xla:hlo_xla_quantization_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
@ -77,10 +79,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_to_standard",
|
||||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
|
@ -26,9 +26,11 @@ package_group(
|
||||
filegroup(
|
||||
name = "tensorflow_lite_ops_td_files",
|
||||
srcs = [
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"ir/tfl_ops.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,10 +45,29 @@ gentbl(
|
||||
"-gen-op-defs",
|
||||
"ir/tfl_ops.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-struct-attr-decls",
|
||||
"ir/tfl_structs.h.inc",
|
||||
),
|
||||
(
|
||||
"-gen-struct-attr-defs",
|
||||
"ir/tfl_structs.cc.inc",
|
||||
),
|
||||
(
|
||||
"-gen-op-doc",
|
||||
"g3doc/tfl_ops.md",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "tensorflow_lite_op_interfaces_inc_gen",
|
||||
tbl_outs = [
|
||||
(
|
||||
"-gen-op-interface-decls",
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
@ -57,7 +78,7 @@ gentbl(
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/tfl_ops.td",
|
||||
td_file = "ir/tfl_op_interfaces.td",
|
||||
td_srcs = [
|
||||
":tensorflow_lite_ops_td_files",
|
||||
],
|
||||
@ -199,8 +220,6 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -209,6 +228,10 @@ cc_library(
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(jpienaar): Move this out after splitting out LoopLikeOpInterface.
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -267,6 +290,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/extract_ophint.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
@ -280,8 +304,10 @@ cc_library(
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
@ -291,15 +317,19 @@ cc_library(
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:tensor_list",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
@ -357,6 +387,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
@ -370,6 +401,24 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_lite_d2s",
|
||||
srcs = [
|
||||
"transforms/dense_to_sparse.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "generated_op_quant_spec_getters",
|
||||
srcs = [
|
||||
@ -381,6 +430,8 @@ genrule(
|
||||
name = "op_quant_spec_getters_inc",
|
||||
srcs = [
|
||||
"ir/tfl_ops.td",
|
||||
"ir/tfl_op_interfaces.td",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/LoopLikeInterface.td",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
|
||||
],
|
||||
outs = [
|
||||
@ -628,6 +679,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":tensorflow_lite_d2s",
|
||||
":tensorflow_lite_legalize_tf",
|
||||
":tensorflow_lite_optimize",
|
||||
":tensorflow_lite_quantize",
|
||||
|
@ -90,6 +90,7 @@ using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::NoneType;
|
||||
using mlir::Operation;
|
||||
using mlir::Region;
|
||||
using mlir::StringAttr;
|
||||
using mlir::TensorType;
|
||||
using mlir::TranslateFromMLIRRegistration;
|
||||
@ -309,7 +310,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef(
|
||||
static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
|
||||
::mlir::Operation* inst) {
|
||||
// We pass empty string for the original node_def name since Flex runtime
|
||||
// does not care about this being set correctly on node_def. There is no
|
||||
@ -425,6 +426,11 @@ class Translator {
|
||||
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Build while operator where cond & body are regions.
|
||||
Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
// Builds custom operators.
|
||||
// Templated on a) data type of custom_option to be stored into flatbuffer,
|
||||
// and b) TFL custom op type.
|
||||
@ -472,7 +478,10 @@ class Translator {
|
||||
Operation* inst, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results);
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(FuncOp fn);
|
||||
// Build a subgraph with a given name out of the region either corresponding
|
||||
// to a function's body or while op.
|
||||
Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
|
||||
const std::string& name, Region* region);
|
||||
|
||||
// Builds Metadata with the given `name` and buffer `content`.
|
||||
BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
|
||||
@ -523,7 +532,7 @@ class Translator {
|
||||
};
|
||||
|
||||
std::string Translator::UniqueName(mlir::Value val) {
|
||||
return name_mapper_.GetUniqueName(val);
|
||||
return std::string(name_mapper_.GetUniqueName(val));
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
@ -539,9 +548,14 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
attr = cst.value();
|
||||
} else {
|
||||
return empty_buffer_;
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor;
|
||||
auto status = tensorflow::ConvertToTensor(attr, &tensor);
|
||||
if (!status.ok()) {
|
||||
@ -595,6 +609,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
};
|
||||
|
||||
std::vector<int32_t> shape;
|
||||
std::vector<int32_t> shape_signature;
|
||||
if (type.hasStaticShape()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
@ -612,7 +627,25 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
|
||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
} else if (type.hasRank()) {
|
||||
llvm::ArrayRef<int64_t> shape_ref = type.getShape();
|
||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||
|
||||
shape.reserve(shape_ref.size());
|
||||
for (auto& dim : shape_ref) {
|
||||
shape.push_back(dim == -1 ? 1 : dim);
|
||||
}
|
||||
shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||
}
|
||||
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
} else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
|
||||
// CreateSparsityParameters(cst.s_param());
|
||||
}
|
||||
}
|
||||
|
||||
Type element_type = type.getElementType();
|
||||
tflite::TensorType tflite_element_type =
|
||||
GetTFLiteType(type.getElementType()).ValueOrDie();
|
||||
@ -649,10 +682,19 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||
break;
|
||||
}
|
||||
}
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
|
||||
if (shape_signature.empty()) {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable);
|
||||
} else {
|
||||
return tflite::CreateTensor(
|
||||
builder_, builder_.CreateVector(shape), tflite_element_type,
|
||||
(is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
|
||||
/*is_variable=*/is_variable, /*sparsity=*/0,
|
||||
/*shape_signature=*/builder_.CreateVector(shape_signature));
|
||||
}
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Operator> Translator::BuildIfOperator(
|
||||
@ -687,6 +729,32 @@ BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
|
||||
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
|
||||
const std::vector<int32_t>& results) {
|
||||
auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
|
||||
auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
|
||||
if (b.getOperations().size() != 2) return llvm::None;
|
||||
if (auto call_op = dyn_cast<mlir::CallOp>(b.front()))
|
||||
return subgraph_index_map_.at(call_op.callee().str());
|
||||
return llvm::None;
|
||||
};
|
||||
auto body_subgraph_index = get_call_index(op.body().front());
|
||||
auto cond_subgraph_index = get_call_index(op.cond().front());
|
||||
if (!body_subgraph_index || !cond_subgraph_index)
|
||||
return op.emitOpError("only single call cond/body while export supported"),
|
||||
llvm::None;
|
||||
auto builtin_options =
|
||||
tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
|
||||
*body_subgraph_index)
|
||||
.Union();
|
||||
auto inputs = builder_.CreateVector(operands);
|
||||
auto outputs = builder_.CreateVector(results);
|
||||
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
|
||||
tflite::BuiltinOptions_WhileOptions,
|
||||
builtin_options);
|
||||
}
|
||||
|
||||
template <typename CustomOptionType, typename TFLOp>
|
||||
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
|
||||
const CustomOptionType& custom_option, const std::string& opcode_name,
|
||||
@ -908,6 +976,16 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
return BuildMaxUnpooling2DOperator(inst, max_unpooling_op, operands,
|
||||
results);
|
||||
}
|
||||
if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
|
||||
if (inst->getNumOperands() != inst->getNumResults()) {
|
||||
inst->emitOpError(
|
||||
"number of operands and results don't match, only canonical "
|
||||
"TFL While supported");
|
||||
return llvm::None;
|
||||
}
|
||||
return BuildWhileOperator(whileOp, operands, results);
|
||||
}
|
||||
|
||||
inst->emitOpError("is not a supported TFLite op");
|
||||
return llvm::None;
|
||||
}
|
||||
@ -944,7 +1022,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
auto node_def = getTensorFlowNodeDef(inst);
|
||||
auto node_def = GetTensorFlowNodeDef(inst);
|
||||
if (!node_def) {
|
||||
return llvm::None;
|
||||
}
|
||||
@ -1047,9 +1125,12 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
|
||||
return absl::c_find(operand_indices, operand_index) != operand_indices.end();
|
||||
}
|
||||
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
|
||||
const std::string& name, Region* region) {
|
||||
bool has_input_attr = false;
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
|
||||
InitializeNamesFromAttribute(fn, &has_input_attr);
|
||||
}
|
||||
std::vector<BufferOffset<tflite::Tensor>> tensors;
|
||||
llvm::DenseMap<Value, int> tensor_index_map;
|
||||
|
||||
@ -1081,7 +1162,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
};
|
||||
|
||||
std::vector<BufferOffset<tflite::Operator>> operators;
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = region->front();
|
||||
|
||||
// Main function's arguments are first passed to `input` op so they don't
|
||||
// have associated tensor and buffer. Build FlatBuffer tensor and buffer for
|
||||
@ -1089,7 +1170,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
|
||||
mlir::BlockArgument arg = bb.getArgument(i);
|
||||
std::string name;
|
||||
if (has_input_attr) name = name_mapper_.GetUniqueName(arg);
|
||||
if (has_input_attr) name = std::string(name_mapper_.GetUniqueName(arg));
|
||||
if (name.empty()) name = absl::StrCat("arg", i);
|
||||
if (!build_tensor_and_buffer(arg, name)) return llvm::None;
|
||||
}
|
||||
@ -1141,7 +1222,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
||||
return tflite::CreateSubGraph(
|
||||
builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
|
||||
builder_.CreateVector(outputs), builder_.CreateVector(operators),
|
||||
/*name=*/builder_.CreateString(fn.getName().str()));
|
||||
/*name=*/builder_.CreateString(name));
|
||||
}
|
||||
|
||||
BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
|
||||
@ -1184,35 +1265,36 @@ Optional<std::string> Translator::Translate(
|
||||
}
|
||||
|
||||
Optional<std::string> Translator::TranslateInternal() {
|
||||
// Create a list of functions in the module with main function being the
|
||||
// first function in the list. This is required as the first subgraph in the
|
||||
// model is entry point for the model.
|
||||
std::vector<FuncOp> functions;
|
||||
functions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
// A list of named regions in the module with main function being the first in
|
||||
// the list. The main function is required as the first subgraph in the model
|
||||
// is entry point for the model.
|
||||
std::vector<std::pair<std::string, Region*>> named_regions;
|
||||
named_regions.reserve(std::distance(module_.begin(), module_.end()));
|
||||
|
||||
int subgraph_idx = 0;
|
||||
FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
|
||||
subgraph_index_map_[main_fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(main_fn);
|
||||
for (auto fn : module_.getOps<FuncOp>()) {
|
||||
if (fn == main_fn) continue;
|
||||
named_regions.emplace_back("main", &main_fn.getBody());
|
||||
// Walk over the module collection ops with functions and while ops.
|
||||
module_.walk([&](FuncOp fn) {
|
||||
if (fn != main_fn) {
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
named_regions.emplace_back(fn.getName().str(), &fn.getBody());
|
||||
}
|
||||
});
|
||||
|
||||
subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
|
||||
functions.push_back(fn);
|
||||
}
|
||||
|
||||
// Build subgraph for each of the functions.
|
||||
// Build subgraph for each of the named regions.
|
||||
std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
|
||||
subgraphs.reserve(functions.size());
|
||||
subgraphs.reserve(named_regions.size());
|
||||
int first_failed_func = -1;
|
||||
for (int i = 0; i < functions.size(); ++i) {
|
||||
auto subgraph_or = BuildSubGraph(functions[i]);
|
||||
for (auto it : llvm::enumerate(named_regions)) {
|
||||
auto subgraph_or = BuildSubGraph(it.value().first, it.value().second);
|
||||
if (!subgraph_or) {
|
||||
if (first_failed_func == -1)
|
||||
// Record the index of the first function that cannot be converted.
|
||||
// Record the index of the first region that cannot be converted.
|
||||
// Keep looping through all subgraphs in the module to make sure that
|
||||
// we collect the list of missing ops from the entire module.
|
||||
first_failed_func = i;
|
||||
first_failed_func = it.index();
|
||||
} else {
|
||||
subgraphs.push_back(*subgraph_or);
|
||||
}
|
||||
@ -1233,9 +1315,10 @@ Optional<std::string> Translator::TranslateInternal() {
|
||||
"-emit-custom-ops flag): " +
|
||||
failed_custom_ops_list;
|
||||
|
||||
return functions[first_failed_func].emitError("failed while converting: '")
|
||||
<< functions[first_failed_func].getName() << "\'\n"
|
||||
<< err,
|
||||
auto& failed_region = named_regions[first_failed_func];
|
||||
return failed_region.second->getParentOp()->emitError()
|
||||
<< "failed while converting: '" << failed_region.first
|
||||
<< "': " << err,
|
||||
llvm::None;
|
||||
}
|
||||
|
||||
|
74
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
74
tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation interface definition file for TensorFlow Lite.
|
||||
|
||||
#ifndef TFL_OP_INTERFACES
|
||||
#define TFL_OP_INTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
||||
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that are stateful and need to identify stateful operands.
|
||||
|
||||
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||
or more stateful operands is a stateful op.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of stateful operands.}],
|
||||
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for sparse operands.
|
||||
|
||||
def TFL_SparseOp : OpInterface<"SparseOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that support sparse computation.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of sparse operands.}],
|
||||
"std::vector<int>", "GetSparseOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
#endif // TFL_OP_INTERFACES
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
|
||||
namespace TFL {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -797,8 +798,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
||||
// With
|
||||
// %2 = "tfl.reshape"(%0, %shape1)
|
||||
rewriter.replaceOpWithNewOp<ReshapeOp>(
|
||||
{prevOp.getResult()}, op, thisOp.getType(), prevOp.getOperand(0),
|
||||
thisOp.getOperand(1));
|
||||
op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1));
|
||||
}
|
||||
};
|
||||
|
||||
@ -1302,6 +1302,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type result_type = getType();
|
||||
// Only constant fold for tensor of f32 is implemented.
|
||||
if (!IsF32ShapedType(result_type)) return nullptr;
|
||||
|
||||
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1724,6 +1737,67 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct WhileResultOperandsMatch : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(WhileOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto size = while_op.body().front().getArguments().size();
|
||||
Operation *op = while_op.getOperation();
|
||||
auto old_size = op->getNumResults();
|
||||
// No change needed as the number of operands match the number of results.
|
||||
if (size == old_size) return matchFailure();
|
||||
|
||||
// Collect the new types by combining results of old op with additional
|
||||
// operand results.
|
||||
llvm::SmallVector<Type, 4> types;
|
||||
types.reserve(size);
|
||||
for (auto type : while_op.getResultTypes()) types.push_back(type);
|
||||
for (auto arg : while_op.body().front().getArguments().drop_front(old_size))
|
||||
types.push_back(arg.getType());
|
||||
// Collect operands.
|
||||
llvm::SmallVector<Value, 8> operands;
|
||||
operands.reserve(while_op.getNumOperands());
|
||||
for (auto operand : while_op.getOperands()) operands.push_back(operand);
|
||||
|
||||
// Replace with new While with matching operands and results.
|
||||
Operation *new_op = rewriter.insert(
|
||||
Operation::create(op->getLoc(), op->getName(), types, operands,
|
||||
op->getAttrs(), {}, /*numRegions=*/2,
|
||||
/*resizableOperandList=*/true));
|
||||
for (int i = 0; i < 2; ++i) new_op->getRegion(i).takeBody(op->getRegion(i));
|
||||
rewriter.replaceOp(op,
|
||||
new_op->getResults().take_front(op->getNumResults()));
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<WhileResultOperandsMatch>(context);
|
||||
}
|
||||
|
||||
Region &WhileOp::getLoopBody() { return body(); }
|
||||
|
||||
bool WhileOp::isDefinedOutsideOfLoop(Value value) {
|
||||
// TODO(jpienaar): This is to overly conservative and disables anything other
|
||||
// than constant hoisting initially.
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef<mlir::Operation *> ops) {
|
||||
if (ops.empty()) return success();
|
||||
|
||||
// Move the hoisted value to just before the while.
|
||||
Operation *while_op = this->getOperation();
|
||||
for (auto op : ops) op->moveBefore(while_op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -27,10 +27,12 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/LoopLikeInterface.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.h.inc"
|
||||
namespace TFL {
|
||||
|
||||
class TensorFlowLiteDialect : public Dialect {
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#define TFL_OPS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Transforms/LoopLikeInterface.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
|
||||
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
|
||||
|
||||
def TFL_Dialect : Dialect {
|
||||
@ -248,41 +250,6 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
|
||||
buildComparisonBinOp(builder, result, lhs, rhs);
|
||||
}]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for stateful operands.
|
||||
|
||||
def TFL_StatefulOp : OpInterface<"StatefulOpInterface"> {
|
||||
let description = [{
|
||||
Interface for ops that are stateful and need to identify stateful operands.
|
||||
|
||||
Stateful operands correspond to TF's variables semantics. An op that has 1
|
||||
or more stateful operands is a stateful op.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the indices of stateful operands.}],
|
||||
"std::vector<int>", "GetStatefulOperands", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op interface for output channel index.
|
||||
|
||||
def TFL_ChannelDimIndexInterface : OpInterface<"ChannelDimIndexInterface"> {
|
||||
let description = [{
|
||||
Interface for defining the index of out channel index.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
[{Returns the dimension index of the output channels.}],
|
||||
"int", "GetChannelDimIndex", (ins)
|
||||
>,
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TFL op base class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -583,14 +550,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
|
||||
let arguments = (
|
||||
ins Variadic<TensorOf<
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>>:$values,
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>>:$values,
|
||||
I32Attr:$axis,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output
|
||||
[F32, I64, I32, I16, I8, QI8, QUI8, QI16, TFL_Uint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -627,6 +594,57 @@ def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
|
||||
];
|
||||
}
|
||||
|
||||
// Attributes used for encoding sparse tensors.
|
||||
// Please find detailed explanation of these parameters in the TFLite schema.
|
||||
def TFL_DT_Dense : StrEnumAttrCase<"DENSE", 0>;
|
||||
def TFL_DT_SparseCSR : StrEnumAttrCase<"SPARSE_CSR", 1>;
|
||||
|
||||
def TFL_DimensionTypeAttr : StrEnumAttr<
|
||||
"DimensionType", "dimension type", [TFL_DT_Dense, TFL_DT_SparseCSR]>;
|
||||
|
||||
def DimensionMetadataAttr : StructAttr<"DimensionMetadataAttr", TFL_Dialect, [
|
||||
StructFieldAttr<"format", TFL_DimensionTypeAttr>,
|
||||
StructFieldAttr<"dense_size", I32Attr>,
|
||||
StructFieldAttr<"segments", I32ArrayAttr>,
|
||||
StructFieldAttr<"indices", I32ArrayAttr>] > {
|
||||
let description = "Dimension metadata.";
|
||||
}
|
||||
|
||||
def DimensionMetadataArrayAttr : TypedArrayAttrBase<DimensionMetadataAttr,
|
||||
"Array of DimensionMetadata">{}
|
||||
|
||||
def SparsityParameterAttr : StructAttr<"SparsityParameterAttr", TFL_Dialect, [
|
||||
StructFieldAttr<"traversal_order", I32ArrayAttr>,
|
||||
StructFieldAttr<"block_map", I32ArrayAttr>,
|
||||
StructFieldAttr<"dim_metadata", DimensionMetadataArrayAttr>]> {
|
||||
let description = "Sparsity parameter.";
|
||||
let storageType = [{ TFL::SparsityParameterAttr }];
|
||||
}
|
||||
|
||||
def TFL_SparseConstOp : Op<TFL_Dialect, "pseudo_sparse_const", [NoSideEffect,
|
||||
FirstAttrDerivedResultType]> {
|
||||
let summary = "Sparse constant pseudo op.";
|
||||
|
||||
let description = [{
|
||||
Represents a sparse constant value in TensorFlow Lite dialect. This is not
|
||||
an actual operation and it will be lowered to buffer instead.
|
||||
}];
|
||||
|
||||
let arguments = (ins ElementsAttr:$value, SparsityParameterAttr:$s_param);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, Attribute value, "
|
||||
"SparsityParameterAttr s_param",
|
||||
[{
|
||||
state.addTypes(value.getType());
|
||||
state.addAttribute("value", value);
|
||||
state.addAttribute("s_param", s_param);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_ExternalConstOp : Op<TFL_Dialect, "external_const", [NoSideEffect]> {
|
||||
let summary = "External const op.";
|
||||
|
||||
@ -685,7 +703,8 @@ def TFL_FullyConnectedOptionsWeightFormatAttr :
|
||||
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>]> {
|
||||
AffineOpCoefficient<-1, 1>,
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
@ -710,6 +729,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
}];
|
||||
}
|
||||
|
||||
@ -718,7 +739,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasAtleastRank<0, 1>,
|
||||
PredOpTrait<"params and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
let summary = "Gather operator";
|
||||
|
||||
@ -727,7 +748,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params,
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$params,
|
||||
TensorOf<[I32, I64]>:$indices,
|
||||
I32Attr:$axis
|
||||
);
|
||||
@ -740,7 +761,7 @@ def TFL_GatherOp : TFL_Op<"gather", [
|
||||
];
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8]>:$output
|
||||
TensorOf<[F32, I1, I8, I32, I64, TFL_Str, QI8, QUI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -1102,7 +1123,8 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let hasOptions = 0b1;
|
||||
}
|
||||
|
||||
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
|
||||
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Inserts a dimension of 1 into a tensor's shape.";
|
||||
|
||||
let description = [{
|
||||
@ -1694,7 +1716,8 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
|
||||
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Min-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1713,7 +1736,8 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
|
||||
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Max-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1810,6 +1834,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let results = (outs AnyTensor:$y);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
@ -1843,14 +1869,14 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>>:$values,
|
||||
Variadic<TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>>:$values,
|
||||
|
||||
I32Attr:$values_count,
|
||||
I32Attr:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8]>:$output
|
||||
TensorOf<[F32, I8, I16, I32, I64, QI8, QUI8, QI16]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2466,7 +2492,7 @@ def TFL_TransposeOp : TFL_Op<"transpose",
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
|
||||
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
let summary = "Unpacks a tensor along a dimension into multiple tensors";
|
||||
|
||||
let description = [{
|
||||
@ -2632,12 +2658,12 @@ def TFL_SplitOp : TFL_Op<"split", [
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[I32]>:$split_dim,
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
|
||||
PositiveI32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2655,14 +2681,14 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value,
|
||||
TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>:$value,
|
||||
1DTensorOf<[I32]>:$size_splits,
|
||||
0DTensorOf<[I32]>:$split_dim,
|
||||
PositiveI32Attr:$num_splits
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8]>>:$outputs
|
||||
Variadic<TensorOf<[F32, I16, I32, I64, QI8, QUI8, QI16]>>:$outputs
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -2793,12 +2819,11 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
Casts input from input type to output type.
|
||||
}];
|
||||
|
||||
// TODO(b/135538711): Add complex types here.
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8]>:$input
|
||||
TensorOf<[F32, I1, I32, I64, TFL_Quint8, TFL_Uint8, Complex<F<32>>]>:$input
|
||||
);
|
||||
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
|
||||
let results = (outs TensorOf<[F32, I1, I32, I64, Complex<F<32>>]>:$output);
|
||||
|
||||
// TFLite's cast op does not utilize CastOptions, instead derives types
|
||||
// from the TfLiteTensors.
|
||||
@ -2898,7 +2923,9 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
// The expected [min, max] range of values.
|
||||
MinMaxAttr:$minmax,
|
||||
F32Attr:$min,
|
||||
F32Attr:$max,
|
||||
|
||||
// The bitwidth of the quantization; between 2 and 16, inclusive.
|
||||
I32Attr:$num_bits,
|
||||
// Quantization range starts from 0 or 1; starts from 1 if true.
|
||||
@ -2907,6 +2934,8 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let hasCanonicalizer = 0b1;
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
@ -2936,6 +2965,36 @@ def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
|
||||
NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> {
|
||||
let summary = "Sparse quantized constant pseudo op";
|
||||
|
||||
let description = [{
|
||||
Represents a sparse quantized constant value in TensorFlow Lite dialect.
|
||||
This is not an actual operation and it will be lowered to buffer instead.
|
||||
The quantization parameters are stored as a type attribute in this constant.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TensorTypeAttr:$qtype,
|
||||
ElementsAttr:$value,
|
||||
SparsityParameterAttr:$s_param
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &state, TypeAttr qtype, "
|
||||
"Attribute value, SparsityParameterAttr s_param",
|
||||
[{
|
||||
state.addTypes(qtype.getValue());
|
||||
state.addAttribute("qtype", qtype);
|
||||
state.addAttribute("value", value);
|
||||
state.addAttribute("s_param", s_param);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
FirstAttrDerivedResultType, NoQuantizableResult]> {
|
||||
let summary = "Quantize operator";
|
||||
@ -3396,4 +3455,48 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
let results = (outs TensorOf<[F32, I32]>:$output);
|
||||
}
|
||||
|
||||
def TFL_YieldOp : Op<TFL_Dialect, "yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
let description = [{
|
||||
The "yield" operation represents a return operation within the conditional
|
||||
and body of structured control flow (e.g., while). The operation takes
|
||||
variable number of operands and produces no results. The operand number and
|
||||
types must match the signature of the region that contains the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
}
|
||||
|
||||
def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
|
||||
SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = [{While loop}];
|
||||
|
||||
let description = [{
|
||||
output = input; while (cond(output)) { output = body(output) }
|
||||
|
||||
While loop where all values are passes through arguments with implicit
|
||||
capture.
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A region takes 'input' and returns a boolean scalar tensor.
|
||||
body: A region that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyTensor>:$input,
|
||||
|
||||
// Used to map StatelessWhile and While op defined in TensorFlow to a common
|
||||
// op.
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
|
||||
);
|
||||
let results = (outs Variadic<AnyTensor>:$output);
|
||||
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
#endif // TFL_OPS
|
||||
|
@ -41,13 +41,20 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/flex/delegate.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/optional_debug_tools.h"
|
||||
|
||||
using llvm::cl::desc;
|
||||
using llvm::cl::init;
|
||||
using llvm::cl::opt;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> inputFileName(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
static opt<std::string> input_filename(llvm::cl::Positional,
|
||||
desc("<input file>"), init("-"));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool> dump_state("dump-interpreter-state",
|
||||
desc("dump interpreter state post execution"),
|
||||
init(false));
|
||||
|
||||
// TODO(jpienaar): Move these functions to some debug utils.
|
||||
static std::string TfLiteTensorDimString(const TfLiteTensor& tensor) {
|
||||
@ -82,9 +89,9 @@ int main(int argc, char** argv) {
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n");
|
||||
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str());
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.c_str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << inputFileName
|
||||
LOG(ERROR) << argv[0] << ": could not open input file '" << input_filename
|
||||
<< "': " << error.message() << "\n";
|
||||
return 1;
|
||||
}
|
||||
@ -133,5 +140,7 @@ int main(int argc, char** argv) {
|
||||
TfLiteTensorString(out).c_str());
|
||||
}
|
||||
|
||||
if (dump_state) tflite::PrintInterpreterState(interpreter.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper,
|
||||
os << formatv(
|
||||
" auto {0} = Convert{1}ForOptionWriter(op.{0}(), fbb);\n",
|
||||
val.getName(), record->getClasses()[0]->getName());
|
||||
options.push_back(val.getName());
|
||||
options.push_back(std::string(val.getName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -71,18 +71,17 @@ cc_library(
|
||||
"quantization_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"quantization_traits.h",
|
||||
"quantization_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
# TODO(fengliuai): remove this dependence.
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,6 +12,7 @@ package_group(
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
"//tensorflow/lite/...",
|
||||
],
|
||||
)
|
||||
@ -41,6 +42,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfl_to_std",
|
||||
srcs = [
|
||||
"tfl_to_std.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tfl_to_std.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
# Binary to apply quantization on the annotated files.
|
||||
tf_cc_binary(
|
||||
name = "tfl_quantizer",
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<DequantizeOp>(op)) {
|
||||
auto dcast = b.create<quant::DequantizeCastOp>(
|
||||
dq.getLoc(), dq.output().getType(), dq.input());
|
||||
dq.output().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<QuantizeOp>(op)) {
|
||||
auto qcast = b.create<quant::QuantizeCastOp>(
|
||||
q.getLoc(), q.output().getType(), q.input());
|
||||
q.output().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
|
||||
OpBuilder b(func);
|
||||
func.walk([&](Operation* op) {
|
||||
b.setInsertionPoint(op);
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(op)) {
|
||||
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
|
||||
dq.arg());
|
||||
dq.getResult().replaceAllUsesWith(dcast);
|
||||
dq.erase();
|
||||
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
|
||||
auto out_type = q.getResult().getType();
|
||||
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
|
||||
TypeAttr::get(out_type));
|
||||
q.getResult().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
34
tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant
|
||||
// dialect ones in the function.
|
||||
void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func);
|
||||
|
||||
// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize
|
||||
// ops in the function.
|
||||
void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_
|
@ -22,21 +22,6 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/QuantOps/QuantPredicates.td"
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Min-max range pair definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A pair of floating point values which defines the min and max of a value
|
||||
// range for quantization. The attribute is allowed to be empty or
|
||||
// have 2 elements.
|
||||
def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
|
||||
CPred<"$_self.cast<ArrayAttr>().size() == 2">]>,
|
||||
"min-max range pair"> {
|
||||
let storageType = [{ ArrayAttr }];
|
||||
let returnType = [{ ArrayRef<Attribute> }];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantizedType definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
@ -34,13 +36,14 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#define DEBUG_TYPE "quantization-driver"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
namespace {
|
||||
static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); }
|
||||
|
||||
@ -281,6 +284,37 @@ class QuantizationDriver {
|
||||
cached.first->second = InitializeState(op, index, res, /*as_result=*/true);
|
||||
}
|
||||
|
||||
void DumpStates(Operation *current_op) {
|
||||
if (current_op) {
|
||||
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
|
||||
}
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
for (auto i = 0; i < op->getNumOperands(); ++i) {
|
||||
if (auto params = GetOperandQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getOperand(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ") -> (";
|
||||
for (auto i = 0; i < op->getNumResults(); ++i) {
|
||||
if (auto params = GetResultQuantState(op, i).params)
|
||||
params.print(llvm::errs());
|
||||
else
|
||||
op->getResult(i).getType().cast<ShapedType>().getElementType().print(
|
||||
llvm::errs());
|
||||
llvm::errs() << ",";
|
||||
}
|
||||
llvm::errs() << ")\n";
|
||||
});
|
||||
}
|
||||
|
||||
FuncOp fn_;
|
||||
OpBuilder builder_;
|
||||
bool is_signed_;
|
||||
@ -350,7 +384,7 @@ int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||
}
|
||||
|
||||
bool QuantizationDriver::SetConstantResultParams(Operation *op) {
|
||||
ElementsAttr attr;
|
||||
DenseFPElementsAttr attr;
|
||||
Value res = op->getResult(0);
|
||||
if (!matchPattern(res, m_Constant(&attr))) {
|
||||
return false;
|
||||
@ -457,11 +491,9 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto quantize =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
auto dequantize = builder_.create<TFL::DequantizeOp>(loc, expressed_type,
|
||||
quantize.output());
|
||||
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
auto dequantize = builder_.create<quant::DequantizeCastOp>(
|
||||
loc, expressed_type, quantize.getResult());
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
@ -475,7 +507,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
||||
Value value = op->getResult(index);
|
||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||
Operation *user = value.getUses().begin().getUser();
|
||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(user)) {
|
||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||
value = user->getResult(0);
|
||||
builder_.setInsertionPointAfter(user);
|
||||
@ -490,8 +522,8 @@ void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||
builder_.setInsertionPointToStart(arg.getOwner());
|
||||
if (value.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||
}
|
||||
}
|
||||
@ -518,9 +550,8 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
||||
// This value isn't an expressed type (float), skip.
|
||||
if (!new_type) return;
|
||||
|
||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||
auto requantize_op =
|
||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||
builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
|
||||
value.replaceAllUsesWith(requantize_op);
|
||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||
}
|
||||
@ -650,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// If the argument is quantized, it should only has one user.
|
||||
if (arg.hasOneUse()) {
|
||||
auto user = value.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
value = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
value = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeArgState(arg, value, &value_to_state);
|
||||
@ -659,7 +690,9 @@ void QuantizationDriver::SetupAllStates() {
|
||||
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op))
|
||||
return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
@ -668,8 +701,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
if (auto *inst = operand.getDefiningOp()) {
|
||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||
// input of this tfl.dequantize op to set the state.
|
||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||
operand = dq.input();
|
||||
if (auto dq = llvm::dyn_cast<quant::DequantizeCastOp>(inst)) {
|
||||
operand = dq.arg();
|
||||
}
|
||||
}
|
||||
InitializeOperandState(op, i, operand, &value_to_state);
|
||||
@ -682,8 +715,8 @@ void QuantizationDriver::SetupAllStates() {
|
||||
// create the state and mark it immutable.
|
||||
if (result.hasOneUse()) {
|
||||
auto user = result.use_begin().getUser();
|
||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||
result = q.output();
|
||||
if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(user)) {
|
||||
result = q.getResult();
|
||||
}
|
||||
}
|
||||
InitializeResultState(op, res, result, &value_to_state);
|
||||
@ -713,6 +746,8 @@ bool QuantizationDriver::PropagateParams() {
|
||||
Operation *op = work_list_.back();
|
||||
work_list_.pop_back();
|
||||
|
||||
LLVM_DEBUG(DumpStates(op));
|
||||
|
||||
// This op has been quantized, so we should not consider it again.
|
||||
if (llvm::is_contained(quantized_, op)) continue;
|
||||
quantized_.insert(op);
|
||||
@ -737,12 +772,23 @@ bool QuantizationDriver::PropagateParams() {
|
||||
}
|
||||
|
||||
// Use the final state to set all the operands' parameters.
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i)
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (auto type = op->getOperand(i).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetOperandParams(op, i, params);
|
||||
}
|
||||
}
|
||||
|
||||
// Use the final state to set all the results' parameters.
|
||||
for (int res = 0, e = op->getNumResults(); res != e; ++res)
|
||||
changed |= SetResultParams(op, res, params);
|
||||
if (auto type = op->getResult(res).getType().dyn_cast<ShapedType>()) {
|
||||
// Without this check, it will accidently propagate the quantization
|
||||
// information by the shared non-float-tensors.
|
||||
if (type.getElementType().isa<FloatType>())
|
||||
changed |= SetResultParams(op, res, params);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fengliuai): make the bit width configurable.
|
||||
@ -821,5 +867,5 @@ void ApplyQuantizationParamsPropagation(
|
||||
.Run();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
@ -70,7 +70,8 @@ class FixedResultUniformScale {
|
||||
QuantizedType GetResultQuantizedType(int index) {
|
||||
auto op = this->getOperation();
|
||||
auto result_type =
|
||||
op->getResult(index).getType().template cast<TensorType>();
|
||||
op->getResult(index).getType().template cast<ShapedType>();
|
||||
if (!result_type.getElementType().template isa<FloatType>()) return {};
|
||||
Builder builder(op->getContext());
|
||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||
const double scale = static_cast<double>(ScaleMantissa) *
|
||||
|
@ -30,10 +30,9 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
|
||||
const float kNearZeroTolerance = 1.0e-6;
|
||||
|
||||
@ -400,7 +399,7 @@ static bool PreferResultScale(Operation* op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||
if (operand_type.getElementType().isa<FloatType>()) {
|
||||
if (float_operands++ > 1) return true;
|
||||
if (++float_operands > 1) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -460,7 +459,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
}
|
||||
|
||||
// Step 2: backward pass: For the ops skiped in the forward pass, propagate
|
||||
// its results scale backwards.
|
||||
// its results scale backwards as far as possible.
|
||||
func.walk([&](quant::StatisticsOp stats_op) {
|
||||
if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
|
||||
all_stats_ops.push_back(stats_op);
|
||||
@ -472,8 +471,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
all_stats_ops.pop_back();
|
||||
|
||||
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||
PreferResultScale(def)) {
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>()) {
|
||||
for (auto input : def->getOperands()) {
|
||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||
input.getDefiningOp())) {
|
||||
@ -496,5 +494,5 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
// Returns false if the steps finish without errors.
|
||||
return false;
|
||||
}
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace quant {
|
||||
|
||||
using QuantParams = quant::QuantizedType;
|
||||
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
|
||||
@ -113,8 +113,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
||||
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
Type result_type = quant_type.castFromExpressedType(op.getType());
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||
TypeAttr::get(result_type));
|
||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg());
|
||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||
op.getResult().replaceAllUsesWith(dq);
|
||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||
@ -168,9 +167,12 @@ struct QuantizationPattern : public RewritePattern {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// If it is terminator or not quantizable, we shouldn't rewrite.
|
||||
// If it is terminator or not quantizable or any ops form the mlir quant
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
if (quantized_op->isKnownTerminator() ||
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
@ -316,7 +318,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(Q op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Type output_type = op.output().getType();
|
||||
Type output_type = op.getResult().getType();
|
||||
auto qtype = QType::getQuantizedElementType(output_type);
|
||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||
|
||||
@ -355,8 +357,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
if (!new_qtype) return this->matchFailure();
|
||||
Type new_output_type = new_qtype.castFromExpressedType(
|
||||
QType::castToExpressedType(output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.input(),
|
||||
TypeAttr::get(new_output_type));
|
||||
rewriter.replaceOpWithNewOp<Q>(op, new_output_type, op.arg());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -444,7 +445,7 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed,
|
||||
bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
||||
OpQuantSpecGetter op_quant_spec_getter);
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace quant
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_
|
||||
|
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
36
tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD
Normal file
@ -0,0 +1,36 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_to_quant",
|
||||
srcs = [
|
||||
"tf_to_quant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Legalize the tf ops to the quant ops, so the quantization passes can work.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass();
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_
|
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,148 @@
|
||||
// RUN: tf-opt -tf-to-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: fakeQuantPerChannelForActivation
|
||||
func @fakeQuantPerChannelForActivation(%arg0: tensor<8x3xf32>) -> (tensor<8x3xf32>) {
|
||||
%arg1 = constant dense<[0.0, -1.0, 1.0]> : tensor<3xf32>
|
||||
%arg2 = constant dense<[255.0, 254.0, 256.0]> : tensor<3xf32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<8x3xf32>
|
||||
return %0 : tensor<8x3xf32>
|
||||
|
||||
// CHECK: %[[fq:.*]] = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %[[q:.*]] = "quant.qcast"(%[[fq]]) : (tensor<8x3xf32>) -> tensor<8x3x!quant.uniform<i8:f32:1, {1.000000e+00:-128,1.000000e+00:-127,1.000000e+00:-128}>>
|
||||
// CHECK: %[[dq:.*]] = "quant.dcast"(%[[q]])
|
||||
// CHECK: return %[[dq]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivation
|
||||
func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %0 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0)
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %2 = "quant.dcast"(%1)
|
||||
// CHECK: return %2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantForActivationNoDuplication
|
||||
func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>) {
|
||||
^bb0(%arg0: tensor<8xf32>):
|
||||
%arg1 = constant dense<0.0> : tensor<f32>
|
||||
%arg2 = constant dense<255.0> : tensor<f32>
|
||||
%0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
%1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
return %1 : tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64}
|
||||
// CHECK: %1 = "quant.qcast"(%0) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantFolded
|
||||
func @fakeQuantFolded() -> (tensor<8xf32>) {
|
||||
%in = constant dense<0.0> : tensor<8xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %rst : tensor<8xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT]]) : (tensor<8xf32>) -> tensor<8x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: return %[[DEQUANTIZE]] : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantNotFolded
|
||||
func @fakeQuantNotFolded(tensor<8xf32>, tensor<f32>, tensor<f32>) -> (tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8xf32>, %arg3: tensor<f32>, %arg4: tensor<f32>):
|
||||
%1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor<f32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %1 : tensor<8xf32>
|
||||
|
||||
// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2)
|
||||
// CHECK: return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithConv2D
|
||||
func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithConv2D
|
||||
func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]] : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fakeQuantWithDepthwiseConv2D
|
||||
func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<f32>
|
||||
%max = constant dense<255.0> : tensor<f32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<f32>) -> tensor<f32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<f32>) -> tensor<f32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<f32>, tensor<f32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32, 1.000000e+00:-128>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D
|
||||
func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) {
|
||||
^bb0(%arg: tensor<256x32x32x3xf32>) :
|
||||
%in = constant dense<0.0> : tensor<3x3x3x16xf32>
|
||||
%min = constant dense<0.0> : tensor<16xf32>
|
||||
%max = constant dense<255.0> : tensor<16xf32>
|
||||
%mini = "tf.Identity"(%min) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%maxi = "tf.Identity"(%max) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
%fq = "tf.FakeQuantWithMinMaxVarsPerChannel"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<3x3x3x16xf32>
|
||||
%rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
return %rst : tensor<256x30x30x16xf32>
|
||||
|
||||
// CHECK: %[[CONSTANT0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3x3x3x16xf32>}
|
||||
// CHECK: %[[QUANTIZE:.*]] = "quant.qcast"(%[[CONSTANT0]]) : (tensor<3x3x3x16xf32>) -> tensor<3x3x3x16x!quant.uniform<i8:f32:3,
|
||||
// CHECK-SAME: {1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,
|
||||
// CHECK-SAME: 1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128,1.000000e+00:-128}>>
|
||||
// CHECK: %[[DEQUANTIZE:.*]] = "quant.dcast"(%[[QUANTIZE]])
|
||||
// CHECK: %[[CONV:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZE]])
|
||||
// CHECK: return %[[CONV]]
|
||||
}
|
@ -0,0 +1,162 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to legalize the quantization emulation ops from TF.
|
||||
//
|
||||
namespace {
|
||||
|
||||
// Legalize TF quantization emulation ops to that in Quant ops dialect.
|
||||
struct LegalizeTFToQuant : public FunctionPass<LegalizeTFToQuant> {
|
||||
explicit LegalizeTFToQuant() = default;
|
||||
LegalizeTFToQuant(const LegalizeTFToQuant &) {}
|
||||
|
||||
/// Performs the lowering to Quant ops dialect.
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// TODO(b/140968741): propagate the sign from the command line. Currently all
|
||||
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
|
||||
// actually INT8.
|
||||
// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant
|
||||
// folding logic will use a "std.constant" op to replace the
|
||||
// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve
|
||||
// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to
|
||||
// convert the output type to the next op. Here are the transformations:
|
||||
//
|
||||
// input min cst max cst input min cst max cst
|
||||
// \ | | \ | |
|
||||
// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity)
|
||||
// \ | | \ | |
|
||||
// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars
|
||||
// | |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
// If the input is a constant, the result pattern will eventually converted to
|
||||
//
|
||||
// quant-emulated input
|
||||
// |
|
||||
// tf.quantize
|
||||
// |
|
||||
// tf.dequantize
|
||||
// |
|
||||
template <typename TFFakeQuantOp, bool PerAxis>
|
||||
struct InsertQuantOpsAfterTFFakeQuantOp
|
||||
: public OpRewritePattern<TFFakeQuantOp> {
|
||||
using BaseType = InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>;
|
||||
|
||||
explicit InsertQuantOpsAfterTFFakeQuantOp<TFFakeQuantOp, PerAxis>(
|
||||
MLIRContext *ctx)
|
||||
: OpRewritePattern<TFFakeQuantOp>(ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFFakeQuantOp tf_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||
auto res = tf_op.outputs();
|
||||
if (!res.hasOneUse() || isa<quant::QuantizeCastOp>(*res.user_begin()))
|
||||
return this->matchFailure();
|
||||
|
||||
// Extract the min/max constant values from the operands. We also consider
|
||||
// a special case that there are tf.Identity ops between the min/max
|
||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||
Value min = tf_op.min(), max = tf_op.max();
|
||||
DenseFPElementsAttr min_value, max_value;
|
||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp())) {
|
||||
id1.replaceAllUsesWith(id1.input());
|
||||
min = tf_op.min();
|
||||
rewriter.eraseOp(id1);
|
||||
}
|
||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp())) {
|
||||
id2.replaceAllUsesWith(id2.input());
|
||||
max = tf_op.max();
|
||||
rewriter.eraseOp(id2);
|
||||
}
|
||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||
|
||||
int quant_dim = -1;
|
||||
if (PerAxis) {
|
||||
// This is a special case that the quant_dim is the last dimensions
|
||||
// according to the tf.FakeQuantWithMinMaxPerChannel.
|
||||
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||
}
|
||||
// Use the min/max from the operands and the num_bits and narrow_range
|
||||
// attribute to create the quantization parameter for the new quantize op.
|
||||
rewriter.setInsertionPointAfter(tf_op);
|
||||
IntegerAttr num_bits =
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/true);
|
||||
if (!qtype) this->matchFailure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
// dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp
|
||||
// and its users.
|
||||
Value value = tf_op.outputs();
|
||||
auto quantize = rewriter.create<quant::QuantizeCastOp>(
|
||||
tf_op.getLoc(), qtype.getValue(), value);
|
||||
auto dequantize = rewriter.create<quant::DequantizeCastOp>(
|
||||
tf_op.getLoc(), res_type, quantize.getResult());
|
||||
value.replaceAllUsesWith(dequantize);
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
|
||||
return this->matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
using PreparePerTensorFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsOp, false>;
|
||||
|
||||
using PreparePerChannelFakeQuant =
|
||||
InsertQuantOpsAfterTFFakeQuantOp<TF::FakeQuantWithMinMaxVarsPerChannelOp,
|
||||
true>;
|
||||
|
||||
// TODO(fengliuai): add the support of the tf.QuantizeAndDequantize*
|
||||
// legalization.
|
||||
|
||||
void LegalizeTFToQuant::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
auto *ctx = func.getContext();
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow dialect to QuantOps dialect pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateLegalizeTFToQuantPass() {
|
||||
return std::make_unique<LegalizeTFToQuant>();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTFToQuant> pass(
|
||||
"tf-to-quant", "Legalize TF to quant ops dialect");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -46,9 +46,9 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
std::vector<Record *> defs = records.getAllDerivedDefinitions("Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
OUT(0) << "static std::unique_ptr<OpQuantSpec> "
|
||||
OUT(0) << "static std::unique_ptr<quant::OpQuantSpec> "
|
||||
"GetOpQuantSpec(mlir::Operation *op) {\n";
|
||||
OUT(2) << "auto spec = absl::make_unique<OpQuantSpec>();\n";
|
||||
OUT(2) << "auto spec = absl::make_unique<quant::OpQuantSpec>();\n";
|
||||
llvm::SmallVector<llvm::StringRef, 3> matches;
|
||||
for (auto *def : defs) {
|
||||
Operator op(def);
|
||||
@ -74,7 +74,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) {
|
||||
if (acc_uniform_trait_regex.match(trait_str, &matches)) {
|
||||
OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1]
|
||||
<< ", std::make_pair(tfl.GetAllNonBiasOperands(),"
|
||||
<< "GetUniformQuantizedTypeForBias)));\n";
|
||||
<< "quant::GetUniformQuantizedTypeForBias)));\n";
|
||||
matches.clear();
|
||||
}
|
||||
// There is a "QuantChannelDim" trait, set the quantization dimension.
|
||||
|
40
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
40
tensorflow/compiler/mlir/lite/quantization/xla/BUILD
Normal file
@ -0,0 +1,40 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//tensorflow/compiler/mlir/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_xla_quantization_passes",
|
||||
srcs = [
|
||||
"materialize.cc",
|
||||
"op_quant_spec.inc",
|
||||
"propagate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"passes.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
174
tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
Normal file
174
tensorflow/compiler/mlir/lite/quantization/xla/materialize.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass quantize the constant and rewrite the quantization
|
||||
// ops by xla_hlo primitive ops.
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The pass to materialize the quantization results by xla primitive ops.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pattern matches the "constant->qcast->dcast" pattern and replaces it by
|
||||
// "quantized constant->xla_hlo.dequantize". If it only matches the
|
||||
// "non-constant->qcast->dcast" pattern, it will remove both the "qcast->dcast".
|
||||
// We chain the pattern as a whole to bypass the type checks of the normal
|
||||
// xla_hlo ops.
|
||||
// TODO(fengliuai): make this pass work for bf16 input.
|
||||
class RewriteDequantize : public OpRewritePattern<quant::DequantizeCastOp> {
|
||||
public:
|
||||
explicit RewriteDequantize(int64_t size, MLIRContext *context)
|
||||
: OpRewritePattern<quant::DequantizeCastOp>(context), size_(size) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(quant::DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// quant.dcast
|
||||
// xla_hlo dequantize only takes min/max, so let's recover them from
|
||||
// the quantization parameters.
|
||||
Value dcast = op.arg();
|
||||
auto type = quant::QuantizedType::getQuantizedElementType(dcast.getType());
|
||||
if (!type || !type.isa<quant::UniformQuantizedType>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
auto qtype = type.cast<quant::UniformQuantizedType>();
|
||||
double scale = qtype.getScale();
|
||||
int64_t zero_point = qtype.getZeroPoint();
|
||||
float min = scale * (qtype.getStorageTypeMin() - zero_point);
|
||||
float max = scale * (qtype.getStorageTypeMax() - zero_point);
|
||||
|
||||
// quant.qcast
|
||||
auto qcast =
|
||||
llvm::dyn_cast_or_null<quant::QuantizeCastOp>(dcast.getDefiningOp());
|
||||
if (!qcast) return matchFailure();
|
||||
|
||||
// constant
|
||||
DenseFPElementsAttr attr;
|
||||
// If it isn't a floating-point constant or the size is too small, let's
|
||||
// remove the quantization. Also the last dimension size should be a
|
||||
// multiplier of 4, so the shape isn't broken during packing and unpacking.
|
||||
if (!matchPattern(qcast.arg(), m_Constant(&attr)) ||
|
||||
attr.getNumElements() <= size_ ||
|
||||
attr.getType().getDimSize(attr.getType().getRank() - 1) % 4 != 0) {
|
||||
op.getResult().replaceAllUsesWith(qcast.arg());
|
||||
return matchSuccess();
|
||||
}
|
||||
// TODO(fengliuai): implement transpose if it has high dimension.
|
||||
|
||||
// Create the quantized result
|
||||
auto quantized_result =
|
||||
quant::Quantize(attr, qtype).dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!quantized_result) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Pack the uint8 bits to uint32. The shape is changed from from
|
||||
// [n0, n1, ..., nk] to [n0, n1, ..., nk / 4].
|
||||
std::vector<uint8_t> raw_data;
|
||||
for (auto d : quantized_result.getValues<uint8_t>()) {
|
||||
raw_data.push_back(d);
|
||||
}
|
||||
// The packing might increase the data size by paddings.
|
||||
auto packed_data = xla::PackToUint32<uint8_t>(raw_data);
|
||||
auto packed_shape = attr.getType().getShape().vec();
|
||||
int lower_dims = std::accumulate(
|
||||
packed_shape.begin(),
|
||||
std::next(packed_shape.begin(), packed_shape.size() - 1), 1,
|
||||
std::multiplies<int>());
|
||||
packed_shape[packed_shape.size() - 1] = packed_data.size() / lower_dims;
|
||||
auto packed_type =
|
||||
RankedTensorType::get(packed_shape, rewriter.getIntegerType(32));
|
||||
|
||||
auto packed_quantized_result =
|
||||
DenseElementsAttr::get<uint32_t>(packed_type, packed_data);
|
||||
auto quantized_constant =
|
||||
rewriter.create<ConstantOp>(qcast.getLoc(), packed_quantized_result);
|
||||
|
||||
// Create the xla dequantize op with bf16 output
|
||||
auto dequantized_type = RankedTensorType::get(attr.getType().getShape(),
|
||||
rewriter.getBF16Type());
|
||||
auto dequantize = rewriter.create<DequantizeOp>(
|
||||
qcast.getLoc(), dequantized_type, quantized_constant,
|
||||
rewriter.getF32FloatAttr(min), rewriter.getF32FloatAttr(max),
|
||||
rewriter.getStringAttr("MIN_COMBINED"), rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
// Convert bf16 output back to f32
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(op, op.getResult().getType(),
|
||||
dequantize);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
// Materialize the quantization results by hlo primitive ops.
|
||||
struct MaterializeToXlaPass : public FunctionPass<MaterializeToXlaPass> {
|
||||
explicit MaterializeToXlaPass() = default;
|
||||
MaterializeToXlaPass(const MaterializeToXlaPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void MaterializeToXlaPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext *ctx = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
// TODO(fengliuai): make the size 6 configurable.
|
||||
patterns.insert<RewriteDequantize>(6, ctx);
|
||||
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass() {
|
||||
return std::make_unique<MaterializeToXlaPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<MaterializeToXlaPass> pass(
|
||||
"xla-hlo-materialize-quant",
|
||||
"Materialize the quantization results by xla primitve ops");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
@ -0,0 +1,7 @@
|
||||
// TODO(fengliuai): automatically generate this file
|
||||
// TODO(fengliuai): add all the xla_hlo ops
|
||||
|
||||
static std::unique_ptr<quant::OpQuantSpec> GetOpQuantSpec(mlir::Operation *op) {
|
||||
auto spec = absl::make_unique<quant::OpQuantSpec>();
|
||||
return spec;
|
||||
}
|
37
tensorflow/compiler/mlir/lite/quantization/xla/passes.h
Normal file
37
tensorflow/compiler/mlir/lite/quantization/xla/passes.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
// Propagate the quantization information to all the tensors according to the
|
||||
// op quant spec.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass();
|
||||
|
||||
// Rewrite the graph and quantize the constant.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializeToXlaPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_XLA_PASSES_H_
|
78
tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
Normal file
78
tensorflow/compiler/mlir/lite/quantization/xla/propagate.cc
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass applies quantization propagation on xla_hlo dialect.
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> disable_per_channel(
|
||||
"xla-disable-per-channel", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Whether disable per-channel quantized weights."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The quantization propagation Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
namespace {
|
||||
|
||||
// Applies the quantization propagation on the input function. During the
|
||||
// propagation, two facts are respected:
|
||||
// - The quantization type (params) of the ops in the function
|
||||
// - The quantization spec for the ops
|
||||
// The propagation results should assign quantization types to all the tensors
|
||||
// and the two restrictions are respected.
|
||||
struct PropagateQuantPass : public FunctionPass<PropagateQuantPass> {
|
||||
explicit PropagateQuantPass() = default;
|
||||
PropagateQuantPass(const PropagateQuantPass &) {}
|
||||
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/xla/op_quant_spec.inc"
|
||||
|
||||
void PropagateQuantPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
// XLA only support uint8/uint16 quantization for now.
|
||||
ApplyQuantizationParamsPropagation(func, /*is_signed*/ false,
|
||||
disable_per_channel, GetOpQuantSpec);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the xla_hlo dialect quantization propagation pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePropagateQuantPass() {
|
||||
return std::make_unique<PropagateQuantPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<PropagateQuantPass> pass(
|
||||
"xla-hlo-propagate-quant", "Propagate quantization information");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
19
tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/lite/quantization/xla/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -xla-hlo-materialize-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @quantize_rewrite
|
||||
func @quantize_rewrite(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK: %[[qcst:.*]] = constant dense<{{\[\[}}21004416], [-1056997248]]> : tensor<2x1xi32>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "xla_hlo.dequantize"(%[[qcst]]) {is_16bits = false, max_range = 0.996078431 : f32, min_range = -1.00392163 : f32,
|
||||
// CHECK-SAME: mode = "MIN_COMBINED", transpose_output = false} : (tensor<2x1xi32>) -> tensor<2x4xbf16>
|
||||
// CHECK-NEXT: %[[cast:.*]] = "xla_hlo.convert"(%[[dq]]) : (tensor<2x4xbf16>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[cast]] : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%w = constant dense<[[-1.0, -0.5, 0.0, 0.0], [0.5, 1.0, 0.0, 0.0]]> : tensor<2x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_small
|
||||
func @quantize_small(%arg0: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<1x4xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<1x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<1x4xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<1x4xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<1x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<1x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<1x4xf32>
|
||||
return %mul: tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_cst
|
||||
func @quantize_non_cst(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %arg0 : tensor<2x4xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x4xf32>
|
||||
|
||||
%q = "quant.qcast"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x4x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x4xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x4xf32>
|
||||
return %mul: tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @quantize_non_4x
|
||||
func @quantize_non_4x(%arg0: tensor<2x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<1.000000e+00> : tensor<2x5xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[w]] : tensor<2x5xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x5xf32>
|
||||
|
||||
%w = constant dense<1.0> : tensor<2x5xf32>
|
||||
%q = "quant.qcast"(%w) : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
%dq = "quant.dcast"(%q) : (tensor<2x5x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x5xf32>
|
||||
%mul = xla_hlo.mul %arg0, %dq : tensor<2x5xf32>
|
||||
return %mul: tensor<2x5xf32>
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
// RUN: tf-opt -xla-hlo-propagate-quant %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[w:.*]] = constant dense<{{\[\[}}-1.000000e+00, -5.000000e-01], [5.000000e-01, 1.000000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[w]]) : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x2x!quant.uniform<u8:f32, 0.0078431372549019607:128>>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: %[[mul:.*]] = xla_hlo.mul %arg0, %[[dq]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[mul]] : tensor<2x2xf32>
|
||||
%w = constant dense<[[-1.0, -0.5], [0.5, 1.0]]> : tensor<2x2xf32>
|
||||
%mul = xla_hlo.mul %arg0, %w : tensor<2x2xf32>
|
||||
return %mul: tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[b:.*]] = constant dense<1.000000e+00> : tensor<2xf32>
|
||||
// CHECK-NEXT: %[[q:.*]] = "quant.qcast"(%[[b]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>
|
||||
// CHECK-NEXT: %[[dq:.*]] = "quant.dcast"(%[[q]]) : (tensor<2x!quant.uniform<u8:f32, 0.0039215686274509803>>) -> tensor<2xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg0, %[[dq]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[add]] : tensor<2x2xf32>
|
||||
%b = constant dense<1.0> : tensor<2xf32>
|
||||
%add = "xla_hlo.add"(%arg0, %b) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2x2xf32>
|
||||
return %add: tensor<2x2xf32>
|
||||
}
|
39
tensorflow/compiler/mlir/lite/sparsity/BUILD
Normal file
39
tensorflow/compiler/mlir/lite/sparsity/BUILD
Normal file
@ -0,0 +1,39 @@
|
||||
package(
|
||||
default_visibility = [
|
||||
":friends",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = ["//third_party/mlir:subpackages"],
|
||||
packages = [
|
||||
"//learning/brain/experimental/mlir/...",
|
||||
"//tensorflow/lite/...",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sparsify_model",
|
||||
srcs = [
|
||||
"sparsify_model.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"sparsify_model.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:common",
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
)
|
84
tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
Normal file
84
tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h"
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter) {
|
||||
MLIRContext context;
|
||||
StatusScopedDiagnosticHandler statusHandler(&context,
|
||||
/*propagate=*/true);
|
||||
|
||||
// Import input_model to a MLIR module
|
||||
flatbuffers::FlatBufferBuilder input_builder;
|
||||
flatbuffers::Offset<tflite::Model> input_model_location =
|
||||
tflite::Model::Pack(input_builder, &input_model);
|
||||
tflite::FinishModelBuffer(input_builder, input_model_location);
|
||||
|
||||
std::string serialized_model(
|
||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||
input_builder.GetSize());
|
||||
std::vector<std::string> output_arrays_order;
|
||||
|
||||
OwningModuleRef module =
|
||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
||||
UnknownLoc::get(&context), output_arrays_order);
|
||||
if (!module) {
|
||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
PassManager pm(module->getContext());
|
||||
|
||||
if (failed(pm.run(module.get()))) {
|
||||
const std::string& err = statusHandler.ConsumeStatus().error_message();
|
||||
error_reporter->Report("Failed to sparsify: %s", err.c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Export the results to the builder
|
||||
std::string result;
|
||||
if (tflite::MlirToFlatBufferTranslateFunction(
|
||||
module.get(), &result, /*emit_builtin_tflite_ops=*/true,
|
||||
/*emit_select_tf_ops=*/true, /*emit_custom_ops=*/true)) {
|
||||
error_reporter->Report("Failed to export MLIR to flatbuffer.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
builder->PushFlatBuffer(reinterpret_cast<const uint8_t*>(result.data()),
|
||||
result.size());
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mlir
|
35
tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h
Normal file
35
tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace lite {
|
||||
|
||||
// Sparsify the `input_model` and write the result to a flatbuffer `builder`.
|
||||
TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
flatbuffers::FlatBufferBuilder* builder,
|
||||
tflite::ErrorReporter* error_reporter);
|
||||
} // namespace lite
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_
|
@ -76,42 +76,6 @@ func @reshape_not_removeIdentity(%arg0: tensor<?xf32>, %arg1: tensor<3xi32>) ->
|
||||
// CHECK-NEXT: "tfl.reshape"
|
||||
}
|
||||
|
||||
// Checks that tfl.fake_quant should be removed if all its users have valid
|
||||
// "minmax" attributes.
|
||||
func @fakequant_dropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
|
||||
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%1 = tfl.pow %arg0, %0 {minmax = [0.4, 0.6]} : tensor<i32>
|
||||
%2 = tfl.pow %1, %0 {minmax = [0.5, 0.7]} : tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
|
||||
// CHECK-LABEL: fakequant_dropfakequant
|
||||
// CHECK-NEXT: %0 = tfl.pow %arg0, %arg0 {minmax = [4.000000e-01, 6.000000e-01]} : tensor<i32>
|
||||
// CHECK-NEXT: %1 = tfl.pow %0, %arg0 {minmax = [5.000000e-01, 0.69999999999999996]} : tensor<i32>
|
||||
|
||||
// CHECK-NEXT: return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// Checks that tfl.fake_quant should not be removed if some of its users or
|
||||
// itself don't have valid "minmax" attributes.
|
||||
func @fakequant_notdropfakequant(tensor<i32>, f32, f32) -> tensor<i32> {
|
||||
^bb0(%arg0: tensor<i32>, %arg1: f32, %arg2: f32):
|
||||
%0 = "tfl.fake_quant"(%arg0) {name = 0, minmax = [], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%1 = tfl.pow %arg0, %0 : tensor<i32>
|
||||
%2 = tfl.pow %1, %0 : tensor<i32>
|
||||
|
||||
%5 = "tfl.fake_quant"(%arg0) {name = 1, minmax = [0.1, 0.2], num_bits = 4 : i32, narrow_range = false} : (tensor<i32>) -> tensor<i32>
|
||||
%6 = tfl.pow %arg0, %5 : tensor<i32>
|
||||
%7 = tfl.pow %6, %5 : tensor<i32>
|
||||
|
||||
%11 = addi %2, %7 : tensor<i32>
|
||||
return %11 : tensor<i32>
|
||||
|
||||
// CHECK-LABEL: fakequant_notdropfakequant
|
||||
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], name = 0 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %3 = "tfl.fake_quant"(%arg0) {minmax = [1.000000e-01, 2.000000e-01], name = 1 : i64, narrow_range = false, num_bits = 4 : i32} : (tensor<i32>) -> tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @RemoveRedundantUnpackPack
|
||||
|
231
tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
Normal file
231
tensorflow/compiler/mlir/lite/tests/dilated-conv.mlir
Normal file
@ -0,0 +1,231 @@
|
||||
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
|
||||
|
||||
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConv
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<2> : tensor<2x2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonZeroSTBPadding
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConv
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>)
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithPad
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithPad(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.Pad"(%1, %arg1) : (tensor<4x64x64x8xf32>, tensor<2x2xi32>) -> tensor<4x64x64x8xf32>
|
||||
%3 = "tf.BatchToSpaceND"(%2, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %4 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithPad
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %3 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>, %arg3: tensor<8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.DepthwiseConv2dNative"(%0, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
%3 = "tf.BiasAdd"(%2, %arg3) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
return %3 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithBiasAdd
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128x3xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x3x8xf32>, [[BIAS:%.*]]: tensor<8xf32>)
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[INPUT]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x3xf32>, tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[CONV]], [[BIAS]]) : (tensor<1x128x128x8xf32>, tensor<8xf32>) -> tensor<1x128x128x8xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%5 = "tf.BiasAdd"(%4, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x?x?x1xf32>) -> tensor<4x?x?xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg3) : (tensor<4x?x?xf32>, tensor<?xf32>) -> tensor<4x?x?xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x?x?xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
return %5 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<?xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [3]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64xf32>
|
||||
%4 = "tf.Pad"(%3, %arg1) : (tensor<4x64x64xf32>, tensor<2x2xi32>) -> tensor<4x64x64xf32>
|
||||
%5 = "tf.BatchToSpaceND"(%4, %cst, %arg1) : (tensor<4x64x64xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128xf32>
|
||||
%6 = "tf.BiasAdd"(%5, %arg3) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
return %6 : tensor<1x128x128xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
@ -420,6 +420,15 @@ func @gatherV2VectorIndices(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>)
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndices_I64Axis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x3x5x20xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi64>) -> tensor<1x3x5x20xf32>
|
||||
return %1 : tensor<1x3x5x20xf32>
|
||||
|
||||
// CHECK-LABEL:gatherV2VectorIndices_I64Axis
|
||||
// CHECK: "tfl.gather"(%arg0, %arg1) {axis = 1 : i32} : (tensor<1x2x20xf32>, tensor<3x5xi32>) -> tensor<1x3x5x20xf32>
|
||||
}
|
||||
|
||||
func @gatherV2VectorIndicesNegAxis(%arg0 : tensor<1x2x20xf32>, %arg1 : tensor<3x5xi32>) -> tensor<1x2x3x5xf32> {
|
||||
%0 = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32>
|
||||
%1 = "tf.GatherV2"(%arg0, %arg1, %0) : (tensor<1x2x20xf32>, tensor<3x5xi32>, tensor<1xi32>) -> tensor<1x2x3x5xf32>
|
||||
@ -1074,6 +1083,14 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @castComplex(%arg0: tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>> {
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
||||
return %0 : tensor<1x2x2x5xcomplex<f32>>
|
||||
|
||||
// CHECK-LABEL: castComplex
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xf32>) -> tensor<1x2x2x5xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @unique(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0, %1 : tensor<?xf32> , tensor<?xi32>
|
||||
|
@ -1,5 +1,26 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: tensorlistConst
|
||||
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK: %[[ELEMENT0:.*]] = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: %[[ELEMENT1:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK: %[[LIST:.*]] = "tf.Pack"(%[[ELEMENT0]], %[[ELEMENT1]]) {axis = 0 : i64} : (tensor<3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A2022485C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C3030335C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030305C3030315C3030325C3033325C725C3031305C3030335C3032325C3030345C3032325C3030325C3031305C3030333A5C3030335C3030335C3030345C30303522"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||
|
||||
// CHECK: return %[[LIST]]
|
||||
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @emptyTensorlistConst(%arg0 : tensor<1xi32>) -> tensor<0x3xi32> {
|
||||
// CHECK: %[[LIST:.*]] = "tf.Const"() {value = dense<{{\[\[}}]]> : tensor<0x3xi32>} : () -> tensor<0x3xi32>
|
||||
%0 = "tf.Const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F56415249414E542074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20222A5C6E5C30323674656E736F72666C6F773A3A54656E736F724C6973745C3032325C3032305C3030305C3030335C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3337375C3030315C3032325C3030325C3031305C30303322"> : tensor<!tf.variant>} : () -> tensor<!tf.variant<tensor<3xi32>>>
|
||||
|
||||
// CHECK: return %[[LIST]]
|
||||
%1 = "tf.TensorListStack"(%0, %arg0) : (tensor<!tf.variant<tensor<3xi32>>>, tensor<1xi32>) -> tensor<0x3xi32>
|
||||
return %1 : tensor<0x3xi32>
|
||||
}
|
||||
|
||||
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
|
29
tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD
Normal file
29
tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD
Normal file
@ -0,0 +1,29 @@
|
||||
# Description:
|
||||
# Integration tests of converter & interpreter.
|
||||
#
|
||||
# There should be few tests in here and it should be rare where the execution
|
||||
# tests are not tested by unit tests already. This is useful for verifying some
|
||||
# runtime behavior, but the majority of runtime tests should be TFLite side and
|
||||
# invariants only verified in the converter/compiler.
|
||||
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-opt",
|
||||
"//tensorflow/compiler/mlir/lite:mlir-tflite-runner",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
// Test to verify translation & export work as intended with runtime.
|
||||
|
||||
// RUN: not mlir-tflite-runner --dump-interpreter-state %s 2>&1 | FileCheck %s --check-prefix ERROR --dump-input-on-failure
|
||||
// RUN: tf-opt --mlir-print-debuginfo --canonicalize --tfl-while-loop-outline %s | mlir-tflite-runner --dump-interpreter-state 2>&1 | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// ERROR: number of operands and results don't match
|
||||
|
||||
// Verify value computed:
|
||||
// ----------------------
|
||||
// CHECK: result: Tensor<type: FLOAT32, shape: 1, values: 96>
|
||||
|
||||
// Verify tensors in interpreter state:
|
||||
// ------------------------------------
|
||||
// CHECK: Tensor 0 dec kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 1 N kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 2 val kTfLiteFloat32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 3 std.constant kTfLiteInt32 kTfLiteMmapRo 4 bytes
|
||||
// CHECK-NEXT: Tensor 4 tfl.while kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 5 result kTfLiteFloat32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 6 tfl.while:2 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
// CHECK-NEXT: Tensor 7 tfl.while:3 kTfLiteInt32 kTfLiteArenaRw 4 bytes
|
||||
|
||||
// Verify while was not folded away:
|
||||
// ------------------------------------
|
||||
// CHECK: Operator Builtin Code {{[0-9]*}} WHILE
|
||||
|
||||
func @main() -> tensor<1xf32>
|
||||
attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst = constant dense<1> : tensor<i32> loc("dec")
|
||||
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
||||
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||
%0:2 = "tfl.while"(%arg0, %arg1, %cst) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>, %arg4: tensor<i32>):
|
||||
%1 = "tfl.sub"(%arg2, %arg4) {fused_activation_function = "NONE"} :
|
||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
"tfl.yield"(%1, %2, %arg4) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> ()
|
||||
}) : (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<1xf32>)
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
|
@ -1,139 +1,56 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string -
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -tflite-flatbuffer-to-mlir - -o - | FileCheck --check-prefix=IMPORT %s
|
||||
|
||||
// TODO(b/141520199): Currently fake quant is not being written to flatbuffer
|
||||
// since it is legalized to quantize and dequantize. Update this test and add
|
||||
// fake_quant_v2.mlir when the op is being written to flatbuffer.
|
||||
func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>):
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: SQUARED_DIFFERENCE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: MUL,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: DIV,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: EXP,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: NEG,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "squared_difference",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "mul",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "div",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "exp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "neg",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: MulOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: DivOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 4 ],
|
||||
// CHECK-NEXT: outputs: [ 5 ],
|
||||
// CHECK-NEXT: builtin_options_type: ExpOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 4,
|
||||
// CHECK-NEXT: inputs: [ 5 ],
|
||||
// CHECK-NEXT: outputs: [ 6 ],
|
||||
// CHECK-NEXT: builtin_options_type: NegOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63, 0, 0, 128, 63 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: FAKE_QUANT,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "tfl.fake_quant",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0 ],
|
||||
// CHECK-NEXT: outputs: [ 1 ],
|
||||
// CHECK-NEXT: builtin_options_type: FakeQuantOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: min: 0.3,
|
||||
// CHECK-NEXT: max: 1.4,
|
||||
// CHECK-NEXT: num_bits: 6
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%0 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
// IMPORT: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32}
|
||||
|
||||
%0 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor<4 x f32>) -> tensor<4 x f32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
@ -0,0 +1,219 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
// CHECK-NEXT: builtin_code: WHILE,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: GREATER,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: builtin_code: SUB,
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: version: 1
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "WhileOp",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "WhileOp1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 2, 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: WhileOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-NEXT: cond_subgraph_index: 1,
|
||||
// CHECK-NEXT: body_subgraph_index: 2
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "main"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: BOOL,
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "tfl.greater",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "WhileOp_cond"
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: type: INT32,
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "tfl.sub",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ ],
|
||||
// CHECK-NEXT: buffer: 13,
|
||||
// CHECK-NEXT: name: "tfl.add",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: inputs: [ 0, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3, 4 ],
|
||||
// CHECK-NEXT: operators: [ {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 0, 2 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: builtin_options_type: SubOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 3,
|
||||
// CHECK-NEXT: inputs: [ 1, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 4 ],
|
||||
// CHECK-NEXT: builtin_options_type: AddOptions,
|
||||
// CHECK-NEXT: builtin_options: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: name: "WhileOp_body"
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: description: "MLIR Converted.",
|
||||
// CHECK-NEXT: buffers: [ {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 1, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: } ]
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<i1> {
|
||||
%cst = constant dense<0> : tensor<i32> loc("Const")
|
||||
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) {
|
||||
%cst = constant dense<1> : tensor<i32> loc("Const1")
|
||||
%0 = "tfl.sub"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||
%1 = call @WhileOp_cond(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>): // no predecessors
|
||||
%1:2 = call @WhileOp_body(%arg2, %arg3) : (tensor<*xi32>, tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
@ -355,10 +355,8 @@ func @testConv2DNoBias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf3
|
||||
// CHECK-LABEL: testFakeQuant
|
||||
func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
|
||||
^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
|
||||
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], narrow_range = true, num_bits = 2 : i32} : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.fake_quant"(%arg0) {minmax = [], num_bits = 2 : i32, narrow_range = true} : (tensor<? x f32>) -> tensor<? x f32>
|
||||
// CHECK: %1 = "tfl.fake_quant"(%0) {minmax = [3.000000e-01, 1.400000e+00], narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "tfl.fake_quant"(%0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<? x f32>) -> tensor<? x f32>
|
||||
// CHECK: "tfl.fake_quant"(%arg0) {max = 1.400000e+00 : f32, min = 3.000000e-01 : f32, narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "tfl.fake_quant"(%arg0) {num_bits = 6 : i32, narrow_range = false, min = 0.3:f32, max = 1.4:f32} : (tensor<? x f32>) -> tensor<? x f32>
|
||||
return %1 : tensor<? x f32>
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
// Run optimize pass only and check the results.
|
||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-optimize | FileCheck %s --dump-input-on-failure
|
||||
// Run optimize pass and then canonicalize pass, and make sure some folding is applied.
|
||||
// RUN: tf-opt %s -tfl-optimize -canonicalize | FileCheck --check-prefix=FOLD %s
|
||||
|
||||
@ -294,8 +294,68 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddUnit
|
||||
func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithNoBias
|
||||
func @FuseFullyConnectedAddWithNoBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithExistingBias
|
||||
func @FuseFullyConnectedAddWithExistingBias(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithNoBiasAndScalarRhs
|
||||
func @FuseFullyConnectedAddWithNoBiasAndScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<f32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<f32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithScalarRhs
|
||||
func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<f32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<f32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
|
||||
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant unit
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
@ -304,24 +364,11 @@ func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf3
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<2.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %cst)
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedAddConst
|
||||
func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40x40xf32>
|
||||
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
|
||||
|
||||
return %1 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
// CHECK: %[[unit:.*]] = constant unit
|
||||
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
|
||||
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
|
||||
// CHECK: return %[[add_result]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
@ -690,6 +737,54 @@ func @fuse_relu_to_add(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor
|
||||
// CHECK: return %[[RES]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: leaky_relu_fusion
|
||||
func @leaky_relu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<0.2> : tensor<f32>
|
||||
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.leaky_relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: leaky_relu_not_fused
|
||||
// Should not fuse to LeakyRelu, since alpha > 1.
|
||||
func @leaky_relu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<1.2> : tensor<f32>
|
||||
%0 = "tfl.mul"(%arg0, %alpha) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.maximum"(%0, %arg0) : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %1 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.maximum"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prelu_fusion
|
||||
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<3xf32>
|
||||
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.prelu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prelu_not_fused
|
||||
// Rank of alpha should be one less than input for PReLU, which is not the case.
|
||||
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<f32>
|
||||
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
|
||||
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
|
@ -242,6 +242,22 @@ func @QuantizePad(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) ->
|
||||
// CHECK: return %3 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizePad2
|
||||
// only the second tfl.pad has sufficient quantization information.
|
||||
func @QuantizePad2(tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
^bb0(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x1x3xf32>, %arg2: tensor<3x2xi32>):
|
||||
%0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x1x3xf32>
|
||||
%1 = "tfl.pad"(%arg1, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||
%2 = "tfl.pad"(%0, %arg2) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<?xf32>
|
||||
return %1, %2 : tensor<?xf32>, tensor<?xf32>
|
||||
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%arg0)
|
||||
// CHECK: %[[pad1:.*]] = "tfl.pad"(%arg1, %arg2)
|
||||
// CHECK: %[[pad2:.*]] = "tfl.pad"(%[[dq]], %arg2)
|
||||
// CHECK: %[[q2:.*]] = "tfl.quantize"(%[[pad2]])
|
||||
// CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]])
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeReshape2D
|
||||
func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>) -> tensor<1x36x16xf32> {
|
||||
^bb0(%arg0: tensor<1x6x6x16x!quant.uniform<u8:f32, 7.812500e-03:128>>):
|
||||
@ -418,16 +434,15 @@ func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform<u8:f32, 2.0:
|
||||
// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform<u8:f32, 1.000000e-01:128>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RequantizeAlreadyQuantizedModel
|
||||
func @RequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
|
||||
// CHECK-LABEL: NotRequantizeAlreadyQuantizedModel
|
||||
func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, %arg1: tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>> {
|
||||
%9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>
|
||||
%10 = "tfl.concatenation"(%arg0, %9) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.0>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.0>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
return %10 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.0>>
|
||||
|
||||
// CHECK: %0 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
|
||||
// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %2 = "tfl.concatenation"(%arg0, %1) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %2 : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: %[[max:.*]] = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>
|
||||
// CHECK: %[[cat:.*]] = "tfl.concatenation"(%arg0, %[[max]]) {axis = 3 : i32, fused_activation_function = "NONE"} : (tensor<1x73x73x64x!quant.uniform<u8:f32, 1.000000e+00>>, tensor<1x73x73x96x!quant.uniform<u8:f32, 2.000000e+00>>) -> tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
// CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform<u8:f32, 1.000000e+00>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeChain
|
||||
|
26
tensorflow/compiler/mlir/lite/tests/tfl_while_op_licm.mlir
Normal file
26
tensorflow/compiler/mlir/lite/tests/tfl_while_op_licm.mlir
Normal file
@ -0,0 +1,26 @@
|
||||
// RUN: tf-opt -loop-invariant-code-motion %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK: while_1([[ARG0:%[^ :]*]]: tensor<i32>, [[ARG1:%[^ :]*]]: tensor<1xf32>)
|
||||
func @while_1(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: [[CST:%[^ ]*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: "tfl.while"([[ARG0]], [[ARG1]])
|
||||
// CHECK: (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>)
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) (
|
||||
// cond
|
||||
{
|
||||
^bb0(%condArg0: tensor<*xi32>, %condArg1: tensor<*xf32>):
|
||||
%0 = "std.constant" () {value = dense<0> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.greater"(%condArg0, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
},
|
||||
// body
|
||||
{
|
||||
^bb0(%bodyArg0: tensor<*xi32>, %bodyArg1: tensor<*xf32>):
|
||||
%0 = "std.constant" () {value = dense<1> : tensor<i32>} : () -> tensor<i32> loc("Const")
|
||||
%1 = "tfl.sub"(%bodyArg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %bodyArg1, %bodyArg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}
|
||||
) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
31
tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir
Normal file
31
tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir
Normal file
@ -0,0 +1,31 @@
|
||||
// Test to verify loop outlining.
|
||||
|
||||
// RUN: tf-opt --tfl-while-loop-outline %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @while
|
||||
func @while() -> tensor<1xf32>
|
||||
attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst = constant dense<1> : tensor<i32> loc("dec")
|
||||
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
||||
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_cond
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_body
|
||||
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
|
||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
"tfl.yield"(%1, %2) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
@ -80,10 +80,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
}
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// Execute this pass before `CanonicalizerPass` in case some TensorList
|
||||
// ops are constant folded into variant types.
|
||||
// TODO(b/137125056): Move this pass after `CanonicalizerPass` after we
|
||||
// handle constant ops that produce `TensorList`.
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -65,23 +66,23 @@ class DefaultQuantParamsPass : public FunctionPass<DefaultQuantParamsPass> {
|
||||
// Uses `quant_params` to quantize `value` and inserting a pair of
|
||||
// tfl.quantize and tfl.dequantize ops for this `value`.
|
||||
void QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params);
|
||||
quant::QuantParams quant_params);
|
||||
|
||||
// If the value hasn't been quantized, the functions adds it to `values`.
|
||||
void AddToWorkListIfUnquantized(Value value, std::vector<Value> *values);
|
||||
|
||||
// Converts the default min/max to the default quantization parameters.
|
||||
TFL::QuantParams GetDefaultQuantParams(Builder builder);
|
||||
quant::QuantParams GetDefaultQuantParams(Builder builder);
|
||||
|
||||
// Gets the quantization parameters for the bias of an operation by using the
|
||||
// quantization parameters from the non-biases operands.
|
||||
TFL::QuantParams GetQuantParamsForBias(Operation *op, int bias,
|
||||
const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func);
|
||||
quant::QuantParams GetQuantParamsForBias(Operation *op, int bias,
|
||||
const std::vector<int> &non_biases,
|
||||
quant::AccumulatorScaleFunc func);
|
||||
|
||||
double default_min_;
|
||||
double default_max_;
|
||||
TFL::QuantParams default_quant_params_;
|
||||
quant::QuantParams default_quant_params_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -104,7 +105,9 @@ void DefaultQuantParamsPass::runOnFunction() {
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>())
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op))
|
||||
return;
|
||||
|
||||
for (auto res : op->getResults()) {
|
||||
@ -117,7 +120,7 @@ void DefaultQuantParamsPass::runOnFunction() {
|
||||
});
|
||||
|
||||
// Apply the default quantization parameters for these activation values.
|
||||
TFL::QuantParams default_params = GetDefaultQuantParams(builder);
|
||||
quant::QuantParams default_params = GetDefaultQuantParams(builder);
|
||||
for (Value value : activation_values) {
|
||||
QuantizeValue(builder, value, default_params);
|
||||
}
|
||||
@ -128,7 +131,7 @@ void DefaultQuantParamsPass::runOnFunction() {
|
||||
Operation *op = *bias.user_begin();
|
||||
auto spec = TFL::GetOpQuantSpec(op);
|
||||
for (auto &it : spec->biases_params) {
|
||||
TFL::QuantParams bias_params = GetQuantParamsForBias(
|
||||
quant::QuantParams bias_params = GetQuantParamsForBias(
|
||||
op, it.first, it.second.first, it.second.second);
|
||||
if (!bias_params) continue;
|
||||
QuantizeValue(builder, bias, bias_params);
|
||||
@ -157,7 +160,7 @@ void DefaultQuantParamsPass::AddToWorkListIfUnquantized(
|
||||
}
|
||||
|
||||
void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
|
||||
TFL::QuantParams quant_params) {
|
||||
quant::QuantParams quant_params) {
|
||||
Type expressed_type = value.getType();
|
||||
Type new_type = quant_params.castFromExpressedType(expressed_type);
|
||||
// This value isn't an expressed type (float), skip.
|
||||
@ -182,9 +185,9 @@ void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value,
|
||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||
quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||
Operation *op, int bias, const std::vector<int> &non_biases,
|
||||
TFL::AccumulatorScaleFunc func) {
|
||||
quant::AccumulatorScaleFunc func) {
|
||||
std::vector<quant::QuantizedType> non_bias_types;
|
||||
non_bias_types.reserve(non_biases.size());
|
||||
for (int non_bias : non_biases) {
|
||||
@ -205,7 +208,7 @@ TFL::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias(
|
||||
return func(non_bias_types);
|
||||
}
|
||||
|
||||
TFL::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
|
||||
Builder builder) {
|
||||
if (!default_quant_params_) {
|
||||
default_quant_params_ = quant::fakeQuantAttrsToType(
|
||||
|
74
tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc
Normal file
74
tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass convert dense tensor to sparse format.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The DenseToSparse Pass.
|
||||
//
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
struct DenseToSparse : public FunctionPass<DenseToSparse> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void DenseToSparse::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
OpBuilder builder(func);
|
||||
|
||||
func.walk([&](SparseOpInterface sparse_op) {
|
||||
const auto& sparse_operands = sparse_op.GetSparseOperands();
|
||||
for (const int operand : sparse_operands) {
|
||||
auto* op = sparse_op.getOperation();
|
||||
const auto& value = op->getOperand(operand);
|
||||
builder.setInsertionPoint(op);
|
||||
if (auto* inst = value.getDefiningOp()) {
|
||||
// Replace defining op with SparseConst or SparseQConst.
|
||||
// TODO(yunluli): Implement.
|
||||
}
|
||||
|
||||
// TODO(yunluli): Implement.
|
||||
bool needs_densify = false;
|
||||
|
||||
if (needs_densify) {
|
||||
auto densify = builder.create<DensifyOp>(op->getLoc(), value);
|
||||
value.replaceAllUsesWith(densify);
|
||||
densify.setOperand(value);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect DenseToSparse pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass() {
|
||||
return absl::make_unique<DenseToSparse>();
|
||||
}
|
||||
|
||||
static PassRegistration<DenseToSparse> pass(
|
||||
"tfl-dense-to-sparse", "Convert dense tensor to sparse format.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
41
tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc
Normal file
41
tensorflow/compiler/mlir/lite/transforms/dilated_conv.cc
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
struct IdentifyDilatedConvPass : public FunctionPass<IdentifyDilatedConvPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void IdentifyDilatedConvPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(
|
||||
&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<IdentifyDilatedConvPass> pass(
|
||||
"tfl-identify-dilated-conv",
|
||||
"Identify and replace patterns for dilated convolution.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
234
tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
Normal file
234
tensorflow/compiler/mlir/lite/transforms/dilated_conv.h
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// This pass identifies patterns for dilated convolution and replace it with
|
||||
// a real convolution op.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// A dilated convolution can be emulated with a regular convolution by chaining
|
||||
// SpaceToBatch and BatchToSpace ops before and after it:
|
||||
//
|
||||
// SpaceToBatchND -> Conv2D -> BatchToSpaceND
|
||||
//
|
||||
// This method was common before Conv2D fully supported dilated convolution in
|
||||
// TensorFlow. This transformation detects this "emulation", and replaces it
|
||||
// with a true dilated convolution, eliminating the SpaceToBatch and
|
||||
// BatchtoSpace ops.
|
||||
//
|
||||
// Detecting this alone would be relatively easy. However, in practice some
|
||||
// extra ops are used, so we detect the following patterns:
|
||||
//
|
||||
//
|
||||
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
|
||||
//
|
||||
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
|
||||
// BiasAdd
|
||||
//
|
||||
// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
|
||||
//
|
||||
// SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
|
||||
//
|
||||
// SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
|
||||
//
|
||||
//
|
||||
// The Expand/Squeeze combination is used to adapt a 3D array (such as in
|
||||
// WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
|
||||
// thrown in just for the extra headache. Padding adapts non-conforming input
|
||||
// sizes, and can be discarded. The bias is necessary, so is kept.
|
||||
template <typename Conv2dOpTy>
|
||||
class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
||||
private:
|
||||
using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
|
||||
|
||||
// Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
|
||||
llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
|
||||
Value stb_block_shape, Value bts_block_shape,
|
||||
PatternRewriter& rewriter) const;
|
||||
|
||||
public:
|
||||
PatternMatchResult matchAndRewrite(Conv2dOpTy op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
template <typename Conv2dOpTy>
|
||||
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||
// `Squeeze` op.
|
||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
|
||||
Operation* next_op = op.getOperation()->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
// Expand + Squeeze op.
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
// Expand/Squeeze op must come in pair.
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||
|
||||
// Update previous/next op pointer.
|
||||
prev_op = prev_op->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// SpaceToBatchND op.
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||
|
||||
// Pad op.
|
||||
TF::PadOp pad_op;
|
||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op) return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
TF::BatchToSpaceNDOp bts_op;
|
||||
TF::BiasAddOp biasadd_op;
|
||||
bool final_op_is_bts = true;
|
||||
if (llvm::isa<TF::BiasAddOp>(next_op)) {
|
||||
// Must be BiasAdd + BatchToSpaceND.
|
||||
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (!next_op || !llvm::isa<TF::BatchToSpaceNDOp>(next_op))
|
||||
return Pattern::matchFailure();
|
||||
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
|
||||
} else if (llvm::isa<TF::BatchToSpaceNDOp>(next_op)) {
|
||||
// BatchToSpaceND + (optional) BiasAdd.
|
||||
bts_op = llvm::cast<TF::BatchToSpaceNDOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
if (next_op && llvm::isa<TF::BiasAddOp>(next_op)) {
|
||||
biasadd_op = llvm::cast<TF::BiasAddOp>(next_op);
|
||||
final_op_is_bts = false;
|
||||
}
|
||||
} else {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
|
||||
stb_op.block_shape(), bts_op.block_shape(), rewriter);
|
||||
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
// Here we need to set the correct padding for Conv op. In TF, the conv op
|
||||
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
|
||||
// become a problem here if the original Conv op has 'SAME' padding. When
|
||||
// the original conv has 'SAME' padding, TF will set a non-zero padding for
|
||||
// the 'SpaceToBatch' op, so we rely on this information to check if we need
|
||||
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
|
||||
// values in `stb_op.paddings`, we change the current Conv's padding to
|
||||
// 'SAME').
|
||||
auto stb_paddings = stb_op.paddings();
|
||||
ElementsAttr stb_paddings_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
||||
if (llvm::any_of(stb_paddings_attr.getValues<IntegerAttr>(),
|
||||
[](IntegerAttr attr) { return attr.getInt() != 0; })) {
|
||||
op.setAttr("padding", rewriter.getStringAttr("SAME"));
|
||||
}
|
||||
}
|
||||
|
||||
if (expand_op) {
|
||||
// If there is `expand_op`, we need to rewire the inputs to bypass the
|
||||
// `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
|
||||
// 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
|
||||
// BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
|
||||
|
||||
// Connect `expand_op` with the input of `stb_op`.
|
||||
expand_op.setOperand(0, stb_op.input());
|
||||
// Calculate the shape for expand.
|
||||
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
|
||||
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
|
||||
input_shape.end());
|
||||
expand_shape.push_back(1);
|
||||
auto expand_result_type = RankedTensorType::get(
|
||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||
expand_op.getResult().setType(expand_result_type);
|
||||
op.getResult().setType(expand_result_type);
|
||||
|
||||
squeeze_op.getResult().setType(bts_op.output().getType());
|
||||
|
||||
// Connect `biasadd_op` with the output of `squeeze_op`.
|
||||
biasadd_op.setOperand(0, squeeze_op.output());
|
||||
biasadd_op.output().setType(squeeze_op.output().getType());
|
||||
} else {
|
||||
if (biasadd_op) biasadd_op.setOperand(0, op.output());
|
||||
op.setOperand(0, stb_op.input());
|
||||
op.getResult().setType(bts_op.getResult().getType());
|
||||
}
|
||||
|
||||
if (final_op_is_bts) {
|
||||
bts_op.getResult().replaceAllUsesWith(bts_op.input());
|
||||
}
|
||||
|
||||
stb_op.getResult().dropAllUses();
|
||||
return Pattern::matchSuccess();
|
||||
}
|
||||
|
||||
template <typename Conv2dOpTy>
|
||||
llvm::Optional<ArrayAttr>
|
||||
ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
Value stb_block_shape, Value bts_block_shape,
|
||||
PatternRewriter& rewriter) const {
|
||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
|
||||
// Returns failure status if block shape is not a constant.
|
||||
return {};
|
||||
}
|
||||
// Check that the block_shape of `stb_op` and `bts_op` are equal.
|
||||
if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
|
||||
for (uint64_t i = 0; i < stb_bs_attr.getNumElements(); ++i) {
|
||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||
}
|
||||
|
||||
// TODO(haoliang): support 1-D dilated conv.
|
||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
||||
|
||||
int dilation_h_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
int dilation_w_factor =
|
||||
stb_bs_attr.getValue({1}).cast<IntegerAttr>().getInt();
|
||||
|
||||
return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
|
@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
"quant::GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||
|
||||
// Converts an integer attribute $0 to 32-bit with builder.
|
||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
@ -49,11 +49,19 @@ def convertIntAttrTo32Bit : NativeCodeCall<
|
||||
def ExtractSingleElementAsInteger : NativeCodeCall<
|
||||
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
|
||||
|
||||
// Extracts the single int32 element from $_self.
|
||||
def ExtractSingleElementAsInt32 : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
// TODO(karimnosseir): Move to a common place?
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -186,7 +194,7 @@ def : Pat<(TF_GatherV2Op $params, $indices,
|
||||
(ConstantOp ElementsAttr:$axis),
|
||||
ConstantAttr<I64Attr, "0">:$batch_dims),
|
||||
(TFL_GatherOp $params, $indices,
|
||||
ExtractSingleElementAsInteger:$axis)>;
|
||||
ExtractSingleElementAsInt32:$axis)>;
|
||||
|
||||
def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>;
|
||||
|
||||
@ -198,16 +206,20 @@ def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
|
||||
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
// TODO(karimnosseir): Can the activation part here be removed by modifying the
|
||||
// very similar pass in optimize_patterns.td?
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
|
||||
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||
(ToOp $l, $r, TFL_AF_None)>;
|
||||
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
|
||||
[TF_Relu6Op, TFL_AF_Relu6]] in {
|
||||
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (FromOp:$bin_out $lhs, $rhs)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $bin_out)]>;
|
||||
// TODO: Maybe move these below to general pass?
|
||||
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (ToOp:$bin_out $lhs, $rhs, TFL_AF_None)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $bin_out)]>;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@ -57,6 +59,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/tensor_list.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
@ -162,6 +168,86 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
start_position, slice_size);
|
||||
}
|
||||
|
||||
// Converts tf.Const containing variant of type TensorList to a tensor of
|
||||
// primitive element types. Each of the individual tensor in the list is
|
||||
// converted to an ElementsAttr and then those are packed together using
|
||||
// tf.Pack op.
|
||||
struct ConvertConst : public OpConversionPattern<TF::ConstOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
TF::ConstOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Verify that the opaque elements attribute contains tensor of type variant
|
||||
// and scalar shape. The variant type should hold a TensorList.
|
||||
auto opaque_attr = op.value().dyn_cast<OpaqueElementsAttr>();
|
||||
if (!opaque_attr) return matchFailure();
|
||||
tensorflow::Tensor tensor;
|
||||
if (!tensorflow::ConvertToTensor(opaque_attr, &tensor).ok())
|
||||
return matchFailure();
|
||||
if (tensor.dtype() != tensorflow::DT_VARIANT) return matchFailure();
|
||||
if (!tensorflow::TensorShapeUtils::IsScalar(tensor.shape()))
|
||||
return matchFailure();
|
||||
|
||||
const tensorflow::TensorList *list =
|
||||
tensor.scalar<tensorflow::Variant>()().get<tensorflow::TensorList>();
|
||||
if (!list) return matchFailure();
|
||||
|
||||
// Verify output type is variant and contains exactly one ranked subtypes.
|
||||
auto variant_ty =
|
||||
getElementTypeOrSelf(op.getType()).dyn_cast<TF::VariantType>();
|
||||
if (!variant_ty) return matchFailure();
|
||||
ArrayRef<TensorType> subtypes = variant_ty.getSubtypes();
|
||||
if (subtypes.size() != 1) return matchFailure();
|
||||
RankedTensorType list_element_ty =
|
||||
subtypes.front().dyn_cast<RankedTensorType>();
|
||||
if (!list_element_ty) return matchFailure();
|
||||
|
||||
// Extract tensor elements for the TensorList and construct result type
|
||||
// based on the number of elements and element shape.
|
||||
const std::vector<tensorflow::Tensor> &tensors = list->tensors();
|
||||
llvm::SmallVector<int64_t, 4> result_shape = {
|
||||
static_cast<int64_t>(tensors.size())};
|
||||
result_shape.append(list_element_ty.getShape().begin(),
|
||||
list_element_ty.getShape().end());
|
||||
auto result_ty =
|
||||
RankedTensorType::get(result_shape, list_element_ty.getElementType());
|
||||
|
||||
// If the list is empty, directly create the final result instead of
|
||||
// creating the tf.Pack op. tf.Pack op requires at least one operand.
|
||||
if (tensors.empty()) {
|
||||
absl::InlinedVector<tensorflow::int64, 4> tf_shape;
|
||||
tf_shape.reserve(result_shape.size());
|
||||
for (int64_t dim : result_shape) {
|
||||
tf_shape.push_back(dim);
|
||||
}
|
||||
|
||||
tensorflow::Tensor tensor(list->element_dtype,
|
||||
tensorflow::TensorShape(tf_shape));
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
rewriter.replaceOpWithNewOp<TF::ConstOp>(op, attr_or.ValueOrDie());
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
// Extract individual tensor list element and combine them using the tf.Pack
|
||||
// op.
|
||||
Location loc = op.getLoc();
|
||||
llvm::SmallVector<Value, 4> values;
|
||||
values.reserve(tensors.size());
|
||||
for (const tensorflow::Tensor &tensor : tensors) {
|
||||
auto attr_or = tensorflow::ConvertTensor(tensor, &rewriter);
|
||||
if (!attr_or.ok()) return matchFailure();
|
||||
|
||||
auto value = rewriter.create<TF::ConstOp>(loc, attr_or.ValueOrDie());
|
||||
values.push_back(value);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<TF::PackOp>(
|
||||
op, result_ty, values, /*axis=*/rewriter.getI64IntegerAttr(0));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListSetItem
|
||||
: public OpConversionPattern<TF::TensorListSetItemOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -615,7 +701,7 @@ struct ConvertTensorListStack
|
||||
if ((ranked_type && ranked_type.getRank() == 0) ||
|
||||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
// If no constant is spotted, just forward the operand.
|
||||
rewriter.replaceOp(op, {input}, llvm::None);
|
||||
rewriter.replaceOp(op, {input});
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
@ -768,7 +854,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns
|
||||
.insert<ConvertEmptyTensorList, ConvertIdentity,
|
||||
.insert<ConvertConst, ConvertEmptyTensorList, ConvertIdentity,
|
||||
ConvertTensorListFromTensor, ConvertTensorListGetItem,
|
||||
ConvertTensorListLength, ConvertTensorListPushBack,
|
||||
ConvertTensorListReserve, ConvertTensorListSetItem,
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
@ -173,7 +174,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
||||
}
|
||||
|
||||
TypeAttr RescaleQtype(Type input, Attribute factor) {
|
||||
return TFL::RescaleQuantizedType(input, factor);
|
||||
return quant::RescaleQuantizedType(input, factor);
|
||||
}
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
@ -201,34 +202,80 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
||||
|
||||
PatternMatchResult matchAndRewrite(TFL::AddOp add_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Add.
|
||||
// Match Add.
|
||||
DenseElementsAttr added_value;
|
||||
Value constant_val = add_op.rhs();
|
||||
if (!matchPattern(constant_val, m_Constant(&added_value)))
|
||||
return matchFailure();
|
||||
|
||||
// Fully Connected.
|
||||
// Match Fully Connected.
|
||||
auto fc_op =
|
||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
|
||||
if (!fc_op) return matchFailure();
|
||||
|
||||
// Check if the constant RHS is either 0D (scalar), or a 1D with
|
||||
// `{num_channels}` shape.
|
||||
auto constant_val_type = constant_val.getType().cast<TensorType>();
|
||||
|
||||
// In TFLite FullyConnect definition, bias must be a 1D tensor where
|
||||
// the number of elements is equal to the number of channels.
|
||||
// If it's not 1D or 0D (which can be broadcasted to 1D), reject the
|
||||
// matching.
|
||||
bool is_scalar_rhs = false;
|
||||
if (constant_val_type.getRank() == 0) {
|
||||
is_scalar_rhs = true;
|
||||
} else if (constant_val_type.getRank() != 1) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
Value filter = fc_op.filter();
|
||||
Value bias = fc_op.bias();
|
||||
ElementsAttr bias_value;
|
||||
const bool is_none_bias = bias.getType().isa<NoneType>();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
|
||||
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
||||
return matchFailure();
|
||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||
|
||||
// Rewrite
|
||||
Location loc = fc_op.getLoc();
|
||||
// If bias isn't None, it needs to be added as well.
|
||||
|
||||
if (is_none_bias) {
|
||||
bias = constant_val;
|
||||
if (is_scalar_rhs) {
|
||||
// If the `constant_val` is scalar, we must the shape of filter
|
||||
// to properly broadcast the scalar to `{num_channels}` shape.
|
||||
|
||||
// Get the number of channels if possible.
|
||||
auto filter_type = filter.getType().cast<ShapedType>();
|
||||
// Filter must be a `2D` tensor with `{num_channels, num_features}`
|
||||
// shape. The following check is rejecting unknown rank (-1).
|
||||
if (filter_type.getRank() != 2) {
|
||||
return matchFailure();
|
||||
}
|
||||
int num_channels = filter_type.getShape()[0];
|
||||
|
||||
// Create a zero tensor with shape {num_channels}, and the type need to
|
||||
// be the same as constant_val.
|
||||
// This is a way to gracefully handle scalar tensor. The Add will always
|
||||
// be constant-folded away regardless if `constant_val` is a scalar or
|
||||
// not.
|
||||
RankedTensorType type = RankedTensorType::get(
|
||||
{num_channels}, constant_val_type.getElementType());
|
||||
auto attr = rewriter.getZeroAttr(type);
|
||||
bias = rewriter.create<ConstantOp>(loc, type, attr);
|
||||
auto none_af = rewriter.getStringAttr("NONE");
|
||||
bias =
|
||||
rewriter.create<AddOp>(loc, bias, constant_val, none_af).output();
|
||||
} else {
|
||||
// If there no pre-existing bias and the `constant_val` is 1D, simply
|
||||
// use `constant_val` as bias.
|
||||
bias = constant_val;
|
||||
}
|
||||
} else {
|
||||
auto none_af = rewriter.getStringAttr("NONE");
|
||||
bias = rewriter.create<AddOp>(loc, bias, constant_val, none_af).output();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
|
||||
add_op, add_op.getType(),
|
||||
/*input=*/fc_op.input(),
|
||||
|
@ -23,26 +23,34 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
def F32ElementsAttr : ElementsAttrBase<
|
||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||
|
||||
def ExtractSingleElementAsFloat : NativeCodeCall<
|
||||
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-pattern consisting of matching stand-alone convolution op followed by
|
||||
// activation op.
|
||||
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w)),
|
||||
(TFL_Conv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w)>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)>;
|
||||
$multiplier),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||
@ -58,9 +66,6 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
||||
// op with constant bias, we can fuse the binary op into the convolution op by
|
||||
// constant folding the bias and the binary op's constant operand. The following
|
||||
@ -288,8 +293,9 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $binary_out)]>;
|
||||
}
|
||||
}
|
||||
|
||||
@ -358,6 +364,7 @@ class ValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
|
||||
// ReLU patterns
|
||||
def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input,
|
||||
(ConstantOp $NegOne)),
|
||||
(ConstantOp $One)),
|
||||
@ -370,6 +377,35 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
||||
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
|
||||
$input2),
|
||||
(TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha),
|
||||
[(ConstDoubleValueLessThan<"1"> $alpha),
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
|
||||
// Checks if the operand0's rank is one less than operand1's rank.
|
||||
def PReluAlphaRankCheck : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == "
|
||||
"$1.getType().cast<ShapedType>().getRank() - 1">>;
|
||||
|
||||
// PReLU pattern from Keras:
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def : Pat<(TFL_AddOp
|
||||
(TFL_ReluOp:$relu_out $input1),
|
||||
(TFL_MulOp:$mul_out
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
|
||||
$neg_alpha,
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)]>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
@ -50,7 +50,7 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateQuantizePass();
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareQuantizePass(
|
||||
const QuantizationSpecs& quant_specs);
|
||||
|
||||
// Creates a instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
// Creates an instance of the TensorFlow Lite dialect PostQuantize pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePostQuantizePass(
|
||||
bool emit_quant_adaptor_ops);
|
||||
|
||||
@ -70,14 +70,20 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
|
||||
// pass. The composite op is created from the ophint extraction pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateLegalizeOphintFuncOpPass();
|
||||
|
||||
// Creates an instance of TensorFlow Lite dialect SplitMergedOperandsPass.
|
||||
// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateSplitMergedOperandsPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateOptimizeFunctionalOpsPass();
|
||||
|
||||
// Creates an instance pass to add default quantization parameters.
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to add default
|
||||
// quantization parameters.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDefaultQuantParamsPass(
|
||||
double default_min, double default_max);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
|
||||
// tensor to sparse format.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateDenseToSparsePass();
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -136,7 +136,7 @@ def : Pat<(TF_ReshapeOp
|
||||
// Casts result type of $1 to a quantized type by using the quantization
|
||||
// parameters from the type in $0.
|
||||
class UpdateShapeWithAxis<int i> : NativeCodeCall<
|
||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
|
||||
"quant::CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
|
||||
|
||||
class UsedBy<string op> : Constraint<
|
||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
@ -144,16 +145,16 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
if (auto shaped = input_type.dyn_cast<ShapedType>()) {
|
||||
if (shaped.getElementType().isa<FloatType>()) {
|
||||
auto min_max = GetMinMaxValuesForArgument(func_name, i);
|
||||
TypeAttr params = GetQuantizedTypeAttr(
|
||||
TypeAttr params = quant::GetQuantizedTypeAttr(
|
||||
builder, input_type, builder.getF64FloatAttr(min_max.first),
|
||||
builder.getF64FloatAttr(min_max.second), /*quant_dim=*/-1, num_bits,
|
||||
narrow_range, is_signed);
|
||||
builder.setInsertionPoint(block, insertion_point);
|
||||
auto q_op = builder.create<TFL::QuantizeOp>(loc, params.getValue(), arg,
|
||||
params);
|
||||
auto dq_op =
|
||||
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
||||
arg.replaceAllUsesWith(dq_op.output());
|
||||
auto q_op =
|
||||
builder.create<quant::QuantizeCastOp>(loc, params.getValue(), arg);
|
||||
auto dq_op = builder.create<quant::DequantizeCastOp>(loc, input_type,
|
||||
q_op.getResult());
|
||||
arg.replaceAllUsesWith(dq_op.getResult());
|
||||
q_op.setOperand(arg);
|
||||
}
|
||||
}
|
||||
@ -176,12 +177,14 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
|
||||
}
|
||||
|
||||
using PrepareQuantStats =
|
||||
TFL::ConvertStatsToQDQs<TFL::QuantizeOp, TFL::DequantizeOp>;
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = func.getContext();
|
||||
|
||||
ConvertTFLQuantOpsToMlirQuantOps(func);
|
||||
|
||||
if (quant_specs_.post_training_quantization) {
|
||||
RemoveRedundantStats(func);
|
||||
} else {
|
||||
@ -198,7 +201,7 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
if (is_signed) {
|
||||
patterns.insert<ConvertUnsignedToSigned<TFL::QuantizeOp>>(ctx);
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
||||
@ -213,6 +216,8 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
||||
GetOpQuantSpec);
|
||||
|
||||
ConvertMlirQuantOpsToTFLQuantOps(func);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -81,6 +82,7 @@ class PrepareTFPass : public FunctionPass<PrepareTFPass> {
|
||||
};
|
||||
|
||||
// TODO(fengliuai): move this rule to PreparePatterns.td
|
||||
// TODO(fengliuai): reuse the quantization/tensorflow/tf_to_quant pass.
|
||||
// TODO(b/140968741): propagate the sign from the command line. Currently all
|
||||
// the FakeQuant is assumed to targeting UIN8, but per-channel kernel is
|
||||
// actually INT8.
|
||||
@ -149,9 +151,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
||||
rewriter.getI64IntegerAttr(tf_op.num_bits().getSExtValue());
|
||||
BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range());
|
||||
Type res_type = tf_op.getType();
|
||||
TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_value,
|
||||
max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/false);
|
||||
TypeAttr qtype = quant::GetQuantizedTypeAttr(
|
||||
rewriter, res_type, min_value, max_value, quant_dim, num_bits,
|
||||
narrow_range, /*is_signed=*/false);
|
||||
if (!qtype) this->matchFailure();
|
||||
|
||||
// Finally, use the quantization parameter to create the quantize and
|
||||
@ -503,6 +505,12 @@ void PrepareTFPass::runOnFunction() {
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
// TF FakeQuant ops by the constant folding.
|
||||
patterns.insert<PreparePerTensorFakeQuant, PreparePerChannelFakeQuant>(ctx);
|
||||
|
||||
// This pattern will try to identify and optimize for dilated convolution.
|
||||
// e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
|
||||
// replaced with a single Conv op with dilation parameter.
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
// TODO(karimnosseir): Split to separate pass probably after
|
||||
// deciding on long term plan for this optimization.
|
||||
|
@ -65,8 +65,8 @@ namespace {
|
||||
|
||||
// Full integer quantization rewrite pattern for TFLite.
|
||||
struct TFLFullQuantization
|
||||
: public QuantizationPattern<TFLFullQuantization, QuantizeOp, DequantizeOp,
|
||||
NumericVerifyOp> {
|
||||
: public quant::QuantizationPattern<TFLFullQuantization, QuantizeOp,
|
||||
DequantizeOp, NumericVerifyOp> {
|
||||
explicit TFLFullQuantization(MLIRContext* ctx, bool verify_numeric,
|
||||
float tolerance, bool verify_single_layer)
|
||||
: BaseType(ctx, verify_numeric, tolerance, verify_single_layer) {}
|
||||
|
@ -20,7 +20,7 @@ include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
def QuantizeByQuantizedType : NativeCodeCall<"Quantize($0, $1.getValue())">;
|
||||
def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">;
|
||||
|
||||
// Squash tfl.dequantize and tfl.quantize pairs.
|
||||
// TODO(fengliuai): Compare the scale of input and output. This can also be
|
||||
|
@ -263,6 +263,18 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
// Input dimensions must be defined. MatMulBCast does not support partial
|
||||
// shapes.
|
||||
for (auto dim : lhs_shape) {
|
||||
if (dim == -1) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
}
|
||||
for (auto dim : rhs_shape) {
|
||||
if (dim == -1) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
}
|
||||
// Ensure that batch shapes are broadcastable.
|
||||
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
|
||||
lhs_shape.begin(), lhs_shape.end()),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user