Update from master
This commit is contained in:
commit
7ba6a92ff9
21
.bazelrc
21
.bazelrc
@ -37,7 +37,6 @@
|
|||||||
# v2: Build TF v2
|
# v2: Build TF v2
|
||||||
#
|
#
|
||||||
# Feature and Third party library support options:
|
# Feature and Third party library support options:
|
||||||
# xla: Build TF with XLA
|
|
||||||
# using_cuda: CUDA is available to build system.
|
# using_cuda: CUDA is available to build system.
|
||||||
# cuda: Build with full cuda support.
|
# cuda: Build with full cuda support.
|
||||||
# rocm: Build with AMD GPU support (rocm).
|
# rocm: Build with AMD GPU support (rocm).
|
||||||
@ -227,6 +226,14 @@ build --noincompatible_remove_legacy_whole_archive
|
|||||||
# https://github.com/tensorflow/community/pull/179
|
# https://github.com/tensorflow/community/pull/179
|
||||||
build --noincompatible_prohibit_aapt1
|
build --noincompatible_prohibit_aapt1
|
||||||
|
|
||||||
|
# Enable XLA
|
||||||
|
build --action_env=TF_ENABLE_XLA=1
|
||||||
|
build --define=with_xla_support=true
|
||||||
|
|
||||||
|
# Keep config XLA until all build scripts are cleaned up.
|
||||||
|
build:xla --action_env=TF_ENABLE_XLA=1
|
||||||
|
build:xla --define=with_xla_support=true
|
||||||
|
|
||||||
# Modular TF build options
|
# Modular TF build options
|
||||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||||
@ -312,10 +319,6 @@ build:v2 --action_env=TF2_BEHAVIOR=1
|
|||||||
build --config=v2
|
build --config=v2
|
||||||
test --config=v2
|
test --config=v2
|
||||||
|
|
||||||
# Enable XLA
|
|
||||||
build:xla --action_env=TF_ENABLE_XLA=1
|
|
||||||
build:xla --define=with_xla_support=true
|
|
||||||
|
|
||||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||||
# Options when using remote execution
|
# Options when using remote execution
|
||||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||||
@ -348,7 +351,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
|||||||
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
||||||
|
|
||||||
# Non-rbe settings we should include because we do not run configure
|
# Non-rbe settings we should include because we do not run configure
|
||||||
build:rbe_linux --config=xla
|
|
||||||
build:rbe_linux --config=avx_linux
|
build:rbe_linux --config=avx_linux
|
||||||
build:rbe_linux --config=short_logs
|
build:rbe_linux --config=short_logs
|
||||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||||
@ -386,9 +388,8 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
|||||||
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
||||||
|
|
||||||
build:rbe_linux_py3 --config=rbe_linux
|
build:rbe_linux_py3 --config=rbe_linux
|
||||||
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
|
|
||||||
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||||
|
|
||||||
build:rbe_win --config=rbe
|
build:rbe_win --config=rbe
|
||||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||||
@ -405,9 +406,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
|
|||||||
build:rbe_win --jobs=500
|
build:rbe_win --jobs=500
|
||||||
|
|
||||||
build:rbe_win_py37 --config=rbe
|
build:rbe_win_py37 --config=rbe
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
|
||||||
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
|
||||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
|
||||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||||
|
|
||||||
build:rbe_win_py38 --config=rbe
|
build:rbe_win_py38 --config=rbe
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
|
|||||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||||
/tensorflow/lite/gen/**
|
/tensorflow/lite/gen/**
|
||||||
/tensorflow/lite/tools/make/downloads/**
|
/tensorflow/lite/tools/make/downloads/**
|
||||||
|
/tensorflow/lite/tools/make/gen/**
|
||||||
/api_init_files_list.txt
|
/api_init_files_list.txt
|
||||||
/estimator_api_init_files_list.txt
|
/estimator_api_init_files_list.txt
|
||||||
*.whl
|
*.whl
|
||||||
|
@ -70,7 +70,7 @@ $ python
|
|||||||
3
|
3
|
||||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||||
>>> hello.numpy()
|
>>> hello.numpy()
|
||||||
'Hello, TensorFlow!'
|
b'Hello, TensorFlow!'
|
||||||
```
|
```
|
||||||
|
|
||||||
For more examples, see the
|
For more examples, see the
|
||||||
|
@ -1390,10 +1390,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||||
|
|
||||||
xla_enabled_by_default = is_linux() or is_macos()
|
|
||||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
|
||||||
xla_enabled_by_default, 'xla')
|
|
||||||
|
|
||||||
set_action_env_var(
|
set_action_env_var(
|
||||||
environ_cp,
|
environ_cp,
|
||||||
'TF_NEED_OPENCL_SYCL',
|
'TF_NEED_OPENCL_SYCL',
|
||||||
|
@ -205,6 +205,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"//tensorflow/core/platform:casts",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -874,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
#if defined(IS_MOBILE_PLATFORM)
|
#if defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::Status::OK();
|
||||||
#else // !defined(IS_MOBILE_PLATFORM)
|
#else // !defined(IS_MOBILE_PLATFORM)
|
||||||
status->status = ctx->context->ClearRemoteExecutors();
|
status->status = ctx->context->SyncExecutors();
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1450,6 +1450,25 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||||
|
const void* proto, size_t proto_len,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::AttrValue attr_value;
|
||||||
|
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
||||||
|
status->status =
|
||||||
|
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (op == nullptr || op->operation == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Got a null or uninitialized `op` argument");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
op->operation.get());
|
||||||
|
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||||
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||||
const char* input_name,
|
const char* input_name,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
@ -1606,7 +1625,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
|||||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
op->operation.get());
|
op->operation.get());
|
||||||
*attrs = TFE_OpAttrs(&operation->Attrs());
|
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||||
@ -1620,6 +1639,14 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||||
|
TF_Status* status) {
|
||||||
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
|
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||||
|
name_and_attrs.set_name(attrs->name);
|
||||||
|
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||||
|
}
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
const tensorflow::AttrValue& default_value,
|
const tensorflow::AttrValue& default_value,
|
||||||
@ -1740,7 +1767,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|||||||
}
|
}
|
||||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||||
TF_Status status;
|
TF_Status status;
|
||||||
TFE_OpAttrs attributes(&op->Attrs());
|
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||||
if (status.status.ok()) {
|
if (status.status.ok()) {
|
||||||
|
@ -382,9 +382,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|||||||
const char* worker_name,
|
const char* worker_name,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
// Clear pending streaming requests and error statuses on remote executors.
|
// Sync pending nodes in local executors (including the context default executor
|
||||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
// and thread executors) and streaming requests to remote executors, and get the
|
||||||
TF_Status* status);
|
// combined status.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
// If the TensorHandle is copied to another device as part of an op execution,
|
// If the TensorHandle is copied to another device as part of an op execution,
|
||||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||||
@ -441,6 +443,21 @@ TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
|||||||
// Does not overwrite or update existing attributes, but adds new ones.
|
// Does not overwrite or update existing attributes, but adds new ones.
|
||||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||||
|
|
||||||
|
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
|
||||||
|
// containing the op name and a map of its attributes.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
|
||||||
|
TF_Buffer* buf,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Set an op's attribute from a serialized AttrValue protocol buffer.
|
||||||
|
//
|
||||||
|
// Analogous to TF_SetAttrValueProto for building graph operations.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||||
|
const char* attr_name,
|
||||||
|
const void* proto,
|
||||||
|
size_t proto_len,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||||
|
|
||||||
// Struct to be filled in
|
// Struct to be filled in
|
||||||
|
@ -236,12 +236,16 @@ struct TFE_Executor {
|
|||||||
tensorflow::EagerExecutor* unowned_executor;
|
tensorflow::EagerExecutor* unowned_executor;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||||
|
// that sometimes do not require serialization.
|
||||||
struct TFE_OpAttrs {
|
struct TFE_OpAttrs {
|
||||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||||
|
|
||||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||||
: attributes(value) {}
|
const char* op_name)
|
||||||
|
: name(op_name), attributes(value) {}
|
||||||
|
|
||||||
|
const char* name;
|
||||||
const tensorflow::AttrBuilder* attributes;
|
const tensorflow::AttrBuilder* attributes;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
|||||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||||
|
|
||||||
void TestRemoteExecuteSilentCopies(bool async) {
|
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
// This server def has the task index set to 0.
|
||||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
auto* h1_task2 =
|
auto* h1_task2 =
|
||||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
// Handles are on task0 (local), and task2, but op is on task1.
|
// Handles are on task0 (local), and task2, but op is on task1.
|
||||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||||
TFE_OpSetDevice(matmul, task1_name, status);
|
if (remote) {
|
||||||
|
TFE_OpSetDevice(matmul, task1_name, status);
|
||||||
|
}
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
TFE_TensorHandle* retvals[1];
|
||||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// TODO(gjn): Add support for waiting on async local mirrors
|
||||||
|
if (!async) {
|
||||||
|
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||||
|
h1_task2->handle.get())
|
||||||
|
->Handle();
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
matmul->operation.get());
|
||||||
|
// The input handles should never change since they have been mirrored.
|
||||||
|
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||||
|
}
|
||||||
|
|
||||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
|||||||
worker_server2.release();
|
worker_server2.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, true);
|
||||||
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||||
TestRemoteExecuteSilentCopies(true);
|
TestRemoteExecuteSilentCopies(true, true);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||||
|
TestRemoteExecuteSilentCopies(false, false);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||||
|
TestRemoteExecuteSilentCopies(true, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
|
@ -416,12 +416,23 @@ void TensorHandleSilentCopy(bool async,
|
|||||||
hgpu->handle.get())
|
hgpu->handle.get())
|
||||||
->Handle();
|
->Handle();
|
||||||
|
|
||||||
// The input handles should never change since they have been mirrored.
|
|
||||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
matmul->operation.get());
|
matmul->operation.get());
|
||||||
ASSERT_EQ(op->GetInput(0), arg0);
|
if (!async) {
|
||||||
ASSERT_EQ(op->GetInput(1), arg1);
|
// The input handles should never change since they have been mirrored.
|
||||||
|
ASSERT_EQ(op->GetInput(0), arg0);
|
||||||
|
ASSERT_EQ(op->GetInput(1), arg1);
|
||||||
|
} else {
|
||||||
|
if (cpu_op) {
|
||||||
|
ASSERT_EQ(op->GetInput(0), arg0);
|
||||||
|
// The GPU handle should be replaced with a CPU copy
|
||||||
|
ASSERT_NE(op->GetInput(1), arg1);
|
||||||
|
} else {
|
||||||
|
// The CPU handle should be replaced with a GPU copy
|
||||||
|
ASSERT_NE(op->GetInput(0), arg0);
|
||||||
|
ASSERT_EQ(op->GetInput(1), arg1);
|
||||||
|
}
|
||||||
|
}
|
||||||
TFE_DeleteOp(matmul);
|
TFE_DeleteOp(matmul);
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
TFE_DeleteTensorHandle(hgpu);
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
@ -1578,4 +1589,52 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
|||||||
TFE_DeleteContext(ctx);
|
TFE_DeleteContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||||
|
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||||
|
TFE_OpAttrs attributes;
|
||||||
|
TFE_OpGetAttrs(var_op, &attributes);
|
||||||
|
|
||||||
|
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||||
|
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
tensorflow::NameAttrList name_and_attrs;
|
||||||
|
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||||
|
serialized_attr_values->length));
|
||||||
|
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
|
||||||
|
ASSERT_EQ(tensorflow::DT_INT64,
|
||||||
|
name_and_attrs.attr().find("dtype")->second.type());
|
||||||
|
TF_DeleteBuffer(serialized_attr_values);
|
||||||
|
|
||||||
|
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||||
|
|
||||||
|
string serialized_dtype;
|
||||||
|
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||||
|
&serialized_dtype));
|
||||||
|
TFE_OpSetAttrValueProto(
|
||||||
|
second_var_op, "dtype",
|
||||||
|
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||||
|
serialized_dtype.length(), status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
tensorflow::AttrValueMap attr_values;
|
||||||
|
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||||
|
second_var_op->operation.get());
|
||||||
|
op->Attrs().FillAttrValueMap(&attr_values);
|
||||||
|
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_DeleteOp(var_op);
|
||||||
|
TFE_DeleteOp(second_var_op);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/platform:resource_loader",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/path.h"
|
||||||
|
#include "tensorflow/core/platform/resource_loader.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kTestDataPbTxt[] =
|
string TestDataPbTxt() {
|
||||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
constexpr char kTestDataSharded[] =
|
"half_plus_two_pbtxt", "00000123");
|
||||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
}
|
||||||
|
|
||||||
|
string TestDataSharded() {
|
||||||
|
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||||
|
"half_plus_two", "00000123");
|
||||||
|
}
|
||||||
|
|
||||||
class ReaderTest : public ::testing::Test {
|
class ReaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
|||||||
TEST_F(ReaderTest, TagMatch) {
|
TEST_F(ReaderTest, TagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatch) {
|
TEST_F(ReaderTest, NoTagMatch) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
|||||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
|||||||
TEST_F(ReaderTest, PbtxtFormat) {
|
TEST_F(ReaderTest, PbtxtFormat) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
|
||||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def));
|
&meta_graph_def));
|
||||||
CheckMetaGraphDef(meta_graph_def);
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
|||||||
TEST_F(ReaderTest, InvalidExportPath) {
|
TEST_F(ReaderTest, InvalidExportPath) {
|
||||||
MetaGraphDef meta_graph_def;
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
const string export_dir =
|
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
|
||||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
&meta_graph_def);
|
&meta_graph_def);
|
||||||
EXPECT_FALSE(st.ok());
|
EXPECT_FALSE(st.ok());
|
||||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# TensorFlow MLIR
|
||||||
|
|
||||||
|
These are the docs for: https://www.tensorflow.org/mlir
|
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
upper_tabs:
|
||||||
|
# Tabs left of dropdown menu
|
||||||
|
- include: /_upper_tabs_left.yaml
|
||||||
|
- include: /api_docs/_upper_tabs_api.yaml
|
||||||
|
# Dropdown menu
|
||||||
|
- name: Resources
|
||||||
|
path: /resources
|
||||||
|
is_default: true
|
||||||
|
menu:
|
||||||
|
- include: /resources/_menu_toc.yaml
|
||||||
|
lower_tabs:
|
||||||
|
# Subsite tabs
|
||||||
|
other:
|
||||||
|
- name: Guide
|
||||||
|
contents:
|
||||||
|
- title: Overview
|
||||||
|
path: /mlir/overview
|
||||||
|
- heading: Dialects
|
||||||
|
- title: Overview
|
||||||
|
path: /mlir/dialects
|
||||||
|
- title: TensorFlow
|
||||||
|
path: /mlir/tf_ops
|
||||||
|
- title: TensorFlow Lite
|
||||||
|
path: /mlir/tfl_ops
|
||||||
|
|
||||||
|
- include: /_upper_tabs_right.yaml
|
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
book_path: /mlir/_book.yaml
|
||||||
|
project_path: /mlir/_project.yaml
|
||||||
|
description: <!--no description-->
|
||||||
|
landing_page:
|
||||||
|
custom_css_path: /site-assets/css/style.css
|
||||||
|
rows:
|
||||||
|
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||||
|
items:
|
||||||
|
- description: >
|
||||||
|
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
|
||||||
|
intermediate representation (IR) that unifies the infrastructure required to execute high
|
||||||
|
performance machine learning models in TensorFlow and similar ML frameworks. This project
|
||||||
|
will include the application of HPC techniques, along with integration of
|
||||||
|
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||||
|
cost to bring up new hardware, and improve usability for existing
|
||||||
|
TensorFlow users.
|
||||||
|
|
||||||
|
- code_block: |
|
||||||
|
<pre class = "prettyprint">
|
||||||
|
// Syntactically similar to LLVM:
|
||||||
|
func @testFunction(%arg0: i32) {
|
||||||
|
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||||
|
br ^bb1
|
||||||
|
^bb1:
|
||||||
|
%y = addi %x, %x : i32
|
||||||
|
return %y : i32
|
||||||
|
}
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
- classname: devsite-landing-row-cards
|
||||||
|
items:
|
||||||
|
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||||
|
youtube_id: qzljG6DKgic
|
||||||
|
buttons:
|
||||||
|
- label: Watch the video
|
||||||
|
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||||
|
- heading: "A new intermediate representation and compiler framework"
|
||||||
|
image_path: /resources/images/tf-logo-card-16x9.png
|
||||||
|
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||||
|
buttons:
|
||||||
|
- label: Read on TensorFlow blog
|
||||||
|
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||||
|
- heading: MLIR on GitHub
|
||||||
|
image_path: /resources/images/github-card-16x9.png
|
||||||
|
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||||
|
buttons:
|
||||||
|
- label: View on GitHub
|
||||||
|
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||||
|
- heading: TensorFlow MLIR on GitHub
|
||||||
|
image_path: /resources/images/github-card-16x9.png
|
||||||
|
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
||||||
|
buttons:
|
||||||
|
- label: View on GitHub
|
||||||
|
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# MLIR dialects
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
|
||||||
|
To separate different hardware and software targets, MLIR has “dialects”,
|
||||||
|
including:
|
||||||
|
|
||||||
|
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
|
||||||
|
* XLA HLO IR, which is designed to take advantage of XLA’s compilation
|
||||||
|
abilities (with output to, among other things, TPUs).
|
||||||
|
* An experimental affine dialect, which focuses on
|
||||||
|
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
|
||||||
|
and optimizations.
|
||||||
|
* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation,
|
||||||
|
allowing MLIR to emit GPU and CPU code through LLVM.
|
||||||
|
* TensorFlow Lite, which will translate to running code on mobile platforms.
|
||||||
|
|
||||||
|
Each dialect consists of a set of defined operations which have invariants
|
||||||
|
placed on them, like: “This is a binary operator, and the inputs and outputs
|
||||||
|
have the same types.”
|
||||||
|
|
||||||
|
## Adding to MLIR
|
||||||
|
|
||||||
|
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
|
||||||
|
Dialects can define entirely custom types, which is how MLIR can model things
|
||||||
|
like the LLVM IR type system (which has first class aggregates), domain
|
||||||
|
abstractions important for ML-optimized accelerators like quantized types, and
|
||||||
|
even the Swift or Clang type systems (which are built around Swift/Clang
|
||||||
|
declaration nodes) in the future.
|
||||||
|
|
||||||
|
If you want to connect a new low-level compiler, you would create a new dialect
|
||||||
|
and the lowerings between the TensorFlow Graph dialect and your dialect.
|
||||||
|
This smooths the path for hardware and compiler makers. You can even target
|
||||||
|
dialects at different levels in the same model; the higher-level optimizers
|
||||||
|
will respect the unfamiliar parts of the IR and wait for a lower level to handle
|
||||||
|
it.
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# MLIR
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
MLIR, or Multi-Level Intermediate Representation, is a representation format
|
||||||
|
and library of compiler utilities that sits between the model representation
|
||||||
|
and low-level compilers/executors that generate hardware-specific code.
|
||||||
|
|
||||||
|
MLIR is, at its heart, a flexible infrastructure for modern optimizing
|
||||||
|
compilers. This means it consists of a specification for intermediate
|
||||||
|
representations (IR) and a code toolkit to perform transformations on that
|
||||||
|
representation. (In compiler parlance, as you move from higher-level
|
||||||
|
representations to lower-level representations, these transformations can be
|
||||||
|
called “lowerings”)
|
||||||
|
|
||||||
|
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
|
||||||
|
many great ideas from it. It has a flexible type system, and allows
|
||||||
|
representing, analyzing and transforming graphs combining multiple levels of
|
||||||
|
abstraction in the same compilation unit. These abstractions include TensorFlow
|
||||||
|
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
|
||||||
|
hardware operations and types.
|
||||||
|
|
||||||
|
We expect MLIR to be of interest to many groups, including:
|
||||||
|
|
||||||
|
* Compiler researchers and implementers looking to optimize performance and
|
||||||
|
memory consumption of machine learning models
|
||||||
|
* Hardware makers looking for a way to connect their hardware to TensorFlow,
|
||||||
|
such as TPUs, portable neural hardware in phones, and other custom ASICs
|
||||||
|
* People writing language bindings that want to take advantage of optimizing
|
||||||
|
compilers and hardware acceleration.
|
||||||
|
|
||||||
|
The TensorFlow ecosystem contains a number of compilers and optimizers that
|
||||||
|
operate at multiple levels of the software and hardware stack. We expect the
|
||||||
|
gradual adoption of MLIR to simplify every aspect of this stack.
|
||||||
|
|
||||||
|
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
@ -602,6 +602,7 @@ tf_cc_binary(
|
|||||||
name = "flatbuffer_translate",
|
name = "flatbuffer_translate",
|
||||||
deps = [
|
deps = [
|
||||||
":flatbuffer_translate_lib",
|
":flatbuffer_translate_lib",
|
||||||
|
"@llvm-project//mlir:LoopOpsTransforms",
|
||||||
"@llvm-project//mlir:MlirTranslateMain",
|
"@llvm-project//mlir:MlirTranslateMain",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -46,7 +46,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||||
@ -76,6 +76,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/lite/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
|||||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<std::string> input_arrays_flag(
|
||||||
|
"input-arrays",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"List of input tensors, if different from the default inputs"),
|
||||||
|
llvm::cl::init(""));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static opt<std::string> output_arrays_flag(
|
||||||
|
"output-arrays",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"List of output tensors, if different from the default outputs"),
|
||||||
|
llvm::cl::init(""));
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
bool IsScalar(const TensorT& tensor) {
|
bool IsScalar(const TensorT& tensor) {
|
||||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||||
@ -590,6 +605,11 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
op_state.addTypes({type});
|
op_state.addTypes({type});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op_name == "tfl.lstm") {
|
||||||
|
// TODO(b/147587779): add the right region if region is empty.
|
||||||
|
op_state.addRegion();
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||||
if (IsCustomOp(op_name)) {
|
if (IsCustomOp(op_name)) {
|
||||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||||
@ -610,43 +630,30 @@ StatusOr<Operation*> ConvertOp(
|
|||||||
return builder.createOperation(op_state);
|
return builder.createOperation(op_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the output tensor indices for the given subgraph. If
|
// Returns indices of the given tensors in the subgraph. Returns error if a
|
||||||
// ordered_output_arrays is provided, then return the tensor indices in
|
// tensor name cannot be found in the subgraph.
|
||||||
// ordered_output_arrays.
|
StatusOr<std::vector<int>> GetTensorIndices(
|
||||||
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
|
const tflite::SubGraphT& subgraph,
|
||||||
const tflite::SubGraphT& subgraph, Location base_loc,
|
const std::vector<std::string>& tensor_names) {
|
||||||
const std::vector<std::string>& ordered_output_arrays) {
|
absl::flat_hash_map<std::string, int> name_to_index;
|
||||||
if (ordered_output_arrays.empty()) {
|
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
|
||||||
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
|
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
|
||||||
subgraph.outputs.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int32_t, 4> outputs;
|
std::vector<int> indices;
|
||||||
outputs.resize(ordered_output_arrays.size());
|
indices.reserve(tensor_names.size());
|
||||||
absl::flat_hash_map<std::string, int> output_order_map;
|
|
||||||
for (auto output : llvm::enumerate(ordered_output_arrays)) {
|
|
||||||
output_order_map[output.value()] = output.index();
|
|
||||||
}
|
|
||||||
|
|
||||||
int tensor_index = 0;
|
for (const auto& name : tensor_names) {
|
||||||
int found_output_tensors = 0;
|
auto found = name_to_index.find(name);
|
||||||
for (const auto& tensor : subgraph.tensors) {
|
if (found != name_to_index.end()) {
|
||||||
auto found = output_order_map.find(tensor->name);
|
indices.push_back(found->second);
|
||||||
if (found != output_order_map.end()) {
|
} else {
|
||||||
const int output_index = found->second;
|
return errors::InvalidArgument("could not find tensor in subgraph: ",
|
||||||
outputs[output_index] = tensor_index;
|
name);
|
||||||
++found_output_tensors;
|
|
||||||
}
|
}
|
||||||
++tensor_index;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (found_output_tensors != ordered_output_arrays.size()) {
|
return indices;
|
||||||
auto err = errors::InvalidArgument(
|
|
||||||
"cannot find all nodes in ordered_output_arrays");
|
|
||||||
return emitError(base_loc, err.ToString()), err;
|
|
||||||
}
|
|
||||||
|
|
||||||
return outputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a list of tensor indices, returns a string of concatenated tensor names
|
// Given a list of tensor indices, returns a string of concatenated tensor names
|
||||||
@ -661,15 +668,18 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
|||||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
// Traverses the subgraph from output_indices to input_indices and returns the
|
||||||
// ops that are ancestors of the output tensors.
|
// set of ops that are visited.
|
||||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
|
||||||
|
ArrayRef<int32_t> output_indices) {
|
||||||
// Create a map from tensor index to defining op.
|
// Create a map from tensor index to defining op.
|
||||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||||
for (const auto& op : subgraph.operators) {
|
for (const auto& op : subgraph.operators) {
|
||||||
for (int32_t output : op->outputs) {
|
for (int32_t output : op->outputs) {
|
||||||
defining_op[output] = op.get();
|
if (!llvm::is_contained(input_indices, output)) {
|
||||||
|
defining_op[output] = op.get();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -718,18 +728,40 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
const std::vector<std::string>& op_names,
|
const std::vector<std::string>& op_names,
|
||||||
const std::vector<std::string>& func_names,
|
const std::vector<std::string>& func_names,
|
||||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||||
Location base_loc, Builder builder,
|
Location base_loc, Builder builder, bool is_entry_point,
|
||||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
|
||||||
bool use_external_constant,
|
bool use_external_constant,
|
||||||
|
const std::vector<std::string>& ordered_input_arrays,
|
||||||
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||||
|
|
||||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
||||||
|
|
||||||
// Construct function type
|
std::vector<int> func_inputs = subgraph.inputs;
|
||||||
for (auto input : subgraph.inputs) {
|
if (is_entry_point && !ordered_input_arrays.empty()) {
|
||||||
auto& tensor = *subgraph.tensors.at(input);
|
if (!experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
|
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"input-arrays should be used with experimental pruning flag");
|
||||||
|
}
|
||||||
|
TF_ASSIGN_OR_RETURN(func_inputs,
|
||||||
|
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add state variables to inputs.
|
||||||
|
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||||
|
func_inputs.end());
|
||||||
|
for (int i = 0; i < subgraph.tensors.size(); i++) {
|
||||||
|
auto& tensor = *subgraph.tensors.at(i);
|
||||||
|
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||||
|
func_inputs.emplace_back(i);
|
||||||
|
input_index_set.insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto input_or_variable : func_inputs) {
|
||||||
|
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||||
// but we cannot differentiate scalars from unranked tensors.
|
// but we cannot differentiate scalars from unranked tensors.
|
||||||
// Here we reverse the default assumption that shape = [] means unranked.
|
// Here we reverse the default assumption that shape = [] means unranked.
|
||||||
@ -753,9 +785,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
std::vector<int> func_outputs = subgraph.outputs;
|
||||||
auto func_outputs,
|
if (is_entry_point && !ordered_output_arrays.empty()) {
|
||||||
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
|
TF_ASSIGN_OR_RETURN(func_outputs,
|
||||||
|
GetTensorIndices(subgraph, ordered_output_arrays));
|
||||||
|
}
|
||||||
|
|
||||||
for (auto output : func_outputs) {
|
for (auto output : func_outputs) {
|
||||||
bool is_constant = !is_op_output[output];
|
bool is_constant = !is_op_output[output];
|
||||||
@ -782,8 +816,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
Value maybe_optional_arg_marker = nullptr;
|
Value maybe_optional_arg_marker = nullptr;
|
||||||
|
|
||||||
// Get or construct MLIR values for each input
|
// Get or construct MLIR values for each input
|
||||||
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
|
for (int i = 0, e = func_inputs.size(); i < e; i++) {
|
||||||
auto input_tensor = subgraph.inputs[i];
|
auto input_tensor = func_inputs[i];
|
||||||
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
||||||
auto loc = TensorLoc(tensor, builder, base_loc);
|
auto loc = TensorLoc(tensor, builder, base_loc);
|
||||||
if (vals_map[input_tensor]) {
|
if (vals_map[input_tensor]) {
|
||||||
@ -806,9 +840,9 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
// Set tf.entry_function attribute
|
// Set tf.entry_function attribute
|
||||||
if (is_entry_point) {
|
if (is_entry_point) {
|
||||||
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
||||||
if (!subgraph.inputs.empty()) {
|
if (!func_inputs.empty()) {
|
||||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||||
subgraph, &builder, "inputs", subgraph.inputs));
|
subgraph, &builder, "inputs", func_inputs));
|
||||||
}
|
}
|
||||||
if (!func_outputs.empty()) {
|
if (!func_outputs.empty()) {
|
||||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||||
@ -820,7 +854,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
|||||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||||
PruneSubgraph(subgraph, func_outputs));
|
PruneSubgraph(subgraph, func_inputs, func_outputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Construct MLIR operators from TFLite operators
|
// Construct MLIR operators from TFLite operators
|
||||||
@ -931,8 +965,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
|||||||
|
|
||||||
OwningModuleRef tflite::FlatBufferToMlir(
|
OwningModuleRef tflite::FlatBufferToMlir(
|
||||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
|
||||||
bool use_external_constant,
|
bool use_external_constant,
|
||||||
|
const std::vector<std::string>& ordered_input_arrays,
|
||||||
|
const std::vector<std::string>& ordered_output_arrays,
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||||
auto model_ptr =
|
auto model_ptr =
|
||||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||||
@ -971,33 +1006,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
|||||||
builder.getStringAttr(model->description));
|
builder.getStringAttr(model->description));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
|
|
||||||
// TODO(b/141485522): support more than one subgraph.
|
|
||||||
return emitError(base_loc,
|
|
||||||
"ordered_output_arrays does not support more than one "
|
|
||||||
"subgraph yet"),
|
|
||||||
nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto e : llvm::enumerate(model->subgraphs)) {
|
for (auto e : llvm::enumerate(model->subgraphs)) {
|
||||||
auto& subgraph = e.value();
|
auto& subgraph = e.value();
|
||||||
std::string name = SubgraphName(e.index(), *subgraph);
|
std::string name = SubgraphName(e.index(), *subgraph);
|
||||||
auto func_or_error = ConvertSubgraph(
|
auto func_or_error = ConvertSubgraph(
|
||||||
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
||||||
// Only the entry point needs pseudo_input_ops
|
builder,
|
||||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||||
builder, ordered_output_arrays,
|
|
||||||
/*is_entry_point=*/e.index() == 0,
|
/*is_entry_point=*/e.index() == 0,
|
||||||
/*use_external_constant=*/use_external_constant,
|
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||||
|
ordered_output_arrays,
|
||||||
experimental_prune_unreachable_nodes_unconditionally);
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
if (!func_or_error.ok()) {
|
if (!func_or_error.ok()) {
|
||||||
return emitError(base_loc, "could not translate function ")
|
return emitError(base_loc, "could not translate function ")
|
||||||
<< subgraph->name,
|
<< subgraph->name << ": "
|
||||||
|
<< func_or_error.status().error_message(),
|
||||||
nullptr;
|
nullptr;
|
||||||
}
|
}
|
||||||
module.push_back(func_or_error.ConsumeValueOrDie());
|
module.push_back(func_or_error.ConsumeValueOrDie());
|
||||||
}
|
}
|
||||||
// TFLite subgraphs do not necessarily have names,
|
|
||||||
|
|
||||||
return OwningModuleRef(module);
|
return OwningModuleRef(module);
|
||||||
}
|
}
|
||||||
@ -1012,17 +1039,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
|||||||
auto loc =
|
auto loc =
|
||||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||||
|
|
||||||
// Parses output_arrays_order from command line option.
|
// Parses input/output names from command line options.
|
||||||
|
std::vector<std::string> inputs;
|
||||||
std::vector<std::string> outputs;
|
std::vector<std::string> outputs;
|
||||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
// Use output parser since we only have tensor names.
|
||||||
|
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
|
||||||
|
return emitError(loc, "parsing input array info failed ")
|
||||||
|
<< input_arrays_flag,
|
||||||
|
nullptr;
|
||||||
|
}
|
||||||
|
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
|
||||||
return emitError(loc, "parsing output array info failed ")
|
return emitError(loc, "parsing output array info failed ")
|
||||||
<< output_arrays_string,
|
<< output_arrays_flag,
|
||||||
nullptr;
|
nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return tflite::FlatBufferToMlir(
|
return tflite::FlatBufferToMlir(
|
||||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||||
context, loc, outputs, use_external_constant,
|
context, loc, use_external_constant, inputs, outputs,
|
||||||
experimental_prune_unreachable_nodes_unconditionally);
|
experimental_prune_unreachable_nodes_unconditionally);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,9 +35,9 @@ namespace tflite {
|
|||||||
// are not ancestors of the output nodes will be pruned.
|
// are not ancestors of the output nodes will be pruned.
|
||||||
mlir::OwningModuleRef FlatBufferToMlir(
|
mlir::OwningModuleRef FlatBufferToMlir(
|
||||||
absl::string_view buffer, mlir::MLIRContext* context,
|
absl::string_view buffer, mlir::MLIRContext* context,
|
||||||
mlir::Location base_loc,
|
mlir::Location base_loc, bool use_external_constant = false,
|
||||||
const std::vector<std::string>& ordered_output_arrays,
|
const std::vector<std::string>& ordered_input_arrays = {},
|
||||||
bool use_external_constant = false,
|
const std::vector<std::string>& ordered_output_arrays = {},
|
||||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/ToolOutputFile.h"
|
#include "llvm/Support/ToolOutputFile.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
@ -122,8 +122,6 @@ bool emit_custom_ops;
|
|||||||
bool emit_select_tf_ops;
|
bool emit_select_tf_ops;
|
||||||
bool lower_tensor_list_ops;
|
bool lower_tensor_list_ops;
|
||||||
bool strip_debug_info;
|
bool strip_debug_info;
|
||||||
// NOLINTNEXTLINE
|
|
||||||
std::string output_arrays_string;
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static opt<bool, true> emit_builtin_tflite_ops_flag(
|
static opt<bool, true> emit_builtin_tflite_ops_flag(
|
||||||
@ -156,11 +154,6 @@ static opt<bool, true> strip_debug_info_flag(
|
|||||||
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
|
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
|
||||||
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
|
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
|
||||||
static opt<std::string, true> output_arrays_flag(
|
|
||||||
"output-arrays", llvm::cl::desc("List of output tensors"),
|
|
||||||
llvm::cl::location(output_arrays_string), llvm::cl::init(""));
|
|
||||||
|
|
||||||
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
|
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
|
||||||
|
|
||||||
// Use initial buffer size in flatbuffer builder to be same as the initial size
|
// Use initial buffer size in flatbuffer builder to be same as the initial size
|
||||||
@ -172,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
|
|||||||
// `isSigned` is set to false for other types.
|
// `isSigned` is set to false for other types.
|
||||||
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||||
bool is_signed = true) {
|
bool is_signed = true) {
|
||||||
if (!is_signed && type.isInteger(8)) {
|
if (!is_signed && type.isSignlessInteger(8)) {
|
||||||
return tflite::TensorType_UINT8;
|
return tflite::TensorType_UINT8;
|
||||||
}
|
}
|
||||||
if (!is_signed) {
|
if (!is_signed) {
|
||||||
|
@ -27,7 +27,5 @@ extern bool emit_custom_ops;
|
|||||||
extern bool lower_tensor_list_ops;
|
extern bool lower_tensor_list_ops;
|
||||||
// The flag to control whether debug info gets stripped on export.
|
// The flag to control whether debug info gets stripped on export.
|
||||||
extern bool strip_debug_info;
|
extern bool strip_debug_info;
|
||||||
// The flag to control the output array info of tflite graph.
|
|
||||||
extern std::string output_arrays_string;
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||||
@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp(
|
|||||||
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
||||||
float_calculate, is_commutative);
|
float_calculate, is_commutative);
|
||||||
|
|
||||||
if (elemType.isa<IntegerType>())
|
if (elemType.isSignlessInteger())
|
||||||
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
||||||
int_calculate, is_commutative);
|
int_calculate, is_commutative);
|
||||||
|
|
||||||
@ -723,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Make sure all inputs have the same shape and element type.
|
// Make sure all inputs have the same shape and element type.
|
||||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
// TODO(b/135032063): Simplify once fixed.
|
||||||
for (Value operand : op.getOperands()) {
|
for (Type operand_type : op.getOperandTypes()) {
|
||||||
auto other_type = operand.getType().cast<ShapedType>();
|
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
|
||||||
if (input_type != other_type)
|
|
||||||
return op.emitOpError("operands should be of the same type. got ")
|
return op.emitOpError("operands should be of the same type. got ")
|
||||||
<< input_type << ", " << other_type;
|
<< input_type << ", " << operand_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -1561,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
limit_tensor.getType().getRank() == 0 &&
|
limit_tensor.getType().getRank() == 0 &&
|
||||||
delta_tensor.getType().getRank() == 0);
|
delta_tensor.getType().getRank() == 0);
|
||||||
Type elem_type = getType().cast<ShapedType>().getElementType();
|
Type elem_type = getType().cast<ShapedType>().getElementType();
|
||||||
if (elem_type.isa<IntegerType>()) {
|
if (elem_type.isSignlessInteger()) {
|
||||||
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
||||||
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
||||||
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
||||||
@ -1663,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
// Do not try to fold elements attr of a quant type because
|
// Do not try to fold elements attr of a quant type because
|
||||||
// DenseElementsAttr does not support it.
|
// DenseElementsAttr does not support it.
|
||||||
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
|
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
assert(perm_tensor.getType().getRank() == 1);
|
assert(perm_tensor.getType().getRank() == 1);
|
||||||
|
@ -1656,7 +1656,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
|||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> {
|
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
|
||||||
let summary = "Mean operator";
|
let summary = "Mean operator";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2482,11 +2482,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
|
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$input,
|
||||||
TFL_I32OrI64Tensor:$multiples);
|
TFL_I32OrI64Tensor:$multiples);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
|
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$output);
|
||||||
|
|
||||||
let hasOptions = 0;
|
let hasOptions = 0;
|
||||||
}
|
}
|
||||||
|
@ -63,6 +63,41 @@ const char kDetectionPostProcessOp[] =
|
|||||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||||
|
|
||||||
|
const char kUnidirectionalSequenceLstmOp[] =
|
||||||
|
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
|
||||||
|
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
|
||||||
|
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||||
|
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
|
||||||
|
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
|
||||||
|
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
|
||||||
|
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
|
||||||
|
"name: 'InputCellStateTensor' type: DT_FLOAT } "
|
||||||
|
"output_arg: { name: 'Concat' type: DT_FLOAT} "
|
||||||
|
"output_arg: { name: "
|
||||||
|
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||||
|
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||||
|
|
||||||
|
const char kUnidirectionalSequenceRnnOp[] =
|
||||||
|
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||||
|
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||||
|
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||||
|
"name: 'Bias' type: DT_FLOAT} "
|
||||||
|
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||||
|
"output_arg: { name: "
|
||||||
|
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||||
|
"DT_FLOAT} "
|
||||||
|
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||||
|
|
||||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||||
// conversion mapping for constants defined in TFLite Python API.
|
// conversion mapping for constants defined in TFLite Python API.
|
||||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||||
@ -260,6 +295,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
|||||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||||
toco_flags.custom_opdefs().end());
|
toco_flags.custom_opdefs().end());
|
||||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||||
|
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||||
|
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
|
@ -61,11 +61,9 @@ TfLiteStatus QuantizeModel(
|
|||||||
std::string serialized_model(
|
std::string serialized_model(
|
||||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||||
input_builder.GetSize());
|
input_builder.GetSize());
|
||||||
std::vector<std::string> output_arrays_order;
|
|
||||||
|
|
||||||
OwningModuleRef module =
|
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
UnknownLoc::get(&context));
|
||||||
UnknownLoc::get(&context), output_arrays_order);
|
|
||||||
if (!module) {
|
if (!module) {
|
||||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||||
inputs.push_back(op_inst.input());
|
inputs.push_back(op_inst.input());
|
||||||
} else if (ele_type.isa<IntegerType>()) {
|
} else if (ele_type.isSignlessInteger()) {
|
||||||
// If the operand is an integer tensor, then it doesn't require the
|
// If the operand is an integer tensor, then it doesn't require the
|
||||||
// DQ op in the pattern.
|
// DQ op in the pattern.
|
||||||
inputs.push_back(operand);
|
inputs.push_back(operand);
|
||||||
@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
auto user = llvm::cast<Q>(*result.user_begin());
|
auto user = llvm::cast<Q>(*result.user_begin());
|
||||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||||
output_types.push_back(user.getType());
|
output_types.push_back(user.getType());
|
||||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
} else if (result_ele_type.isSignlessInteger()) {
|
||||||
// If the result is an integer tensor, then it doesn't require the
|
// If the result is an integer tensor, then it doesn't require the
|
||||||
// D op in the pattern.
|
// D op in the pattern.
|
||||||
outputs_replaced.insert({result, enumerated_result.index()});
|
outputs_replaced.insert({result, enumerated_result.index()});
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||||
|
@ -48,11 +48,9 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
|||||||
std::string serialized_model(
|
std::string serialized_model(
|
||||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||||
input_builder.GetSize());
|
input_builder.GetSize());
|
||||||
std::vector<std::string> output_arrays_order;
|
|
||||||
|
|
||||||
OwningModuleRef module =
|
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
UnknownLoc::get(&context));
|
||||||
UnknownLoc::get(&context), output_arrays_order);
|
|
||||||
if (!module) {
|
if (!module) {
|
||||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
|
@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
|
|||||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||||
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||||
|
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||||
|
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||||
|
return %2 : tensor<1x128x128x8xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
|
||||||
|
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||||
|
// CHECK-NEXT: return [[RESULT]]
|
||||||
|
}
|
||||||
|
|
||||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||||
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||||
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||||
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
|||||||
|
|
||||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
%cst_0 = constant dense<3> : tensor<i32>
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
|||||||
|
|
||||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
|
||||||
|
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||||
|
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||||
|
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||||
|
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
|
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
|
||||||
|
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||||
|
return %4 : tensor<1x128x128x1xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
|
||||||
|
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||||
|
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
|
||||||
|
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||||
|
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
|
||||||
|
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||||
|
// CHECK-NEXT: return [[RESULT]]
|
||||||
|
}
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
// Tests -input-arrays flag.
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||||
|
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||||
|
return %2 : tensor<4xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: main
|
||||||
|
// CHECK-NOT: tfl.squared_difference
|
||||||
|
// CHECK: tfl.mul %[[CONST:.*]], %arg0
|
||||||
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||||
|
// Ensure lstm roundtrip exactly
|
||||||
|
|
||||||
|
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||||
|
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||||
|
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
return %24 : tensor<4xf32>
|
||||||
|
// CHECK-LABEL: main
|
||||||
|
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||||
|
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||||
|
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||||
|
// CHECK: return %[[RES0]]
|
||||||
|
|
||||||
|
}
|
@ -123,6 +123,17 @@ func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
|||||||
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
|
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
return %0 : tensor<8x16xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: softplus
|
||||||
|
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
|
||||||
|
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||||
return %0 : tensor<8x8x8x8xf32>
|
return %0 : tensor<8x8x8x8xf32>
|
||||||
@ -1453,3 +1464,19 @@ func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
|
|||||||
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||||
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
|
func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||||
|
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28x28xf32>} : () -> tensor<28x28xf32>
|
||||||
|
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28xf32>} : () -> tensor<28xf32>
|
||||||
|
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x28xf32>} : () -> tensor<1x28xf32>
|
||||||
|
%4:2 = "tf.UnidirectionalSequenceRnn"(%arg, %1, %1, %2, %3) {_tflite_input_indices = [0, 1, 2, 3, 4], device = ""} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> (tensor<*xf32>, tensor<28x1x28xf32>)
|
||||||
|
return %4#1 : tensor<28x1x28xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @UnidirectionalRnn([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x28xf32> {
|
||||||
|
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<28x28xf32>
|
||||||
|
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<28xf32>
|
||||||
|
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<1x28xf32>
|
||||||
|
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
|
||||||
|
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||||
|
// CHECK: }
|
||||||
|
@ -878,6 +878,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||||
|
@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
module {
|
module {
|
||||||
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||||
@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
|||||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
@ -181,7 +181,46 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
|||||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||||
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||||
|
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||||
|
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||||
|
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||||
|
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||||
|
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||||
|
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||||
|
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||||
|
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||||
|
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||||
|
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
|
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||||
|
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||||
|
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||||
|
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||||
|
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||||
|
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
|
||||||
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
@ -622,3 +622,16 @@ func @QuantizeSharedBiases2(
|
|||||||
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
||||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: ReturnQuantizedResult
|
||||||
|
func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) {
|
||||||
|
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||||
|
%1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||||
|
%2 = "tfl.dequantize"(%1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> (tensor<1x112x112x32xf32>)
|
||||||
|
return %0, %2 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[dw:.*]] = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2)
|
||||||
|
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[dw]])
|
||||||
|
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||||
|
// CHECK: return %[[dq]], %[[dq]]
|
||||||
|
}
|
||||||
|
@ -1,22 +1,28 @@
|
|||||||
// Test to verify loop outlining.
|
// Test to verify loop outlining.
|
||||||
|
|
||||||
// RUN: tf-opt --split-input-file --tfl-while-loop-outline %s | FileCheck %s --dump-input-on-failure
|
// RUN: tf-opt --split-input-file --tfl-while-loop-outline %s | FileCheck %s --dump-input-on-failure
|
||||||
|
// Check that while loop outlining is nop if re-ran.
|
||||||
|
// RUN: tf-opt --tfl-while-loop-outline %s -o %t1
|
||||||
|
// RUN: tf-opt --tfl-while-loop-outline %t1 -o %t2
|
||||||
|
// RUN: diff %t1 %t2
|
||||||
|
|
||||||
// CHECK-LABEL: func @while
|
// CHECK-LABEL: func @while
|
||||||
func @while() -> tensor<1xf32>
|
func @while() -> tensor<1xf32>
|
||||||
attributes {tf.entry_function = {outputs = "result"}} {
|
attributes {tf.entry_function = {outputs = "result"}} {
|
||||||
%cst = constant dense<1> : tensor<i32> loc("dec")
|
%cst = constant dense<1> : tensor<i32> loc("dec")
|
||||||
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
%cst0 = constant dense<5> : tensor<i32> loc("N")
|
||||||
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
%cst1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
%0:2 = "tfl.while"(%cst0, %cst1) ( {
|
||||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||||
// CHECK: call @WhileOp_cond
|
// CHECK: call @WhileOp_cond
|
||||||
|
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||||
%cst_0 = constant dense<0> : tensor<i32>
|
%cst_0 = constant dense<0> : tensor<i32>
|
||||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||||
}, {
|
}, {
|
||||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||||
// CHECK: call @WhileOp_body
|
// CHECK: call @WhileOp_body
|
||||||
|
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||||
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
|
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
|
||||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
@ -32,6 +38,52 @@ func @while() -> tensor<1xf32>
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @while2
|
||||||
|
// Verify that while body//cond with implicitly captured values result in changing while operands/results.
|
||||||
|
func @while2() -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
|
||||||
|
%cst = constant dense<1> : tensor<i32>
|
||||||
|
%cst_0 = constant dense<5> : tensor<i32>
|
||||||
|
%cst_1 = constant dense<3.000000e+00> : tensor<1xf32>
|
||||||
|
// Verifies 3 operands post outlining.
|
||||||
|
// CHECK: "tfl.while"({{.*}}, {{.*}}, {{.*}}) (
|
||||||
|
%0:2 = "tfl.while"(%cst_0, %cst_1) ( {
|
||||||
|
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>): // no predecessors
|
||||||
|
// CHECK: call @WhileOp_cond
|
||||||
|
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||||
|
%1 = call @WhileOp_cond(%arg0, %arg1, %cst) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> tensor<i1>
|
||||||
|
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>): // no predecessors
|
||||||
|
// CHECK: call @WhileOp_body
|
||||||
|
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||||
|
%1:3 = call @WhileOp_body(%arg0, %arg1, %cst) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||||
|
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||||
|
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||||
|
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) ->
|
||||||
|
// CHECK-SAME: (tensor<i32>, tensor<1xf32>, tensor<i32>)
|
||||||
|
return %0#1 : tensor<1xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||||
|
%cst = constant dense<0> : tensor<i32>
|
||||||
|
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
|
||||||
|
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||||
|
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||||
|
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @WhileOp_cond(
|
||||||
|
// CHECK: tfl.greater
|
||||||
|
// CHECK-LABEL: func @WhileOp_body(
|
||||||
|
// CHECK: tfl.sub
|
||||||
|
// CHECK: tfl.add
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x2xf32> attributes {tf.entry_function = {inputs = "Placeholder", outputs = "rnn/transpose_1"}} {
|
func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x2xf32> attributes {tf.entry_function = {inputs = "Placeholder", outputs = "rnn/transpose_1"}} {
|
||||||
%cst = constant dense<0.000000e+00> : tensor<4x2xf32>
|
%cst = constant dense<0.000000e+00> : tensor<4x2xf32>
|
||||||
%cst_0 = constant dense<0.000000e+00> : tensor<8xf32>
|
%cst_0 = constant dense<0.000000e+00> : tensor<8xf32>
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
|||||||
template <typename Conv2dOpTy>
|
template <typename Conv2dOpTy>
|
||||||
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||||
|
// Make sure Conv2D has 'VALID' padding.
|
||||||
|
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
|
||||||
|
return Pattern::matchFailure();
|
||||||
|
}
|
||||||
|
// Make sure dilations are all ones if set.
|
||||||
|
const ArrayAttr& dilations =
|
||||||
|
op.template getAttrOfType<ArrayAttr>("dilations");
|
||||||
|
if (dilations && !TFIntListIsAllOnes(dilations)) {
|
||||||
|
return Pattern::matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||||
// `Squeeze` op.
|
// `Squeeze` op.
|
||||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||||
@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
|
|
||||||
TF::ExpandDimsOp expand_op;
|
TF::ExpandDimsOp expand_op;
|
||||||
TF::SqueezeOp squeeze_op;
|
TF::SqueezeOp squeeze_op;
|
||||||
|
int64_t expand_axis;
|
||||||
// Expand + Squeeze op.
|
// Expand + Squeeze op.
|
||||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||||
@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||||
|
|
||||||
|
// Make sure that the axis in `expand_op` is constant.
|
||||||
|
if (auto const_op =
|
||||||
|
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
|
||||||
|
expand_axis =
|
||||||
|
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||||
|
.getSExtValue();
|
||||||
|
} else {
|
||||||
|
return Pattern::matchFailure();
|
||||||
|
}
|
||||||
|
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||||
|
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||||
|
if (squeeze_dims.size() != 1 ||
|
||||||
|
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
||||||
|
return Pattern::matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
// Update previous/next op pointer.
|
// Update previous/next op pointer.
|
||||||
prev_op = prev_op->getPrevNode();
|
prev_op = prev_op->getPrevNode();
|
||||||
if (!prev_op) return Pattern::matchFailure();
|
if (!prev_op) return Pattern::matchFailure();
|
||||||
@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
|
|
||||||
// SpaceToBatchND op.
|
// SpaceToBatchND op.
|
||||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
||||||
|
// TODO(b/149936532): Check `padding` input, currently ignored.
|
||||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||||
|
|
||||||
// Pad op.
|
// Pad op.
|
||||||
TF::PadOp pad_op;
|
TF::PadOp pad_op;
|
||||||
|
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
|
||||||
|
// in real scenarios this may not always be correct: user can put a PadOp here
|
||||||
|
// with non-trivial consequences.
|
||||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||||
next_op = next_op->getNextNode();
|
next_op = next_op->getNextNode();
|
||||||
@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BatchToSpaceND + BiasAdd.
|
// BatchToSpaceND + BiasAdd.
|
||||||
|
// TODO(b/149936532): Check the `crops` input, currently ignored.
|
||||||
TF::BatchToSpaceNDOp bts_op;
|
TF::BatchToSpaceNDOp bts_op;
|
||||||
TF::BiasAddOp biasadd_op;
|
TF::BiasAddOp biasadd_op;
|
||||||
bool final_op_is_bts = true;
|
bool final_op_is_bts = true;
|
||||||
@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
||||||
op.setAttr("dilations", dilations_attr.getValue());
|
op.setAttr("dilations", dilations_attr.getValue());
|
||||||
|
|
||||||
// Here we need to set the correct padding for Conv op. In TF, the conv op
|
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
|
||||||
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
|
// TODO(b/149936532): This assumption only holds when the input width & height
|
||||||
// become a problem here if the original Conv op has 'SAME' padding. When
|
// is multiple of dilation width & height. We should fix it in order to
|
||||||
// the original conv has 'SAME' padding, TF will set a non-zero padding for
|
// support other use cases.
|
||||||
// the 'SpaceToBatch' op, so we rely on this information to check if we need
|
|
||||||
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
|
|
||||||
// values in `stb_op.paddings`, we change the current Conv's padding to
|
|
||||||
// 'SAME').
|
|
||||||
auto stb_paddings = stb_op.paddings();
|
auto stb_paddings = stb_op.paddings();
|
||||||
ElementsAttr stb_paddings_attr;
|
ElementsAttr stb_paddings_attr;
|
||||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
||||||
@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
|||||||
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
|
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
|
||||||
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
|
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
|
||||||
input_shape.end());
|
input_shape.end());
|
||||||
expand_shape.push_back(1);
|
expand_shape.insert(expand_shape.begin() + expand_axis, 1);
|
||||||
|
|
||||||
auto expand_result_type = RankedTensorType::get(
|
auto expand_result_type = RankedTensorType::get(
|
||||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||||
expand_op.getResult().setType(expand_result_type);
|
expand_op.getResult().setType(expand_result_type);
|
||||||
@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
|||||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
|
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
|
||||||
// Returns failure status if block shape is not a constant.
|
// Returns failure status if block_shape is not a constant.
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
// Check that the block_shape of `stb_op` and `bts_op` are equal.
|
// Check that the block_shape of `stb_op` and `bts_op` are equal.
|
||||||
@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
|||||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(haoliang): support 1-D dilated conv.
|
// Set dilation factor.
|
||||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
if (stb_bs_attr.getNumElements() < 2) return {};
|
||||||
|
|
||||||
int dilation_h_factor =
|
int dilation_h_factor =
|
||||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||||
int dilation_w_factor =
|
int dilation_w_factor =
|
||||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/StringMap.h"
|
#include "llvm/ADT/StringMap.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
// TFLite legalization patterns
|
// TFLite legalization patterns
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
|
|
||||||
@ -167,6 +167,7 @@ def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
|
|||||||
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
||||||
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
|
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
|
||||||
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
||||||
|
def : Pat<(TF_SoftplusOp F32Tensor:$arg0), (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), TFL_AF_None))>;
|
||||||
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
|
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
|
||||||
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||||
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
||||||
@ -340,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
|
|||||||
class I32VectorElementsAttr<int len> : ElementsAttrBase<
|
class I32VectorElementsAttr<int len> : ElementsAttrBase<
|
||||||
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
|
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
|
||||||
"$_self.cast<DenseIntElementsAttr>().getType()."
|
"$_self.cast<DenseIntElementsAttr>().getType()."
|
||||||
"getElementType().isInteger(32)">,
|
"getElementType().isSignlessInteger(32)">,
|
||||||
"32-bit int elements attribute of shape [" # len # "]"> {
|
"32-bit int elements attribute of shape [" # len # "]"> {
|
||||||
|
|
||||||
let storageType = [{ DenseIntElementsAttr }];
|
let storageType = [{ DenseIntElementsAttr }];
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||||
@ -64,6 +65,7 @@ using xla::Status;
|
|||||||
using xla::StatusOr;
|
using xla::StatusOr;
|
||||||
|
|
||||||
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
||||||
|
constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
||||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||||
|
|
||||||
// Legalize operations in functions.
|
// Legalize operations in functions.
|
||||||
@ -253,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
|||||||
|
|
||||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||||
if (!shape_type.getElementType().isInteger(32)) {
|
if (!shape_type.getElementType().isSignlessInteger(32)) {
|
||||||
auto new_shape = shape_type.getShape();
|
auto new_shape = shape_type.getShape();
|
||||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||||
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
||||||
@ -632,6 +634,66 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Legalize unidirectional seqeucen rnn.
|
||||||
|
struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||||
|
explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
|
||||||
|
: RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation* op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
auto tflite_indices_attr =
|
||||||
|
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
|
||||||
|
if (!tflite_indices_attr) return matchFailure();
|
||||||
|
|
||||||
|
if (op->getNumOperands() != 5) {
|
||||||
|
op->emitError()
|
||||||
|
<< "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
|
||||||
|
<< op->getNumOperands() << " provided";
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op->getNumResults() != 2) {
|
||||||
|
op->emitError()
|
||||||
|
<< "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
|
||||||
|
<< op->getNumResults() << " found";
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate inputs.
|
||||||
|
// UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
|
||||||
|
// are optional inputs.
|
||||||
|
SmallVector<Value, 5> inputs;
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
inputs.push_back(op->getOperand(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Populate outputs.
|
||||||
|
// UnidirectionalSequenceRnn should only have 1 output, and that is the
|
||||||
|
// original ophint converted node's 2nd output.
|
||||||
|
SmallVector<Type, 4> result_types;
|
||||||
|
result_types.push_back(op->getOpResult(1).getType());
|
||||||
|
|
||||||
|
// Populate attributes.
|
||||||
|
SmallVector<NamedAttribute, 2> attributes;
|
||||||
|
// Activation will always be tanh.
|
||||||
|
attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
|
||||||
|
rewriter.getStringAttr("TANH")));
|
||||||
|
|
||||||
|
// will always be time_majored.
|
||||||
|
attributes.push_back(
|
||||||
|
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
|
||||||
|
|
||||||
|
auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
|
||||||
|
op->getLoc(), result_types, inputs, attributes);
|
||||||
|
|
||||||
|
// Rewire the output.
|
||||||
|
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
|
||||||
|
op->erase();
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void LegalizeTF::runOnFunction() {
|
void LegalizeTF::runOnFunction() {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto* ctx = &getContext();
|
auto* ctx = &getContext();
|
||||||
@ -647,7 +709,8 @@ void LegalizeTF::runOnFunction() {
|
|||||||
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
|
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
|
||||||
|
|
||||||
// Ophint python converter converted tf node pattern.
|
// Ophint python converter converted tf node pattern.
|
||||||
patterns.insert<LegalizeUnidirectionalSequenceLstm>(ctx);
|
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||||
|
LegalizeUnidirectionalSequenceRnn>(ctx);
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsGreedily(func, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// Converts TF While to TFL While with single call in body and cond.
|
// Converts TF While to TFL While with single call in body and cond.
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/None.h"
|
#include "llvm/ADT/None.h"
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/Optional.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type dtype = op.element_dtype();
|
Type dtype = op.element_dtype();
|
||||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
dtype.isInteger(1) || dtype.isSignlessInteger(8) ||
|
||||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) ||
|
||||||
|
dtype.isSignlessInteger(64))) {
|
||||||
op.emitError(
|
op.emitError(
|
||||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
// This is the optimization pattern definition file for TensorFlow Lite.
|
// This is the optimization pattern definition file for TensorFlow Lite.
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||||
|
|
||||||
// Both Quantize and Dequantize ops have side effects, so we have to define
|
// Both Quantize and Dequantize ops have side effects, so we have to define
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||||
@ -115,6 +116,10 @@ class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply some sanity check and report some warnings for those don't follow
|
||||||
|
// the best quantization practise. This also fixes some simple violations.
|
||||||
|
void SanityCheckAndAdjustment(FuncOp func);
|
||||||
|
|
||||||
QuantizationSpecs quant_specs_;
|
QuantizationSpecs quant_specs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -184,13 +189,56 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
|
|||||||
return RemoveRedundantStatsOps(func, GetOpQuantSpec);
|
return RemoveRedundantStatsOps(func, GetOpQuantSpec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value Quantized(Operation* user) {
|
||||||
|
if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
|
||||||
|
if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
|
||||||
|
*q.getResult().user_begin())) {
|
||||||
|
return dq.getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||||
|
// If an op output has two users: one of them is a quantize op and another
|
||||||
|
// one is returned directly, we decide to return the quantized result instead,
|
||||||
|
// so this op can be quantized. This is only applied on the returned result
|
||||||
|
// because the error will not be accumulated.
|
||||||
|
func.walk([&](ReturnOp ret) {
|
||||||
|
int i = 0;
|
||||||
|
for (Value returned : ret.operands()) {
|
||||||
|
llvm::SmallVector<Value, 4> quantized;
|
||||||
|
for (auto user : returned.getUsers()) {
|
||||||
|
if (auto q = Quantized(user)) {
|
||||||
|
quantized.push_back(q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (quantized.size() == 1) {
|
||||||
|
ret.setOperand(i, quantized.front());
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// We prefer to placing quantization emulation ops on the results of the
|
||||||
|
// concat ops.
|
||||||
|
func.walk([&](ConcatenationOp concat) {
|
||||||
|
if (concat.output().hasOneUse() &&
|
||||||
|
Quantized(*concat.output().user_begin())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
concat.emitWarning(
|
||||||
|
"Missing quantization parameter on the output might introduce "
|
||||||
|
"quantization error!");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
using PrepareQuantStats =
|
using PrepareQuantStats =
|
||||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||||
|
|
||||||
void PrepareQuantizePass::runOnFunction() {
|
void PrepareQuantizePass::runOnFunction() {
|
||||||
FuncOp func = getFunction();
|
FuncOp func = getFunction();
|
||||||
MLIRContext* ctx = func.getContext();
|
MLIRContext* ctx = func.getContext();
|
||||||
|
|
||||||
ConvertTFLQuantOpsToMlirQuantOps(func);
|
ConvertTFLQuantOpsToMlirQuantOps(func);
|
||||||
|
|
||||||
if (quant_specs_.post_training_quantization) {
|
if (quant_specs_.post_training_quantization) {
|
||||||
@ -220,6 +268,8 @@ void PrepareQuantizePass::runOnFunction() {
|
|||||||
}
|
}
|
||||||
applyPatternsGreedily(func, patterns);
|
applyPatternsGreedily(func, patterns);
|
||||||
|
|
||||||
|
SanityCheckAndAdjustment(func);
|
||||||
|
|
||||||
// Finally, the quantization parameters can be propagated to the rest of the
|
// Finally, the quantization parameters can be propagated to the rest of the
|
||||||
// values (tensors).
|
// values (tensors).
|
||||||
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||||
|
|
||||||
// Quantize attribute $0 by using quantization parameter from %1.
|
// Quantize attribute $0 by using quantization parameter from %1.
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/StringMap.h"
|
#include "llvm/ADT/StringMap.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
@ -52,25 +52,47 @@ class WhileOutlinePass : public mlir::ModulePass<WhileOutlinePass> {
|
|||||||
|
|
||||||
tensorflow::OpOrArgLocNameMapper mapper_;
|
tensorflow::OpOrArgLocNameMapper mapper_;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||||
return (mapper_.GetUniqueName(op) + suffix).str();
|
return (mapper_.GetUniqueName(op) + suffix).str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
||||||
|
// to functions).
|
||||||
|
static bool IsAlreadyOutlinedd(WhileOp while_op) {
|
||||||
|
auto just_call = [](Region& region) {
|
||||||
|
auto it = region.front().begin();
|
||||||
|
if (!isa<CallOp>(*it)) return false;
|
||||||
|
++it;
|
||||||
|
if (!isa<YieldOp>(*it)) return false;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
return just_call(while_op.body()) && just_call(while_op.cond());
|
||||||
|
}
|
||||||
|
|
||||||
void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||||
OpBuilder builder(&getContext());
|
OpBuilder builder(&getContext());
|
||||||
// Colect external values used. Note: if an external value is also passed in
|
// Collect external values used.
|
||||||
// via argument, then it could end up being passed in multiple times. In the
|
|
||||||
// case where the value was already just passed through, this will result in
|
|
||||||
// redundancy.
|
|
||||||
llvm::SetVector<Value> extern_values;
|
llvm::SetVector<Value> extern_values;
|
||||||
|
|
||||||
// Sink down none type constants into the functions.
|
// The basic block arguments correspond to values that are loop carried, while
|
||||||
|
// all those post are loop independent. Initialize extern_values with while_op
|
||||||
|
// not loop carried operands.
|
||||||
|
auto num_loop_carried = while_op.cond().front().getNumArguments();
|
||||||
|
auto not_carried_operands =
|
||||||
|
while_op.getOperands().drop_front(num_loop_carried);
|
||||||
|
extern_values.insert(not_carried_operands.begin(),
|
||||||
|
not_carried_operands.end());
|
||||||
|
auto old_extern_values_size = extern_values.size();
|
||||||
|
|
||||||
llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
|
llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
|
||||||
for (auto it : llvm::enumerate(regions)) {
|
for (auto it : llvm::enumerate(regions)) {
|
||||||
llvm::SetVector<Value> region_extern_values;
|
llvm::SetVector<Value> region_extern_values;
|
||||||
Value const_none = nullptr;
|
Value const_none = nullptr;
|
||||||
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
|
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
|
||||||
|
|
||||||
|
// Sink down none type constants into the functions.
|
||||||
for (auto extern_value : region_extern_values) {
|
for (auto extern_value : region_extern_values) {
|
||||||
if (!extern_value.getType().isa<NoneType>()) {
|
if (!extern_value.getType().isa<NoneType>()) {
|
||||||
extern_values.insert(extern_value);
|
extern_values.insert(extern_value);
|
||||||
@ -89,12 +111,23 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Colect new types.
|
bool has_extra_extern_values = old_extern_values_size != extern_values.size();
|
||||||
|
// If an extern value is already an operand post the loop carried operands,
|
||||||
|
// then it need not be passed in again.
|
||||||
|
// Compute all the extra operands that have to be added to the while.
|
||||||
|
llvm::SetVector<Value> extra_operands;
|
||||||
|
if (has_extra_extern_values) {
|
||||||
|
auto new_extern =
|
||||||
|
extern_values.getArrayRef().drop_front(old_extern_values_size);
|
||||||
|
extra_operands.insert(new_extern.begin(), new_extern.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if already just calls.
|
||||||
|
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
|
||||||
|
|
||||||
|
// Collect new types.
|
||||||
SmallVector<Type, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
types.reserve(extern_values.size() +
|
types.reserve(extra_operands.size() + while_op.getNumOperands());
|
||||||
while_op.cond().front().getNumArguments());
|
|
||||||
// Type of block arguments are used as these could differ from those of While
|
|
||||||
// op, but has to match between cond and body.
|
|
||||||
for (BlockArgument ba : while_op.cond().front().getArguments())
|
for (BlockArgument ba : while_op.cond().front().getArguments())
|
||||||
types.push_back(ba.getType());
|
types.push_back(ba.getType());
|
||||||
for (Value operand : extern_values) types.push_back(operand.getType());
|
for (Value operand : extern_values) types.push_back(operand.getType());
|
||||||
@ -119,7 +152,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
outlined_func.getBody().takeBody(region);
|
outlined_func.getBody().takeBody(region);
|
||||||
Region& func_region = outlined_func.getBody();
|
Region& func_region = outlined_func.getBody();
|
||||||
|
|
||||||
// Replace all external uses with block args and update uses..
|
// Replace all external uses with block args and update uses.
|
||||||
llvm::SmallVector<Value, 4> new_args;
|
llvm::SmallVector<Value, 4> new_args;
|
||||||
new_args.reserve(extern_values.size());
|
new_args.reserve(extern_values.size());
|
||||||
Block& block = func_region.front();
|
Block& block = func_region.front();
|
||||||
@ -133,10 +166,12 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
Operation* yield_op = outlined_func.getBody().front().getTerminator();
|
Operation* yield_op = outlined_func.getBody().front().getTerminator();
|
||||||
OpBuilder b(yield_op);
|
OpBuilder b(yield_op);
|
||||||
llvm::SmallVector<Value, 4> args;
|
llvm::SmallVector<Value, 4> args;
|
||||||
args.reserve(yield_op->getNumOperands() + new_args.size());
|
auto loop_carried_yield_operands =
|
||||||
|
yield_op->getOperands().take_front(num_loop_carried);
|
||||||
|
args.reserve(loop_carried_yield_operands.size() + new_args.size());
|
||||||
if (passthru_extra_args) {
|
if (passthru_extra_args) {
|
||||||
// Add operands of yield to the return, inserting casts if needed.
|
// Add operands of yield to the return, inserting casts if needed.
|
||||||
for (auto it : llvm::zip(yield_op->getOperands(), types)) {
|
for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
|
||||||
auto value = std::get<0>(it);
|
auto value = std::get<0>(it);
|
||||||
auto type = std::get<1>(it);
|
auto type = std::get<1>(it);
|
||||||
if (value.getType() == type) {
|
if (value.getType() == type) {
|
||||||
@ -160,11 +195,6 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
// Replace region with call to outline function.
|
// Replace region with call to outline function.
|
||||||
auto replace_with_call = [&](StringRef name, Region& region,
|
auto replace_with_call = [&](StringRef name, Region& region,
|
||||||
bool passthru_extra_args) {
|
bool passthru_extra_args) {
|
||||||
// Skip if already only a call.
|
|
||||||
if (region.front().getOperations().size() == 2 &&
|
|
||||||
isa<mlir::CallOp>(region.front().front()))
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto func = create_outline_func(name, region, passthru_extra_args);
|
auto func = create_outline_func(name, region, passthru_extra_args);
|
||||||
OpBuilder b(region);
|
OpBuilder b(region);
|
||||||
// The body of the region is empty/has been outlined into the function.
|
// The body of the region is empty/has been outlined into the function.
|
||||||
@ -185,19 +215,19 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
|||||||
|
|
||||||
// If there are extern values used then the result type of the while has to
|
// If there are extern values used then the result type of the while has to
|
||||||
// change, so replace with new while op.
|
// change, so replace with new while op.
|
||||||
if (extern_values.empty()) return;
|
if (extra_operands.empty()) return;
|
||||||
|
|
||||||
Operation* op = while_op.getOperation();
|
Operation* op = while_op.getOperation();
|
||||||
SmallVector<Value, 4> operands;
|
SmallVector<Value, 4> operands;
|
||||||
SmallVector<Type, 4> new_types;
|
SmallVector<Type, 4> new_types;
|
||||||
operands.reserve(op->getNumOperands() + extern_values.size());
|
operands.reserve(types.size());
|
||||||
new_types.reserve(operands.size());
|
new_types.reserve(operands.size());
|
||||||
auto add_operand = [&](Value v) {
|
auto add_operand = [&](Value v) {
|
||||||
operands.push_back(v);
|
operands.push_back(v);
|
||||||
new_types.push_back(v.getType());
|
new_types.push_back(v.getType());
|
||||||
};
|
};
|
||||||
for (auto operand : op->getOperands()) add_operand(operand);
|
for (auto operand : op->getOperands()) add_operand(operand);
|
||||||
for (auto operand : extern_values) add_operand(operand);
|
for (auto operand : extra_operands) add_operand(operand);
|
||||||
|
|
||||||
Operation* new_op = OpBuilder(op).insert(Operation::create(
|
Operation* new_op = OpBuilder(op).insert(Operation::create(
|
||||||
op->getLoc(), op->getName(), new_types, operands, op->getAttrs(),
|
op->getLoc(), op->getName(), new_types, operands, op->getAttrs(),
|
||||||
@ -212,7 +242,6 @@ void WhileOutlinePass::runOnModule() {
|
|||||||
getModule().walk(
|
getModule().walk(
|
||||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||||
}
|
}
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
|
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
|
||||||
|
@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
|
|||||||
|
|
||||||
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
|
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
|
||||||
if (attr.getType().getNumElements() != 1 ||
|
if (attr.getType().getNumElements() != 1 ||
|
||||||
!attr.getType().getElementType().isa<IntegerType>()) {
|
!attr.getType().getElementType().isSignlessInteger()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
@ -70,14 +70,14 @@ Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
|
|||||||
builder->getUnitAttr());
|
builder->getUnitAttr());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
Value Transpose(OpBuilder* builder, Value value_to_transpose,
|
||||||
RankedTensorType type, mlir::Location location) {
|
SmallVector<int64_t, 4> perm, RankedTensorType original_type,
|
||||||
|
mlir::Location location) {
|
||||||
// Create a constant op for transpose permutation.
|
// Create a constant op for transpose permutation.
|
||||||
SmallVector<int64_t, 2> perm = {1, 0};
|
|
||||||
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
||||||
|
|
||||||
// Create tensor type for the transpose result.
|
// Create tensor type for the transpose result.
|
||||||
auto transpose_type = type;
|
auto transpose_type = original_type;
|
||||||
auto transpose_shape = functional::map(
|
auto transpose_shape = functional::map(
|
||||||
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
||||||
perm);
|
perm);
|
||||||
@ -88,6 +88,13 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
|||||||
value_to_transpose, perm_op);
|
value_to_transpose, perm_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||||
|
RankedTensorType type, mlir::Location location) {
|
||||||
|
// Create a constant op for transpose permutation.
|
||||||
|
SmallVector<int64_t, 4> perm = {1, 0};
|
||||||
|
return Transpose(builder, value_to_transpose, perm, type, location);
|
||||||
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
||||||
return value.getType().cast<RankedTensorType>().getShape();
|
return value.getType().cast<RankedTensorType>().getShape();
|
||||||
}
|
}
|
||||||
@ -586,15 +593,30 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
|||||||
Value recurrent_kernel = func_op.getArgument(4);
|
Value recurrent_kernel = func_op.getArgument(4);
|
||||||
Value bias = func_op.getArgument(5);
|
Value bias = func_op.getArgument(5);
|
||||||
|
|
||||||
// Assume it's batch majored.
|
// TFL lstm only supports time-majored inputs, so if it's not time-majored,
|
||||||
|
// we will transpose the inputs and outputs.
|
||||||
|
auto time_major_attr = func_op.getAttrOfType<BoolAttr>("tf.time_major");
|
||||||
|
if (time_major_attr == nullptr) return failure();
|
||||||
|
|
||||||
|
bool time_majored = time_major_attr.getValue();
|
||||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||||
if (!input_type) {
|
if (!input_type) {
|
||||||
func_op.emitError() << "Input type is not a ranked tensor type";
|
func_op.emitError() << "Input type is not a ranked tensor type";
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
int batch = input_type.getDimSize(0);
|
auto final_inputs = input;
|
||||||
int time = input_type.getDimSize(1);
|
auto final_input_type = input_type;
|
||||||
|
// We will transpose the inputs.
|
||||||
|
if (!time_majored) {
|
||||||
|
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||||
|
final_inputs =
|
||||||
|
Transpose(builder, final_inputs, perm, input_type, func_op.getLoc());
|
||||||
|
final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
int batch = final_input_type.getDimSize(1);
|
||||||
|
int time = final_input_type.getDimSize(0);
|
||||||
|
|
||||||
// Setup correct weights.
|
// Setup correct weights.
|
||||||
RankedTensorType weight_type =
|
RankedTensorType weight_type =
|
||||||
@ -672,7 +694,13 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
|||||||
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
|
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
|
||||||
builder->getStringAttr("FULL"));
|
builder->getStringAttr("FULL"));
|
||||||
|
|
||||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), lstm.getResult());
|
auto final_output = lstm.getResult();
|
||||||
|
if (!time_majored) {
|
||||||
|
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||||
|
final_output =
|
||||||
|
Transpose(builder, final_output, perm, result_type, func_op.getLoc());
|
||||||
|
}
|
||||||
|
builder->create<mlir::ReturnOp>(func_op.getLoc(), final_output);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TFL {
|
||||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -90,7 +90,7 @@ gentbl(
|
|||||||
td_file = "ir/tf_saved_model_ops.td",
|
td_file = "ir/tf_saved_model_ops.td",
|
||||||
td_srcs = [
|
td_srcs = [
|
||||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ gentbl(
|
|||||||
td_file = "ir/tf_executor_ops.td",
|
td_file = "ir/tf_executor_ops.td",
|
||||||
td_srcs = [
|
td_srcs = [
|
||||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ gentbl(
|
|||||||
td_file = "ir/tf_device_ops.td",
|
td_file = "ir/tf_device_ops.td",
|
||||||
td_srcs = [
|
td_srcs = [
|
||||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -281,12 +281,12 @@ cc_library(
|
|||||||
"transforms/generated_canonicalize.inc",
|
"transforms/generated_canonicalize.inc",
|
||||||
"transforms/generated_optimize.inc",
|
"transforms/generated_optimize.inc",
|
||||||
"transforms/graph_pruning.cc",
|
"transforms/graph_pruning.cc",
|
||||||
"transforms/inline_global_tensors.cc",
|
|
||||||
"transforms/layout_optimization.cc",
|
"transforms/layout_optimization.cc",
|
||||||
"transforms/mark_function_visibility.cc",
|
"transforms/mark_function_visibility.cc",
|
||||||
"transforms/materialize_mlir_passthrough_op.cc",
|
"transforms/materialize_mlir_passthrough_op.cc",
|
||||||
"transforms/optimize.cc",
|
"transforms/optimize.cc",
|
||||||
"transforms/optimize_global_tensors.cc",
|
"transforms/optimize_global_tensors.cc",
|
||||||
|
"transforms/parallel_execute_to_islands.cc",
|
||||||
"transforms/promote_resources_to_args.cc",
|
"transforms/promote_resources_to_args.cc",
|
||||||
"transforms/raise_control_flow.cc",
|
"transforms/raise_control_flow.cc",
|
||||||
"transforms/replicate_invariant_op_hoisting.cc",
|
"transforms/replicate_invariant_op_hoisting.cc",
|
||||||
@ -376,6 +376,7 @@ cc_library(
|
|||||||
":tensorflow",
|
":tensorflow",
|
||||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:LoopOpsTransforms",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -1000,8 +1001,13 @@ cc_library(
|
|||||||
srcs = ["utils/tpu_rewrite_device_util.cc"],
|
srcs = ["utils/tpu_rewrite_device_util.cc"],
|
||||||
hdrs = ["utils/tpu_rewrite_device_util.h"],
|
hdrs = ["utils/tpu_rewrite_device_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:array3d",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
],
|
],
|
||||||
@ -1016,6 +1022,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
@ -573,9 +573,9 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
|
|||||||
|
|
||||||
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
||||||
// Parsing:
|
// Parsing:
|
||||||
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
// %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||||
// Where the first operand is the data to replicate, the second is an i32
|
// Where the first operand is the data to replicate, the second is an i32
|
||||||
// indicating which output to populate, followed by the keyword `by` and the
|
// indicating which output to populate, followed by the keyword `of` and the
|
||||||
// number of outputs (+1 for the control token).
|
// number of outputs (+1 for the control token).
|
||||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||||
SmallVector<Type, 1> types;
|
SmallVector<Type, 1> types;
|
||||||
|
@ -165,7 +165,7 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island",
|
|||||||
The `tf_executor.island` operation has a single region with a single block
|
The `tf_executor.island` operation has a single region with a single block
|
||||||
attached (only functional control flow is allowed). The block is terminated
|
attached (only functional control flow is allowed). The block is terminated
|
||||||
by a `tf_executor.yield` operation. The operands of the terminator
|
by a `tf_executor.yield` operation. The operands of the terminator
|
||||||
correspond to the result values of the `tf_executor.graph` operation. An
|
correspond to the result values of the `tf_executor.island` operation. An
|
||||||
extra result of type `!tf_executor.control` is always produced by every
|
extra result of type `!tf_executor.control` is always produced by every
|
||||||
`tf_executor.island`.
|
`tf_executor.island`.
|
||||||
Within an island, execution semantics follow standard sequential behavior as
|
Within an island, execution semantics follow standard sequential behavior as
|
||||||
@ -299,7 +299,7 @@ def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN",
|
|||||||
.SetShapeFn(SwitchNShape);
|
.SetShapeFn(SwitchNShape);
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
%2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||||
|
|
||||||
Note: One additional result corresponds to the control output.
|
Note: One additional result corresponds to the control output.
|
||||||
}];
|
}];
|
||||||
|
@ -510,6 +510,7 @@ Broadcasting is supported, so `value` may have any number of dimensions.
|
|||||||
// TF_LayoutSensitiveInterface:
|
// TF_LayoutSensitiveInterface:
|
||||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
|
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -980,7 +981,7 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
|
|||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect]> {
|
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
||||||
}];
|
}];
|
||||||
@ -1030,6 +1031,13 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
|||||||
let verifier = [{
|
let verifier = [{
|
||||||
return Verify(*this);
|
return Verify(*this);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_LayoutSensitiveInterface:
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
|
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> {
|
def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> {
|
||||||
@ -2091,7 +2099,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
|||||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
|
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||||
let summary = "Batch normalization.";
|
let summary = "Batch normalization.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -2122,6 +2130,13 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
|||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_FoldOperandsTransposeInterface:
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
|
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||||
@ -3096,6 +3111,70 @@ cublas.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> {
|
||||||
|
let summary = [{
|
||||||
|
Copy a tensor setting everything outside a central band in each innermost matrix
|
||||||
|
to zero.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
The `band` part is computed as follows:
|
||||||
|
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
|
||||||
|
tensor with the same shape where
|
||||||
|
|
||||||
|
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
|
||||||
|
|
||||||
|
The indicator function
|
||||||
|
|
||||||
|
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
|
||||||
|
(num_upper < 0 || (n-m) <= num_upper)`.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
# if 'input' is [[ 0, 1, 2, 3]
|
||||||
|
[-1, 0, 1, 2]
|
||||||
|
[-2, -1, 0, 1]
|
||||||
|
[-3, -2, -1, 0]],
|
||||||
|
|
||||||
|
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
|
||||||
|
[-1, 0, 1, 2]
|
||||||
|
[ 0, -1, 0, 1]
|
||||||
|
[ 0, 0, -1, 0]],
|
||||||
|
|
||||||
|
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
|
||||||
|
[-1, 0, 1, 0]
|
||||||
|
[-2, -1, 0, 1]
|
||||||
|
[ 0, -2, -1, 0]]
|
||||||
|
```
|
||||||
|
|
||||||
|
Useful special cases:
|
||||||
|
|
||||||
|
```
|
||||||
|
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
|
||||||
|
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
|
||||||
|
tf.matrix_band_part(input, 0, 0) ==> Diagonal.
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$input,
|
||||||
|
TF_I32OrI64Tensor:$num_lower,
|
||||||
|
TF_I32OrI64Tensor:$num_upper
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$band
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
|
def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Returns a batched diagonal tensor with a given batched diagonal values.
|
Returns a batched diagonal tensor with a given batched diagonal values.
|
||||||
@ -4278,7 +4357,7 @@ This is the opposite of `unpack`.
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect]> {
|
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||||
let summary = "Pads a tensor with zeros.";
|
let summary = "Pads a tensor with zeros.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -4317,6 +4396,13 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
|
|||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_FoldOperandsTransposeInterface:
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||||
|
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_PadV2Op : TF_Op<"PadV2", [NoSideEffect]> {
|
def TF_PadV2Op : TF_Op<"PadV2", [NoSideEffect]> {
|
||||||
@ -4845,7 +4931,7 @@ I.e., \\(y = 1 / x\\).
|
|||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
|
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
|
||||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> :
|
|||||||
|
|
||||||
// Any tensor element type allowed in TensorFlow ops
|
// Any tensor element type allowed in TensorFlow ops
|
||||||
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
||||||
AnyInteger.predicate,
|
AnySignlessInteger.predicate,
|
||||||
AnyComplex.predicate,
|
AnyComplex.predicate,
|
||||||
TF_TFDialectType.predicate]>,
|
TF_TFDialectType.predicate]>,
|
||||||
"tf.dtype">;
|
"tf.dtype">;
|
||||||
|
@ -50,6 +50,12 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
|
|||||||
[{Returns indices of layout dependent results.}],
|
[{Returns indices of layout dependent results.}],
|
||||||
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
|
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
|
||||||
>,
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
[{Updates operation attributes and operands to account for the updated
|
||||||
|
data format. If data format is not supported, must return failure.}],
|
||||||
|
"LogicalResult", "UpdateDataFormat",
|
||||||
|
(ins "StringRef":$data_format)
|
||||||
|
>,
|
||||||
];
|
];
|
||||||
|
|
||||||
let verify = [{
|
let verify = [{
|
||||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/ADT/iterator_range.h"
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
@ -151,26 +151,6 @@ static bool AreCastCompatible(Type a, Type b) {
|
|||||||
b_kind == TensorFlowTypes::VARIANT;
|
b_kind == TensorFlowTypes::VARIANT;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
|
||||||
DenseIntElementsAttr perm1) {
|
|
||||||
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
|
||||||
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> perm0_values;
|
|
||||||
for (auto value : perm0.getIntValues())
|
|
||||||
perm0_values.push_back(value.getSExtValue());
|
|
||||||
|
|
||||||
SmallVector<int64_t, 8> perm1_values;
|
|
||||||
for (auto value : perm1.getIntValues())
|
|
||||||
perm1_values.push_back(value.getSExtValue());
|
|
||||||
|
|
||||||
for (int i = 0; i < perm0_values.size(); ++i) {
|
|
||||||
if (perm0_values[perm1_values[i]] != i) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||||
return dim_or_rank == -1;
|
return dim_or_rank == -1;
|
||||||
}
|
}
|
||||||
@ -312,6 +292,164 @@ static LogicalResult VerifyTypesCompatibility(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TF op helper functions to work with layout transformation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) {
|
||||||
|
SmallVector<int64_t, 4> reverse(permutation.size());
|
||||||
|
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||||
|
reverse[permutation[i]] = i;
|
||||||
|
}
|
||||||
|
return reverse;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
|
||||||
|
if (from == "NHWC" && to == "NCHW") {
|
||||||
|
return {0, 3, 1, 2};
|
||||||
|
} else if (from == "NCHW" && to == "NHWC") {
|
||||||
|
return {0, 2, 3, 1};
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle elements in the `attr` according to the permutation. Optional
|
||||||
|
// `inner_size` allows to shuffle array attributes created from rank 2 tensors
|
||||||
|
// on outer dimension only.
|
||||||
|
ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
|
||||||
|
int inner_size = 1) {
|
||||||
|
if (attr.size() == 0) return attr;
|
||||||
|
|
||||||
|
assert(attr.size() % inner_size == 0);
|
||||||
|
assert(attr.size() / inner_size == permutation.size());
|
||||||
|
|
||||||
|
SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
|
||||||
|
SmallVector<Attribute, 8> shuffled(values.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||||
|
for (size_t j = 0; j < inner_size; ++j) {
|
||||||
|
shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ArrayAttr::get(shuffled, attr.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle ranked tensor dimensions according to the permutation.
|
||||||
|
Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
|
||||||
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
|
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||||
|
assert(permutation.size() == shape.size());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> new_shape(permutation.size());
|
||||||
|
for (size_t i = 0; i < permutation.size(); ++i)
|
||||||
|
new_shape[i] = shape[permutation[i]];
|
||||||
|
|
||||||
|
return RankedTensorType::get(new_shape, ranked_type.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||||
|
DenseIntElementsAttr perm1) {
|
||||||
|
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
||||||
|
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> perm0_values;
|
||||||
|
for (auto value : perm0.getIntValues())
|
||||||
|
perm0_values.push_back(value.getSExtValue());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> perm1_values;
|
||||||
|
for (auto value : perm1.getIntValues())
|
||||||
|
perm1_values.push_back(value.getSExtValue());
|
||||||
|
|
||||||
|
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||||
|
if (perm0_values[perm1_values[i]] != i) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
|
||||||
|
// layout sensitive operations that do not have any additional layout dependent
|
||||||
|
// attributes besides `data_format` string.
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
|
||||||
|
auto perm = GetDataFormatPermutation(op->data_format(), data_format);
|
||||||
|
if (perm.empty()) return failure();
|
||||||
|
|
||||||
|
// Update data format attribute.
|
||||||
|
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
|
||||||
|
|
||||||
|
// Update types for all layout sensitive results.
|
||||||
|
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
|
||||||
|
for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
|
||||||
|
OpResult result = op->getOperation()->getResult(idx);
|
||||||
|
result.setType(ShuffleRankedTensorType(result.getType(), perm));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default implementation for folding operand transpose into the operation.
|
||||||
|
// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
|
||||||
|
template <typename Op>
|
||||||
|
LogicalResult FoldOperandsPermutation(
|
||||||
|
ArrayRef<int64_t> permutation, Op *op,
|
||||||
|
ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
|
||||||
|
MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
|
||||||
|
|
||||||
|
// We only support NHWC <-> NCHW permutations.
|
||||||
|
static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
|
||||||
|
static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
|
||||||
|
|
||||||
|
// Operation data format after folding `permutation`.
|
||||||
|
StringRef target_data_format = [&]() -> StringRef {
|
||||||
|
if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
|
||||||
|
return "NCHW"; // cancel NCHW->NHWC operand permutation
|
||||||
|
} else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
|
||||||
|
return "NHWC"; // cancel NHWC->NCHW operand permutation
|
||||||
|
} else {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
if (target_data_format.empty()) return failure();
|
||||||
|
|
||||||
|
// To fold operand `permutation` into the `op` we need shuffle all layout
|
||||||
|
// dependent attributes and types with a reverse permutation, and change
|
||||||
|
// operation data format to `target_data_format`.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// %1 = SomeOp(...) {data_format = NHWC}
|
||||||
|
// %2 = Transpose(%1) {permutation = NHWC->NCHW}
|
||||||
|
// %3 = Op(%2) {data_format = NCHW}
|
||||||
|
//
|
||||||
|
// To bypass %2 we have to change data format to shuffle data format from NCHW
|
||||||
|
// to NHWC, which is the reverse of operand permutation (function argument).
|
||||||
|
auto reverse_permutation =
|
||||||
|
GetDataFormatPermutation(op->data_format(), target_data_format);
|
||||||
|
if (reverse_permutation.empty()) return failure();
|
||||||
|
|
||||||
|
op->setAttr("data_format", StringAttr::get(target_data_format, context));
|
||||||
|
|
||||||
|
for (auto pair : shuffle_attrs) {
|
||||||
|
StringRef attr_name = pair.first;
|
||||||
|
ArrayAttr attr_value = pair.second;
|
||||||
|
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
|
||||||
|
for (unsigned idx : fold.GetLayoutDependentResults()) {
|
||||||
|
OpResult result = op->getOperation()->getResult(idx);
|
||||||
|
result.setType(
|
||||||
|
ShuffleRankedTensorType(result.getType(), reverse_permutation));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -479,6 +617,15 @@ static LogicalResult Verify(BiasAddOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only
|
||||||
|
// support folding operand transposes.
|
||||||
|
LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
|
||||||
|
auto ranked = value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!ranked || ranked.getRank() != 4) return failure();
|
||||||
|
|
||||||
|
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BiasAddGradOp
|
// BiasAddGradOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -837,6 +984,21 @@ static LogicalResult Verify(OpT op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
|
||||||
|
auto perm = GetDataFormatPermutation(this->data_format(), data_format);
|
||||||
|
if (perm.empty()) return failure();
|
||||||
|
|
||||||
|
// Update data_format attribute and result types.
|
||||||
|
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||||
|
|
||||||
|
// Update convolution attributes.
|
||||||
|
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||||
|
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||||
|
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conv2dBackpropInputOp
|
// Conv2dBackpropInputOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1158,6 +1320,11 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||||
|
ArrayRef<int64_t> permutation) {
|
||||||
|
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GatherV2Op
|
// GatherV2Op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1339,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns(
|
|||||||
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MatrixBandPartOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult Verify(MatrixBandPartOp op) {
|
||||||
|
if (!HasRankAtLeast(op.input(), 2)) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "requires `input` to have rank of at least 2, but found "
|
||||||
|
<< op.input().getType();
|
||||||
|
}
|
||||||
|
if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "requires `num_lower` to have 0 dimensions, but found "
|
||||||
|
<< op.num_lower().getType();
|
||||||
|
}
|
||||||
|
if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "requires `num_upper` to have 0 dimensions, but found "
|
||||||
|
<< op.num_upper().getType();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MaxOp
|
// MaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1356,57 +1546,8 @@ void MaxOp::build(Builder *builder, OperationState &result, Value input,
|
|||||||
|
|
||||||
LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
||||||
ArrayRef<int64_t> permutation) {
|
ArrayRef<int64_t> permutation) {
|
||||||
MLIRContext *context = getParentOfType<ModuleOp>().getContext();
|
return ::mlir::TF::FoldOperandsPermutation(
|
||||||
|
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
||||||
// For now we only support folding of NCHW->NHWC and NHWC->NCHW permutations.
|
|
||||||
if (data_format() == "NHWC") {
|
|
||||||
static constexpr std::array<int64_t, 4> kPerm = {0, 2, 3, 1}; // to NHWC
|
|
||||||
if (permutation != ArrayRef<int64_t>(kPerm)) return failure();
|
|
||||||
|
|
||||||
setAttr("data_format", StringAttr::get("NCHW", context));
|
|
||||||
|
|
||||||
} else if (data_format() == "NCHW") {
|
|
||||||
static constexpr std::array<int64_t, 4> kPerm = {0, 3, 1, 2}; // to NCHW
|
|
||||||
if (permutation != ArrayRef<int64_t>(kPerm)) return failure();
|
|
||||||
|
|
||||||
setAttr("data_format", StringAttr::get("NHWC", context));
|
|
||||||
|
|
||||||
} else {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shuffle_attr = [&](ArrayAttr attr) -> ArrayAttr {
|
|
||||||
SmallVector<Attribute, 4> values{attr.begin(), attr.end()};
|
|
||||||
SmallVector<Attribute, 4> shuffled(values.size());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < permutation.size(); ++i)
|
|
||||||
shuffled[permutation[i]] = values[i];
|
|
||||||
|
|
||||||
return ArrayAttr::get(shuffled, context);
|
|
||||||
};
|
|
||||||
|
|
||||||
setAttr("strides", shuffle_attr(strides()));
|
|
||||||
setAttr("ksize", shuffle_attr(ksize()));
|
|
||||||
|
|
||||||
auto shuffle_type = [&](Type type) -> Type {
|
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
|
||||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
|
||||||
assert(permutation.size() == shape.size());
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> new_shape(permutation.size());
|
|
||||||
for (size_t i = 0; i < permutation.size(); ++i)
|
|
||||||
new_shape[permutation[i]] = shape[i];
|
|
||||||
|
|
||||||
return RankedTensorType::get(new_shape, ranked_type.getElementType());
|
|
||||||
}
|
|
||||||
|
|
||||||
return type;
|
|
||||||
};
|
|
||||||
|
|
||||||
OpResult result = getOperation()->getResult(0);
|
|
||||||
result.setType(shuffle_type(result.getType()));
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1426,6 +1567,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MeanOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||||
|
// Reduction indices must be defined by a constant operation.
|
||||||
|
auto reduction_op =
|
||||||
|
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
|
||||||
|
if (!reduction_op) return failure();
|
||||||
|
|
||||||
|
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
|
||||||
|
if (!reductions_value) return failure();
|
||||||
|
|
||||||
|
// Prepare new reduction indices according to operand permutation.
|
||||||
|
SmallVector<int32_t, 4> shuffled_reduction;
|
||||||
|
llvm::transform(reductions_value.getIntValues(),
|
||||||
|
std::back_inserter(shuffled_reduction),
|
||||||
|
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
|
||||||
|
|
||||||
|
// Add constant operation with a new reduction indices.
|
||||||
|
OpBuilder builder(getOperation());
|
||||||
|
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
|
||||||
|
builder.getIntegerType(32));
|
||||||
|
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
|
||||||
|
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||||
|
|
||||||
|
// Use new reduction indices.
|
||||||
|
setOperand(1, shuffled_reduction_op);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// NegOp
|
// NegOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1568,6 +1741,46 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// PadOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||||
|
// Paddings must be defined by a constant operation.
|
||||||
|
auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
|
||||||
|
if (!paddings_op) return failure();
|
||||||
|
|
||||||
|
auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
|
||||||
|
if (!paddings_value ||
|
||||||
|
paddings_value.getNumElements() != permutation.size() * 2)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
|
||||||
|
for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
|
||||||
|
size_t outer_idx = index_pair.index() / 2;
|
||||||
|
size_t inner_idx = index_pair.index() % 2;
|
||||||
|
|
||||||
|
shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
|
||||||
|
index_pair.value().getSExtValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add constant operation with a new paddings.
|
||||||
|
OpBuilder builder(getOperation());
|
||||||
|
auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
|
||||||
|
builder.getIntegerType(32));
|
||||||
|
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
|
||||||
|
auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||||
|
|
||||||
|
// Use new paddings.
|
||||||
|
setOperand(1, shuffled_paddings_op);
|
||||||
|
|
||||||
|
// Change the result type.
|
||||||
|
getResult().setType(ShuffleRankedTensorType(getResult().getType(),
|
||||||
|
ReversePermutation(permutation)));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ParseExampleV2Op
|
// ParseExampleV2Op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1914,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Type element_type = result_ranked_type.getElementType();
|
Type element_type = result_ranked_type.getElementType();
|
||||||
if (!element_type.isInteger(32) && !element_type.isInteger(64))
|
if (!element_type.isSignlessInteger(32) &&
|
||||||
|
!element_type.isSignlessInteger(64))
|
||||||
return op->emitOpError("requires int32 or int64 return type for result")
|
return op->emitOpError("requires int32 or int64 return type for result")
|
||||||
<< variadic_idx_str;
|
<< variadic_idx_str;
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
|
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
@ -195,6 +195,13 @@ retained with length 1.
|
|||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TF_FoldOperandsTransposeInterface:
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||||
|
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
|
||||||
|
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
||||||
|
@ -112,24 +112,20 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
|
|||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor) {
|
||||||
|
auto type = global_tensor.type().cast<TensorType>();
|
||||||
|
return RankedTensorType::get(
|
||||||
|
{}, TF::ResourceType::get({type}, type.getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
|
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
|
||||||
Type arg_type,
|
Type arg_type,
|
||||||
GlobalTensorOp global_tensor) {
|
GlobalTensorOp global_tensor) {
|
||||||
if (global_tensor.is_mutable()) {
|
auto expected_type = GetBoundInputArgTypeFor(global_tensor);
|
||||||
auto expected_type = RankedTensorType::get(
|
if (arg_type != expected_type) {
|
||||||
{}, TF::ResourceType::get({global_tensor.type().cast<TensorType>()},
|
return op_for_diagnostics->emitError()
|
||||||
arg_type.getContext()));
|
<< "bound input with type " << arg_type << " expected to have type "
|
||||||
if (arg_type != expected_type) {
|
<< expected_type;
|
||||||
return op_for_diagnostics->emitError()
|
|
||||||
<< "mutable bound input with type " << arg_type
|
|
||||||
<< " expected to have type " << expected_type;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (arg_type != global_tensor.type()) {
|
|
||||||
return op_for_diagnostics->emitError()
|
|
||||||
<< "bound input for immutable 'tf_saved_model.global_tensor' must "
|
|
||||||
"match the global tensor's type";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,10 @@ bool HasTfSavedModelSemantics(ModuleOp module);
|
|||||||
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||||
const SymbolTable &symbol_table);
|
const SymbolTable &symbol_table);
|
||||||
|
|
||||||
|
// Gets the type that an exported function arg that is bound to `global_tensor`
|
||||||
|
// should have.
|
||||||
|
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
|
||||||
|
|
||||||
} // namespace tf_saved_model
|
} // namespace tf_saved_model
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class TensorFlowType : public Type {
|
|||||||
// Returns true if the specified type is a valid TensorFlow element type.
|
// Returns true if the specified type is a valid TensorFlow element type.
|
||||||
static inline bool IsValidTFElementType(Type type) {
|
static inline bool IsValidTFElementType(Type type) {
|
||||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
type.isSignlessInteger() || type.isa<TensorFlowType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if this is a valid TensorFlow tensor type.
|
// Returns true if this is a valid TensorFlow tensor type.
|
||||||
@ -141,20 +141,16 @@ class TensorFlowRefType : public TensorFlowType {
|
|||||||
static TensorFlowType get(Type type);
|
static TensorFlowType get(Type type);
|
||||||
static TensorFlowType getChecked(Type type, MLIRContext* context,
|
static TensorFlowType getChecked(Type type, MLIRContext* context,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (failed(verifyConstructionInvariants(loc, context, type))) {
|
if (failed(verifyConstructionInvariants(loc, type))) {
|
||||||
return TensorFlowRefType();
|
return TensorFlowRefType();
|
||||||
}
|
}
|
||||||
return get(type);
|
return get(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyConstructionInvariants(
|
static LogicalResult verifyConstructionInvariants(Location loc, Type type) {
|
||||||
llvm::Optional<Location> loc, MLIRContext* context, Type type) {
|
|
||||||
// type should be a valid TensorFlow type.
|
// type should be a valid TensorFlow type.
|
||||||
if (!IsValidTFTensorType(type)) {
|
if (!IsValidTFTensorType(type)) {
|
||||||
if (loc) {
|
return emitError(loc) << "invalid TensorFlow type: " << type;
|
||||||
emitError(*loc) << "invalid TensorFlow type: " << type;
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -230,7 +226,7 @@ class TypeWithSubtypeImpl
|
|||||||
|
|
||||||
static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
|
static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
return Base::getChecked(loc, context, Derived::getTypeKind(), subtypes);
|
return Base::getChecked(loc, Derived::getTypeKind(), subtypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Derived get(MLIRContext* context) { return get({}, context); }
|
static Derived get(MLIRContext* context) { return get({}, context); }
|
||||||
@ -239,16 +235,12 @@ class TypeWithSubtypeImpl
|
|||||||
static bool kindof(unsigned kind) { return kind == Derived::getTypeKind(); }
|
static bool kindof(unsigned kind) { return kind == Derived::getTypeKind(); }
|
||||||
|
|
||||||
static LogicalResult verifyConstructionInvariants(
|
static LogicalResult verifyConstructionInvariants(
|
||||||
llvm::Optional<Location> loc, MLIRContext* context,
|
Location loc, ArrayRef<TensorType> subtypes) {
|
||||||
ArrayRef<TensorType> subtypes) {
|
|
||||||
// Each of the subtypes should be a valid TensorFlow type.
|
// Each of the subtypes should be a valid TensorFlow type.
|
||||||
for (TensorType subtype : subtypes) {
|
for (TensorType subtype : subtypes) {
|
||||||
if (!IsValidTFTensorType(subtype)) {
|
if (!IsValidTFTensorType(subtype)) {
|
||||||
if (loc) {
|
return emitError(loc) << "invalid " << Derived::getTypeName()
|
||||||
emitError(*loc) << "invalid " << Derived::getTypeName()
|
<< " subtype: " << subtype;
|
||||||
<< " subtype: " << subtype;
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
|
@ -280,3 +280,67 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The following tests check that certain control dependencies between islands
|
||||||
|
// and certain tf_executor ops are added correctly.
|
||||||
|
|
||||||
|
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||||
|
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
|
||||||
|
func @next_iteration_sink_control_input() {
|
||||||
|
tf_executor.graph {
|
||||||
|
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
|
||||||
|
%island:2 = tf_executor.island {
|
||||||
|
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||||
|
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||||
|
tf_executor.yield %const : tensor<*xi32>
|
||||||
|
}
|
||||||
|
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
|
||||||
|
tf_executor.fetch %island#0 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||||
|
// CHECK: tf_executor.LoopCond {{.*}}, %[[CONTROL]]
|
||||||
|
func @loop_cond_control_input() {
|
||||||
|
tf_executor.graph {
|
||||||
|
%island:2 = tf_executor.island {
|
||||||
|
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
|
||||||
|
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
|
||||||
|
tf_executor.yield %const : tensor<*xi1>
|
||||||
|
}
|
||||||
|
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
|
||||||
|
tf_executor.fetch %loop_cond#0 : tensor<*xi1>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||||
|
// CHECK: tf_executor.Enter {{.*}}, %[[CONTROL]]
|
||||||
|
func @enter_control_input() {
|
||||||
|
tf_executor.graph {
|
||||||
|
%island:2 = tf_executor.island {
|
||||||
|
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||||
|
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||||
|
tf_executor.yield %const : tensor<*xi32>
|
||||||
|
}
|
||||||
|
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
|
||||||
|
tf_executor.fetch %enter#0 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||||
|
// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]])
|
||||||
|
func @switchn_control_input(%arg1: tensor<i32>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
%island:2 = tf_executor.island {
|
||||||
|
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||||
|
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||||
|
tf_executor.yield %const : tensor<*xi32>
|
||||||
|
}
|
||||||
|
%switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32>
|
||||||
|
tf_executor.fetch %switchn#0 : tensor<*xi32>
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -9,5 +9,7 @@ func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
|
|||||||
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||||
// CHECK: device = "cpu"
|
// CHECK: device = "cpu"
|
||||||
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||||
return %2 : tensor<3x3xf32>
|
// CHECK: device = "gpu"
|
||||||
|
%3 = "tf.Relu"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"]} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||||
|
return %3 : tensor<3x3xf32>
|
||||||
}
|
}
|
||||||
|
@ -6,10 +6,10 @@ func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tens
|
|||||||
// Check that BiasAdd was converted to forced data format, and layout
|
// Check that BiasAdd was converted to forced data format, and layout
|
||||||
// dependent arguments and results passed through transpose nodes.
|
// dependent arguments and results passed through transpose nodes.
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
@ -20,10 +20,10 @@ func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tens
|
|||||||
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
|
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
|
||||||
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
@ -38,4 +38,38 @@ func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8x
|
|||||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
|
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @transposeConv2D
|
||||||
|
func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> {
|
||||||
|
|
||||||
|
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||||
|
// dilations, etc...). This test only verifies that changing convolution data
|
||||||
|
// layout will update all the attributes.
|
||||||
|
|
||||||
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
|
|
||||||
|
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||||
|
// CHECK-SAME: data_format = "NCHW"
|
||||||
|
// CHECK-SAME: dilations = [1, 4, 2, 3]
|
||||||
|
// CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
|
||||||
|
// CHECK-SAME: padding = "EXPLICIT"
|
||||||
|
// CHECK-SAME: strides = [5, 8, 6, 7]
|
||||||
|
// CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||||
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
|
%0 = "tf.Conv2D"(%input, %filter)
|
||||||
|
{
|
||||||
|
data_format = "NHWC",
|
||||||
|
dilations = [1, 2, 3, 4],
|
||||||
|
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
padding = "EXPLICIT",
|
||||||
|
strides = [5, 6, 7, 8]
|
||||||
|
} : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||||
|
|
||||||
|
return %0 : tensor<1x32x32x8xf32>
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @transposeConv2D
|
||||||
|
func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
|
||||||
|
|
||||||
|
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||||
|
// dilations, etc...). This test only verifies that changing convolution data
|
||||||
|
// layout will update all the attributes.
|
||||||
|
|
||||||
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
|
|
||||||
|
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
// CHECK-SAME: dilations = [1, 3, 4, 2]
|
||||||
|
// CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4]
|
||||||
|
// CHECK-SAME: padding = "EXPLICIT"
|
||||||
|
// CHECK-SAME: strides = [5, 7, 8, 6]
|
||||||
|
// CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||||
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
|
%0 = "tf.Conv2D"(%input, %filter)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
dilations = [1, 2, 3, 4],
|
||||||
|
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
padding = "EXPLICIT",
|
||||||
|
strides = [5, 6, 7, 8]
|
||||||
|
} : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||||
|
|
||||||
|
return %0 : tensor<1x8x32x32xf32>
|
||||||
|
}
|
@ -3,14 +3,14 @@
|
|||||||
// CHECK-LABEL: func @move_across_single_op
|
// CHECK-LABEL: func @move_across_single_op
|
||||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: return %[[TANH]]
|
// CHECK: return %[[TANH]]
|
||||||
|
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %2 : tensor<1x8x4x4xf32>
|
return %2 : tensor<1x8x4x4xf32>
|
||||||
}
|
}
|
||||||
@ -18,17 +18,17 @@ func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
|||||||
// CHECK-LABEL: func @move_across_multiple_ops
|
// CHECK-LABEL: func @move_across_multiple_ops
|
||||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: return %[[TANH1]]
|
// CHECK: return %[[RELU]]
|
||||||
|
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
%1 = "tf.Tanh"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%1 = "tf.Relu"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %3 : tensor<1x8x4x4xf32>
|
return %3 : tensor<1x8x4x4xf32>
|
||||||
}
|
}
|
||||||
@ -36,15 +36,15 @@ func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32
|
|||||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
||||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: return %[[ADD]]
|
// CHECK: return %[[ADD]]
|
||||||
|
|
||||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %2 : tensor<1x8x4x4xf32>
|
return %2 : tensor<1x8x4x4xf32>
|
||||||
}
|
}
|
||||||
@ -52,7 +52,7 @@ func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4
|
|||||||
// CHECK-LABEL: func @move_with_multiple_uses
|
// CHECK-LABEL: func @move_with_multiple_uses
|
||||||
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
@ -60,8 +60,8 @@ func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32>
|
|||||||
|
|
||||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %3 : tensor<1x8x4x4xf32>
|
return %3 : tensor<1x8x4x4xf32>
|
||||||
}
|
}
|
||||||
|
@ -3,13 +3,13 @@
|
|||||||
// CHECK-LABEL: func @move_across_single_op
|
// CHECK-LABEL: func @move_across_single_op
|
||||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %2 : tensor<1x8x4x4xf32>
|
return %2 : tensor<1x8x4x4xf32>
|
||||||
@ -18,16 +18,16 @@ func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
|||||||
// CHECK-LABEL: func @move_across_multiple_ops
|
// CHECK-LABEL: func @move_across_multiple_ops
|
||||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||||
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x4x4x8xf32>
|
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32>
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH1]], %[[RES_PERM]])
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]])
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||||
%3 = "tf.Tanh"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
%3 = "tf.Relu"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %3 : tensor<1x8x4x4xf32>
|
return %3 : tensor<1x8x4x4xf32>
|
||||||
}
|
}
|
||||||
@ -35,14 +35,14 @@ func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32
|
|||||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
|
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
%3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
%3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
return %3 : tensor<1x8x4x4xf32>
|
return %3 : tensor<1x8x4x4xf32>
|
||||||
@ -54,14 +54,14 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
|
|||||||
// MaxPool operand transpose must be folded into the op and MaxPool
|
// MaxPool operand transpose must be folded into the op and MaxPool
|
||||||
// must use NCHW data format with updated kernel size and strides.
|
// must use NCHW data format with updated kernel size and strides.
|
||||||
|
|
||||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
||||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]])
|
||||||
// CHECK: return %[[RES_TRANSPOSE]]
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
// Transpose NCHW -> NHWC
|
// Transpose NCHW -> NHWC
|
||||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||||
|
|
||||||
// Compute MaxPool in NHWC format
|
// Compute MaxPool in NHWC format
|
||||||
%2 = "tf.MaxPool"(%1)
|
%2 = "tf.MaxPool"(%1)
|
||||||
@ -72,3 +72,49 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
|
|||||||
|
|
||||||
return %2 : tensor<1x56x56x64xf32>
|
return %2 : tensor<1x56x56x64xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_into_mean
|
||||||
|
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
|
||||||
|
|
||||||
|
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>}
|
||||||
|
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
|
||||||
|
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi32>) -> tensor<1x64xf32>
|
||||||
|
// CHECK: return %[[MEAN]]
|
||||||
|
|
||||||
|
// Transpose NCHW -> NHWC
|
||||||
|
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||||
|
|
||||||
|
// Compute Mean over spatial dimensions in NHWC format.
|
||||||
|
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi32>) -> tensor<1x64xf32>
|
||||||
|
|
||||||
|
return %3 : tensor<1x64xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @fold_into_fused_batch_norm
|
||||||
|
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
|
||||||
|
|
||||||
|
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||||
|
// CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW"
|
||||||
|
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
|
||||||
|
// CHECK: return %[[RES_TRANSPOSE]]
|
||||||
|
|
||||||
|
// Transpose NCHW -> NHWC
|
||||||
|
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||||
|
|
||||||
|
// Compute FusedBatchNormV3 in NHWC format
|
||||||
|
%2, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||||
|
= "tf.FusedBatchNormV3"(%1, %arg1, %arg1, %arg1, %arg1)
|
||||||
|
{
|
||||||
|
data_format = "NHWC",
|
||||||
|
epsilon = 1.001 : f32,
|
||||||
|
exponential_avg_factor = 1.0 : f32,
|
||||||
|
is_training = false
|
||||||
|
}
|
||||||
|
: (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||||
|
-> (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||||
|
|
||||||
|
return %2#0 : tensor<1x112x112x64xf32>
|
||||||
|
}
|
||||||
|
@ -4,15 +4,15 @@
|
|||||||
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
|
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
// Convert input: NCHW -> NHWC
|
// Convert input: NCHW -> NHWC
|
||||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
// Compute in NHWC
|
// Compute in NHWC
|
||||||
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
// Convert result back: NHWC -> NCHW
|
// Convert result back: NHWC -> NCHW
|
||||||
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
// Check that BiasAdd computed in NCHW format, and all redundant transpose
|
// Check that BiasAdd computed in NCHW format, and all redundant transpose
|
||||||
// operations removed from the function.
|
// operations removed from the function.
|
@ -0,0 +1,156 @@
|
|||||||
|
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @transpose_resnet_layer
|
||||||
|
func @transpose_resnet_layer(%arg0: tensor<?x224x224x3xf32>, // input
|
||||||
|
%arg1: tensor<64xf32>, // batch_norm args
|
||||||
|
%arg2: tensor<256xf32>, // batch_norm args
|
||||||
|
%arg3: tensor<7x7x3x64xf32>, // conv filter #0
|
||||||
|
%arg4: tensor<1x1x64x256xf32> // conv filter #1
|
||||||
|
) -> tensor<?x256xf32> {
|
||||||
|
|
||||||
|
// This is a simplified ResNet layer that gets input in NHWC format, converts
|
||||||
|
// it to NCHW before padding, and does all computations in NCHW (this is the
|
||||||
|
// default setup for ResNet model trained in fp32 on GPU).
|
||||||
|
//
|
||||||
|
// To be able to use Tensor Cores on latest NVIDIA GPUs this model has to be
|
||||||
|
// converted to NHWC data format.
|
||||||
|
|
||||||
|
// Padding in spatial dimension (NCHW)
|
||||||
|
%0 = "tf.Const"() {value = dense<[[0, 0], [0, 0], [3, 3], [3, 3]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
|
||||||
|
|
||||||
|
// Reduce over spatial dimensions (NCHW)
|
||||||
|
%1 = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||||
|
|
||||||
|
// Transpose input: NHWC -> NCHW
|
||||||
|
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%3 = "tf.Transpose"(%arg0, %2) : (tensor<?x224x224x3xf32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
|
||||||
|
|
||||||
|
// Pad spatial dimensions.
|
||||||
|
%4 = "tf.Pad"(%3, %0) : (tensor<?x3x224x224xf32>, tensor<4x2xi32>) -> tensor<?x3x230x230xf32>
|
||||||
|
|
||||||
|
// Shuffled paddings.
|
||||||
|
// CHECK: %[[PADDINGS:[0-9]*]] = "tf.Const"(){{.*}}[0, 0], [3, 3], [3, 3], [0, 0]
|
||||||
|
|
||||||
|
// Pad input with new paddings.
|
||||||
|
// CHECK: %[[PAD:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDINGS]])
|
||||||
|
// CHECK-SAME: (tensor<?x224x224x3xf32>, tensor<4x2xi32>) -> tensor<?x230x230x3xf32>
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Convolution layer #0.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%5 = "tf.Conv2D"(%4, %arg3)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
dilations = [1, 1, 1, 1],
|
||||||
|
explicit_paddings = [],
|
||||||
|
padding = "VALID",
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
} : (tensor<?x3x230x230xf32>, tensor<7x7x3x64xf32>) -> tensor<?x64x112x112xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CONV0:[0-9]*]] = "tf.Conv2D"
|
||||||
|
// CHECK-SAME %[[PAD]]
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
// CHECK-SAME: strides = [1, 2, 2, 1]
|
||||||
|
|
||||||
|
%6, %batch_mean, %batch_variance, %reserved_1, %reserved_2, %reserved_3 =
|
||||||
|
"tf.FusedBatchNormV3"(%5, %arg1, %arg1, %arg1, %arg1)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
epsilon = 1.001000e-05 : f32,
|
||||||
|
is_training = false
|
||||||
|
} : (tensor<?x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||||
|
-> (tensor<?x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>)
|
||||||
|
|
||||||
|
// CHECK: "tf.FusedBatchNormV3"
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
|
||||||
|
%7 = "tf.Relu"(%6) : (tensor<?x64x112x112xf32>) -> tensor<?x64x112x112xf32>
|
||||||
|
%8 = "tf.MaxPool"(%7)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
ksize = [1, 1, 3, 3],
|
||||||
|
padding = "SAME",
|
||||||
|
strides = [1, 1, 2, 2]
|
||||||
|
} : (tensor<?x64x112x112xf32>) -> tensor<?x64x56x56xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
// CHECK-SAME: ksize = [1, 3, 3, 1]
|
||||||
|
// CHECK-SAME: strides = [1, 2, 2, 1]
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Convolution layer #1.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%9 = "tf.Conv2D"(%8, %arg4)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
dilations = [1, 1, 1, 1],
|
||||||
|
explicit_paddings = [],
|
||||||
|
padding = "VALID",
|
||||||
|
strides = [1, 1, 1, 1]
|
||||||
|
} : (tensor<?x64x56x56xf32>, tensor<1x1x64x256xf32>) -> tensor<?x256x56x56xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CONV1:[0-9]*]] = "tf.Conv2D"(%[[MAX_POOL]], %arg4)
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
|
||||||
|
%10, %batch_mean_1, %batch_variance_1, %reserved_1_1, %reserved_1_2, %reserved_1_3 =
|
||||||
|
"tf.FusedBatchNormV3"(%9, %arg2, %arg2, %arg2, %arg2)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
epsilon = 1.001000e-05 : f32
|
||||||
|
} : (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>)
|
||||||
|
-> (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<*xf32>)
|
||||||
|
|
||||||
|
// CHECK: %[[BATCH_NORM1:[_a-z0-9]*]], {{.*}} = "tf.FusedBatchNormV3"
|
||||||
|
// CHECK-SAME: %[[CONV1]]
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Convolution layer #2.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
%11 = "tf.Conv2D"(%8, %arg4)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
dilations = [1, 1, 1, 1],
|
||||||
|
explicit_paddings = [],
|
||||||
|
padding = "VALID",
|
||||||
|
strides = [1, 1, 1, 1]
|
||||||
|
} : (tensor<?x64x56x56xf32>, tensor<1x1x64x256xf32>) -> tensor<?x256x56x56xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[CONV2:[0-9]*]] = "tf.Conv2D"(%[[MAX_POOL]], %arg4)
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
|
||||||
|
%12, %batch_mean_2, %batch_variance_2, %reserved_2_1, %reserved_2_2, %reserved_2_3 =
|
||||||
|
"tf.FusedBatchNormV3"(%11, %arg2, %arg2, %arg2, %arg2)
|
||||||
|
{
|
||||||
|
data_format = "NCHW",
|
||||||
|
epsilon = 1.001000e-05 : f32
|
||||||
|
} : (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>)
|
||||||
|
-> (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<*xf32>)
|
||||||
|
|
||||||
|
// CHECK: %[[BATCH_NORM2:[_a-z0-9]*]], {{.*}} = "tf.FusedBatchNormV3"
|
||||||
|
// CHECK-SAME: %[[CONV2]]
|
||||||
|
// CHECK-SAME: data_format = "NHWC"
|
||||||
|
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
// Add results of convolution layers #1 and #2.
|
||||||
|
// ------------------------------------------------------------------------ //
|
||||||
|
|
||||||
|
%14 = "tf.AddV2"(%10, %12) : (tensor<?x256x56x56xf32>, tensor<?x256x56x56xf32>) -> tensor<?x256x56x56xf32>
|
||||||
|
%15 = "tf.Relu"(%14) : (tensor<?x256x56x56xf32>) -> tensor<?x256x56x56xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[BATCH_NORM1]], %[[BATCH_NORM2]])
|
||||||
|
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[ADD]])
|
||||||
|
|
||||||
|
// Reduce spatial dimensions
|
||||||
|
%16 = "tf.Mean"(%15, %1) : (tensor<?x256x56x56xf32>, tensor<2xi32>) -> tensor<?x256xf32>
|
||||||
|
|
||||||
|
// Mean should compute reduction over NHWC spatial dimensions.
|
||||||
|
// CHECK: %[[MEAN_DIMS:[0-9]*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>}
|
||||||
|
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%[[RELU]], %[[MEAN_DIMS]])
|
||||||
|
// CHECK-SAME: (tensor<?x56x56x256xf32>, tensor<2xi32>) -> tensor<?x256xf32>
|
||||||
|
// CHECK: return %[[MEAN]] : tensor<?x256xf32>
|
||||||
|
|
||||||
|
return %16 : tensor<?x256xf32>
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,194 @@
|
|||||||
|
// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @check_regions_to_islands
|
||||||
|
func @check_regions_to_islands() {
|
||||||
|
tf_executor.graph {
|
||||||
|
tf_executor.island() {
|
||||||
|
"tf_device.parallel_execute"() ({
|
||||||
|
tf_device.return
|
||||||
|
},
|
||||||
|
{
|
||||||
|
tf_device.return
|
||||||
|
}) {} : () -> ()
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield
|
||||||
|
// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||||
|
// CHECK: tf_executor.yield
|
||||||
|
// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||||
|
// CHECK: tf_executor.yield
|
||||||
|
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||||
|
// CHECK-NEXT: tf_executor.yield
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @check_regions_to_islands_with_inputs
|
||||||
|
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||||
|
func @check_regions_to_islands_with_inputs(%arg0 : tensor<i1>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
%1:2 = tf_executor.island {
|
||||||
|
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_executor.yield %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
tf_executor.island() {
|
||||||
|
"tf_device.parallel_execute"() ({
|
||||||
|
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_device.return %3 : tensor<i1>
|
||||||
|
},
|
||||||
|
{
|
||||||
|
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||||
|
tf_device.return %5 : tensor<i32>
|
||||||
|
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||||
|
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||||
|
// CHECK-NEXT: tf_executor.yield
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs
|
||||||
|
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||||
|
func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor<i1>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
%1:2 = tf_executor.island {
|
||||||
|
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_executor.yield %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
%7 = tf_executor.ControlTrigger {}
|
||||||
|
%8 = tf_executor.ControlTrigger {}
|
||||||
|
tf_executor.island(%7, %8) {
|
||||||
|
"tf_device.parallel_execute"() ({
|
||||||
|
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_device.return %3 : tensor<i1>
|
||||||
|
},
|
||||||
|
{
|
||||||
|
%5 = "tf.opC"() : () -> tensor<i32>
|
||||||
|
tf_device.return %5 : tensor<i32>
|
||||||
|
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
|
||||||
|
// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
|
||||||
|
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) {
|
||||||
|
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor<i32>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||||
|
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||||
|
// CHECK-NEXT: tf_executor.yield
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs
|
||||||
|
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||||
|
func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor<i1>) {
|
||||||
|
tf_executor.graph {
|
||||||
|
%1:2 = tf_executor.island {
|
||||||
|
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_executor.yield %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
%7:3 = tf_executor.island() {
|
||||||
|
%8:2 = "tf_device.parallel_execute"() (
|
||||||
|
{
|
||||||
|
%3 = "tf.opB"() : () -> tensor<i1>
|
||||||
|
tf_device.return %3 : tensor<i1>
|
||||||
|
},
|
||||||
|
{
|
||||||
|
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||||
|
tf_device.return %5 : tensor<i32>
|
||||||
|
}
|
||||||
|
) {} : () -> (tensor<i1>, tensor<i32>)
|
||||||
|
|
||||||
|
tf_executor.yield %8#0, %8#1 : tensor<i1>, tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
tf_executor.island {
|
||||||
|
"tf.opD"(%7#0, %7#1) : (tensor<i1>, tensor<i32>) -> ()
|
||||||
|
tf_executor.yield
|
||||||
|
}
|
||||||
|
tf_executor.fetch
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||||
|
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||||
|
// CHECK: %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]]
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs
|
||||||
|
func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor<i1>) -> tensor<i1> {
|
||||||
|
%0 = tf_executor.graph {
|
||||||
|
%1:2 = tf_executor.island {
|
||||||
|
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||||
|
tf_executor.yield %2 : tensor<i1>
|
||||||
|
}
|
||||||
|
%8:3 = tf_executor.island() {
|
||||||
|
%7:2 = "tf_device.parallel_execute"() ({
|
||||||
|
%3 = "tf.opB"() : () -> tensor<i1>
|
||||||
|
tf_device.return %3 : tensor<i1>
|
||||||
|
},
|
||||||
|
{
|
||||||
|
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||||
|
tf_device.return %5 : tensor<i32>
|
||||||
|
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||||
|
tf_executor.yield %7#0, %7#1 : tensor<i1>, tensor<i32>
|
||||||
|
}
|
||||||
|
tf_executor.fetch %8#0 : tensor<i1>
|
||||||
|
}
|
||||||
|
return %0 : tensor<i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||||
|
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||||
|
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||||
|
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i32>
|
||||||
|
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||||
|
// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||||
|
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor<i1>, tensor<i32>
|
@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Test valid tf.MatrixBandPart
|
||||||
|
// CHECK-LABEL: func @testValidMatrixBandPartOp
|
||||||
|
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||||
|
return %0 : tensor<64x64xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test valid tf.MatrixBandPart
|
||||||
|
// CHECK-LABEL: func @testValidMatrixBandPartOp3D
|
||||||
|
func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> {
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16>
|
||||||
|
return %0 : tensor<64x64x64xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test valid tf.MatrixBandPart
|
||||||
|
// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked
|
||||||
|
func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||||
|
return %0 : tensor<*xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test invalid tf.MatrixBandPart
|
||||||
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||||
|
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||||
|
return %0 : tensor<64x64xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test invalid tf.MatrixBandPart
|
||||||
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||||
|
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||||
|
return %0 : tensor<*xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test invalid tf.MatrixBandPart
|
||||||
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> {
|
||||||
|
// expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}}
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64>
|
||||||
|
return %0 : tensor<i64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test invalid tf.MatrixBandPart
|
||||||
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> {
|
||||||
|
// expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64>
|
||||||
|
return %0 : tensor<64x64xi64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test invalid tf.MatrixBandPart
|
||||||
|
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> {
|
||||||
|
// expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||||
|
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64>
|
||||||
|
return %0 : tensor<64x64xi64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// tf.{|Stateful}PartitionedCall
|
// tf.{|Stateful}PartitionedCall
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -47,7 +47,7 @@ class TestModule(tf.Module):
|
|||||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||||
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
|
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
|
||||||
# CHECK-SAME: %arg2: tensor<f32> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
# CHECK-SAME: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
||||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
|
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
|
||||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
||||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||||
|
@ -1,53 +0,0 @@
|
|||||||
// RUN: tf-opt -tf-saved-model-inline-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
|
||||||
|
|
||||||
// Test case: Simple case of inlining.
|
|
||||||
|
|
||||||
// CHECK-NOT: tf_saved_model.global_tensor
|
|
||||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
|
||||||
|
|
||||||
// CHECK: func @f()
|
|
||||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
|
||||||
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
|
||||||
|
|
||||||
// Test case: Do not inline mutable global tensors.
|
|
||||||
|
|
||||||
// CHECK: tf_saved_model.global_tensor
|
|
||||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
|
||||||
|
|
||||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
|
||||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
|
||||||
// CHECK-NOT: tf.Const
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
|
||||||
|
|
||||||
// Test case: Sanity check handling of non-bound inputs.
|
|
||||||
// The pass shouldn't do anything in this case.
|
|
||||||
|
|
||||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
|
|
||||||
func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
|
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
|
||||||
// CHECK-NOT: tf.Const
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: have an arg that isn't a bound input.
|
|
@ -25,7 +25,7 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
// CHECK: func @__concrete_function_run_computation
|
// CHECK: func @__concrete_function_run_computation
|
||||||
func @__concrete_function_run_computation(
|
func @__concrete_function_run_computation(
|
||||||
%arg0: tensor<f32> {tf_saved_model.index_path = [0, "foo"]},
|
%arg0: tensor<f32> {tf_saved_model.index_path = [0, "foo"]},
|
||||||
%arg1: tensor<1x64xf32> {tf_saved_model.bound_input = @some_constant},
|
%arg1: tensor<!tf.resource<tensor<1x64xf32>>> {tf_saved_model.bound_input = @some_constant},
|
||||||
%arg2: tensor<!tf.resource<tensor<?x64xf32>>> {tf_saved_model.bound_input = @some_variable}
|
%arg2: tensor<!tf.resource<tensor<?x64xf32>>> {tf_saved_model.bound_input = @some_variable}
|
||||||
) -> (
|
) -> (
|
||||||
tensor<f32> {tf_saved_model.index_path = [0, "bar"]}
|
tensor<f32> {tf_saved_model.index_path = [0, "bar"]}
|
||||||
|
@ -244,7 +244,7 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||||
// expected-error@+1 {{mutable bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
|
// expected-error@+1 {{bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
|
||||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||||
return
|
return
|
||||||
@ -253,18 +253,6 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
|
||||||
|
|
||||||
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<1xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
|
||||||
// expected-error@+1 {{bound input for immutable 'tf_saved_model.global_tensor' must match the global tensor's type}}
|
|
||||||
func @f(%arg0: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @v})
|
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
||||||
|
@ -6,19 +6,16 @@
|
|||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// Test case: Basic test of freezing.
|
// Test case: Basic test of marking immutable.
|
||||||
|
|
||||||
// CHECK: "tf_saved_model.global_tensor"() {
|
// CHECK: "tf_saved_model.global_tensor"() {
|
||||||
// CHECK-NOT: is_mutable
|
// CHECK-NOT: is_mutable
|
||||||
// CHECK-SAME: } : () -> ()
|
// CHECK-SAME: } : () -> ()
|
||||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||||
|
|
||||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
|
||||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||||
// CHECK-NOT: tf.ReadVariableOp
|
|
||||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||||
// CHECK: return %arg0
|
|
||||||
return %val : tensor<f32>
|
return %val : tensor<f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,18 +25,16 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// Test case: Don't freeze if the variable is mutated.
|
// Test case: Don't mark immutable if the variable is mutated.
|
||||||
|
|
||||||
// CHECK: "tf_saved_model.global_tensor"() {
|
// CHECK: "tf_saved_model.global_tensor"() {
|
||||||
// CHECK-SAME: is_mutable
|
// CHECK-SAME: is_mutable
|
||||||
// CHECK-SAME: } : () -> ()
|
// CHECK-SAME: } : () -> ()
|
||||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||||
|
|
||||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
|
||||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||||
// CHECK: tf.AssignVariableOp
|
|
||||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -50,14 +45,13 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// Test case: Don't freeze if the variable is exported.
|
// Test case: Don't mark immutable if the variable is exported.
|
||||||
|
|
||||||
// CHECK: "tf_saved_model.global_tensor"() {
|
// CHECK: "tf_saved_model.global_tensor"() {
|
||||||
// CHECK: is_mutable
|
// CHECK: is_mutable
|
||||||
// CHECK-SAME: } : () -> ()
|
// CHECK-SAME: } : () -> ()
|
||||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||||
|
|
||||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
|
||||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||||
@ -71,7 +65,7 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// Test case: Check that a non-bound input is not modified.
|
// Test case: Check that a non-bound input is left unchanged.
|
||||||
|
|
||||||
// CHECK: func @g
|
// CHECK: func @g
|
||||||
func @g(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
func @g(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||||
@ -86,14 +80,16 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
// Test case: Check that an immutable bound input isn't modified.
|
// Test case: Check that no change is made for a global tensor that is already
|
||||||
|
// immutable.
|
||||||
|
|
||||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||||
|
|
||||||
// CHECK: func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
// CHECK: func @h(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||||
func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
func @h(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||||
attributes {tf_saved_model.exported_names = ["h"]} {
|
attributes {tf_saved_model.exported_names = ["h"]} {
|
||||||
return %arg0 : tensor<f32>
|
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -134,7 +130,7 @@ module attributes {tf_saved_model.semantics} {
|
|||||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||||
|
|
||||||
// CHECK: func @f()
|
// CHECK: func @f()
|
||||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -102,18 +102,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
|||||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
|
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||||
|
%arg4: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
|
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"}) {
|
||||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||||
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
|
%1:7 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
|
||||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
{body = @while_body_7560,
|
||||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
|
||||||
cond = @while_cond_7550, device = "", is_stateless = false,
|
|
||||||
output_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]}
|
|
||||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
|
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
|
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: func @while_body_7560
|
// CHECK: func @while_body_7560
|
||||||
@ -122,9 +123,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
|||||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||||
|
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
|
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
|
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
|
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
|
||||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||||
%2:2 = "tf._TPUCompileMlir"() {
|
%2:2 = "tf._TPUCompileMlir"() {
|
||||||
@ -133,27 +137,33 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
|||||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||||
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
|
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
|
||||||
%new_var = "tf._UnknownOp0_"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||||
|
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
|
||||||
|
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||||
tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||||
[%new_var, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
|
[%newvar, %arg6] as %arg32: tensor<*x!tf.resource<tensor<f32>>>)
|
||||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||||
// %arg30 is used in the cond function, and %arg31 is not pass-through of
|
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
|
||||||
// while inputs, so neither should be formatted.
|
// %arg32 is not a pass-through.
|
||||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %2#1)
|
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %2#1)
|
||||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
|
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
|
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
|
||||||
tf_device.return
|
tf_device.return
|
||||||
}
|
}
|
||||||
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>
|
||||||
}
|
}
|
||||||
// CHECK-LABEL: func @while_cond_7550
|
// CHECK-LABEL: func @while_cond_7550
|
||||||
func @while_cond_7550(%arg0: tensor<i32>,
|
func @while_cond_7550(%arg0: tensor<i32>,
|
||||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||||
|
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||||
|
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||||
-> tensor<i1> {
|
-> tensor<i1> {
|
||||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
@ -891,35 +891,3 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
return %0 : tensor<?xi32>
|
return %0 : tensor<?xi32>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Tests simple case of launch_func on TPU with replication with multiple logical cores.
|
|
||||||
|
|
||||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
|
||||||
// CHECK-LABEL: func @replicated_tpu_launch_func_with_multiple_logical_cores
|
|
||||||
func @replicated_tpu_launch_func_with_multiple_logical_cores(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
|
||||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
|
||||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
|
||||||
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
|
|
||||||
// CHECK-SAME: n = 2
|
|
||||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
|
||||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
|
|
||||||
// CHECK: "tf_device.parallel_execute"()
|
|
||||||
// CHECK-NEXT: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
|
|
||||||
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[ARG_1]], %[[COMPILE_OUTPUT]]#1)
|
|
||||||
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
|
|
||||||
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[ARG_1]], %[[COMPILE_OUTPUT]]#1)
|
|
||||||
%2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
|
||||||
tf_device.return %2 : tensor<?xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
|
|
||||||
return %2 : tensor<?xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
|
||||||
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
|
||||||
return %0 : tensor<?xi32>
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
|
// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
include "mlir/Dialect/StandardOps/Ops.td"
|
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||||
|
|
||||||
// Here, the element type can be any integer or float type. But, note that only
|
// Here, the element type can be any integer or float type. But, note that only
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
// This transformation pass transforms functional control flow operations in the
|
// This transformation pass transforms functional control flow operations in the
|
||||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user