Update from master
This commit is contained in:
commit
7ba6a92ff9
21
.bazelrc
21
.bazelrc
@ -37,7 +37,6 @@
|
||||
# v2: Build TF v2
|
||||
#
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
@ -227,6 +226,14 @@ build --noincompatible_remove_legacy_whole_archive
|
||||
# https://github.com/tensorflow/community/pull/179
|
||||
build --noincompatible_prohibit_aapt1
|
||||
|
||||
# Enable XLA
|
||||
build --action_env=TF_ENABLE_XLA=1
|
||||
build --define=with_xla_support=true
|
||||
|
||||
# Keep config XLA until all build scripts are cleaned up.
|
||||
build:xla --action_env=TF_ENABLE_XLA=1
|
||||
build:xla --define=with_xla_support=true
|
||||
|
||||
# Modular TF build options
|
||||
build:dynamic_kernels --define=dynamic_loaded_kernels=true
|
||||
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
|
||||
@ -312,10 +319,6 @@ build:v2 --action_env=TF2_BEHAVIOR=1
|
||||
build --config=v2
|
||||
test --config=v2
|
||||
|
||||
# Enable XLA
|
||||
build:xla --action_env=TF_ENABLE_XLA=1
|
||||
build:xla --define=with_xla_support=true
|
||||
|
||||
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
|
||||
# Options when using remote execution
|
||||
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
|
||||
@ -348,7 +351,6 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
||||
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
|
||||
|
||||
# Non-rbe settings we should include because we do not run configure
|
||||
build:rbe_linux --config=xla
|
||||
build:rbe_linux --config=avx_linux
|
||||
build:rbe_linux --config=short_logs
|
||||
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
|
||||
@ -386,9 +388,8 @@ build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
||||
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
|
||||
|
||||
build:rbe_linux_py3 --config=rbe_linux
|
||||
build:rbe_linux_py3 --repo_env=PYTHON_BIN_PATH="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win_1803/bazel_121:toolchain"
|
||||
@ -405,9 +406,7 @@ build:rbe_win --define=override_eigen_strong_inline=true
|
||||
build:rbe_win --jobs=500
|
||||
|
||||
build:rbe_win_py37 --config=rbe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_BIN_PATH=C:\\Python37\\python.exe
|
||||
build:rbe_win_py37 --repo_env=PYTHON_LIB_PATH=C:\\Python37\\lib\\site-packages
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py37
|
||||
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
|
||||
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
|
||||
|
||||
build:rbe_win_py38 --config=rbe
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,6 +22,7 @@ tensorflow/contrib/cmake/_build/
|
||||
/tensorflow/python/framework/fast_tensor_util.cpp
|
||||
/tensorflow/lite/gen/**
|
||||
/tensorflow/lite/tools/make/downloads/**
|
||||
/tensorflow/lite/tools/make/gen/**
|
||||
/api_init_files_list.txt
|
||||
/estimator_api_init_files_list.txt
|
||||
*.whl
|
||||
|
@ -70,7 +70,7 @@ $ python
|
||||
3
|
||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||
>>> hello.numpy()
|
||||
'Hello, TensorFlow!'
|
||||
b'Hello, TensorFlow!'
|
||||
```
|
||||
|
||||
For more examples, see the
|
||||
|
@ -1390,10 +1390,6 @@ def main():
|
||||
else:
|
||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||
|
||||
xla_enabled_by_default = is_linux() or is_macos()
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
xla_enabled_by_default, 'xla')
|
||||
|
||||
set_action_env_var(
|
||||
environ_cp,
|
||||
'TF_NEED_OPENCL_SYCL',
|
||||
|
@ -205,6 +205,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -874,12 +874,12 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
status->status = ctx->context->ClearRemoteExecutors();
|
||||
status->status = ctx->context->SyncExecutors();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
@ -1450,6 +1450,25 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
||||
const void* proto, size_t proto_len,
|
||||
TF_Status* status) {
|
||||
tensorflow::AttrValue attr_value;
|
||||
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
||||
return;
|
||||
}
|
||||
if (op == nullptr || op->operation == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Got a null or uninitialized `op` argument");
|
||||
return;
|
||||
}
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
operation->MutableAttrs()->Set(attr_name, attr_value);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
||||
const char* input_name,
|
||||
TF_Status* status) {
|
||||
@ -1606,7 +1625,7 @@ void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
|
||||
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
|
||||
auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
op->operation.get());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs());
|
||||
*attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
@ -1620,6 +1639,14 @@ void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
attrs->attributes->FillAttrValueMap(name_and_attrs.mutable_attr());
|
||||
name_and_attrs.set_name(attrs->name);
|
||||
status->status = MessageToBuffer(name_and_attrs, buf);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
@ -1740,7 +1767,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
TFE_OpAttrs attributes(&op->Attrs());
|
||||
TFE_OpAttrs attributes(&op->Attrs(), op->Name().c_str());
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
|
@ -382,9 +382,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
const char* worker_name,
|
||||
TF_Status* status);
|
||||
|
||||
// Clear pending streaming requests and error statuses on remote executors.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextClearRemoteExecutors(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
// Sync pending nodes in local executors (including the context default executor
|
||||
// and thread executors) and streaming requests to remote executors, and get the
|
||||
// combined status.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
@ -441,6 +443,21 @@ TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
|
||||
|
||||
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
|
||||
// containing the op name and a map of its attributes.
|
||||
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Set an op's attribute from a serialized AttrValue protocol buffer.
|
||||
//
|
||||
// Analogous to TF_SetAttrValueProto for building graph operations.
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
const char* attr_name,
|
||||
const void* proto,
|
||||
size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 1
|
||||
|
||||
// Struct to be filled in
|
||||
|
@ -236,12 +236,16 @@ struct TFE_Executor {
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||
const char* op_name)
|
||||
: name(op_name), attributes(value) {}
|
||||
|
||||
const char* name;
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
@ -127,7 +128,7 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async) {
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -166,10 +167,14 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
@ -177,6 +182,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h1_task2->handle.get())
|
||||
->Handle();
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -213,9 +229,17 @@ void TestRemoteExecuteSilentCopies(bool async) {
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
@ -416,12 +416,23 @@ void TensorHandleSilentCopy(bool async,
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
|
||||
if (!async) {
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
} else {
|
||||
if (cpu_op) {
|
||||
ASSERT_EQ(op->GetInput(0), arg0);
|
||||
// The GPU handle should be replaced with a CPU copy
|
||||
ASSERT_NE(op->GetInput(1), arg1);
|
||||
} else {
|
||||
// The CPU handle should be replaced with a GPU copy
|
||||
ASSERT_NE(op->GetInput(0), arg0);
|
||||
ASSERT_EQ(op->GetInput(1), arg1);
|
||||
}
|
||||
}
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
@ -1578,4 +1589,52 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
tensorflow::NameAttrList name_and_attrs;
|
||||
ASSERT_TRUE(name_and_attrs.ParseFromArray(serialized_attr_values->data,
|
||||
serialized_attr_values->length));
|
||||
ASSERT_EQ("VarHandleOp", name_and_attrs.name());
|
||||
ASSERT_EQ(tensorflow::DT_INT64,
|
||||
name_and_attrs.attr().find("dtype")->second.type());
|
||||
TF_DeleteBuffer(serialized_attr_values);
|
||||
|
||||
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
|
||||
string serialized_dtype;
|
||||
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||
&serialized_dtype));
|
||||
TFE_OpSetAttrValueProto(
|
||||
second_var_op, "dtype",
|
||||
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||
serialized_dtype.length(), status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
second_var_op->operation.get());
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(second_var_op);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -68,6 +68,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,15 +21,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
string TestDataPbTxt() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two_pbtxt", "00000123");
|
||||
}
|
||||
|
||||
string TestDataSharded() {
|
||||
return io::JoinPath("tensorflow", "cc", "saved_model", "testdata",
|
||||
"half_plus_two", "00000123");
|
||||
}
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -49,8 +56,7 @@ class ReaderTest : public ::testing::Test {
|
||||
TEST_F(ReaderTest, TagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -59,8 +65,7 @@ TEST_F(ReaderTest, TagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -73,8 +78,7 @@ TEST_F(ReaderTest, NoTagMatch) {
|
||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataSharded());
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
@ -87,8 +91,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
const string export_dir = GetDataDependencyFilepath(TestDataPbTxt());
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
@ -97,8 +100,7 @@ TEST_F(ReaderTest, PbtxtFormat) {
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
const string export_dir = GetDataDependencyFilepath("missing-path");
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
|
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
3
tensorflow/compiler/mlir/g3doc/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# TensorFlow MLIR
|
||||
|
||||
These are the docs for: https://www.tensorflow.org/mlir
|
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
26
tensorflow/compiler/mlir/g3doc/_book.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
upper_tabs:
|
||||
# Tabs left of dropdown menu
|
||||
- include: /_upper_tabs_left.yaml
|
||||
- include: /api_docs/_upper_tabs_api.yaml
|
||||
# Dropdown menu
|
||||
- name: Resources
|
||||
path: /resources
|
||||
is_default: true
|
||||
menu:
|
||||
- include: /resources/_menu_toc.yaml
|
||||
lower_tabs:
|
||||
# Subsite tabs
|
||||
other:
|
||||
- name: Guide
|
||||
contents:
|
||||
- title: Overview
|
||||
path: /mlir/overview
|
||||
- heading: Dialects
|
||||
- title: Overview
|
||||
path: /mlir/dialects
|
||||
- title: TensorFlow
|
||||
path: /mlir/tf_ops
|
||||
- title: TensorFlow Lite
|
||||
path: /mlir/tfl_ops
|
||||
|
||||
- include: /_upper_tabs_right.yaml
|
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
54
tensorflow/compiler/mlir/g3doc/_index.yaml
Normal file
@ -0,0 +1,54 @@
|
||||
book_path: /mlir/_book.yaml
|
||||
project_path: /mlir/_project.yaml
|
||||
description: <!--no description-->
|
||||
landing_page:
|
||||
custom_css_path: /site-assets/css/style.css
|
||||
rows:
|
||||
- heading: MLIR unifies the infrastructure for high-performance ML models in TensorFlow.
|
||||
items:
|
||||
- description: >
|
||||
The <a href="https://mlir.llvm.org/" class="external">MLIR</a> project defines a common
|
||||
intermediate representation (IR) that unifies the infrastructure required to execute high
|
||||
performance machine learning models in TensorFlow and similar ML frameworks. This project
|
||||
will include the application of HPC techniques, along with integration of
|
||||
search algorithms like reinforcement learning. MLIR aims to reduce the
|
||||
cost to bring up new hardware, and improve usability for existing
|
||||
TensorFlow users.
|
||||
|
||||
- code_block: |
|
||||
<pre class = "prettyprint">
|
||||
// Syntactically similar to LLVM:
|
||||
func @testFunction(%arg0: i32) {
|
||||
%x = call @thingToCall(%arg0) : (i32) -> i32
|
||||
br ^bb1
|
||||
^bb1:
|
||||
%y = addi %x, %x : i32
|
||||
return %y : i32
|
||||
}
|
||||
</pre>
|
||||
|
||||
- classname: devsite-landing-row-cards
|
||||
items:
|
||||
- heading: "Multi-Level Intermediate Representation for Compiler Infrastructure"
|
||||
youtube_id: qzljG6DKgic
|
||||
buttons:
|
||||
- label: Watch the video
|
||||
path: https://www.youtube.com/watch?v=qzljG6DKgic
|
||||
- heading: "A new intermediate representation and compiler framework"
|
||||
image_path: /resources/images/tf-logo-card-16x9.png
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
buttons:
|
||||
- label: Read on TensorFlow blog
|
||||
path: https://blog.tensorflow.org/2019/04/mlir-new-intermediate-representation.html
|
||||
- heading: MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/llvm/llvm-project/tree/master/mlir
|
||||
- heading: TensorFlow MLIR on GitHub
|
||||
image_path: /resources/images/github-card-16x9.png
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
||||
buttons:
|
||||
- label: View on GitHub
|
||||
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir
|
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
37
tensorflow/compiler/mlir/g3doc/dialects.md
Normal file
@ -0,0 +1,37 @@
|
||||
# MLIR dialects
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
To separate different hardware and software targets, MLIR has “dialects”,
|
||||
including:
|
||||
|
||||
* TensorFlow IR, which represents all things possible in TensorFlow graphs.
|
||||
* XLA HLO IR, which is designed to take advantage of XLA’s compilation
|
||||
abilities (with output to, among other things, TPUs).
|
||||
* An experimental affine dialect, which focuses on
|
||||
[polyhedral representations](https://en.wikipedia.org/wiki/Polytope_model)
|
||||
and optimizations.
|
||||
* LLVM IR, which has a 1:1 mapping between it and LLVM’s own representation,
|
||||
allowing MLIR to emit GPU and CPU code through LLVM.
|
||||
* TensorFlow Lite, which will translate to running code on mobile platforms.
|
||||
|
||||
Each dialect consists of a set of defined operations which have invariants
|
||||
placed on them, like: “This is a binary operator, and the inputs and outputs
|
||||
have the same types.”
|
||||
|
||||
## Adding to MLIR
|
||||
|
||||
MLIR has no fixed/built-in list of globally known operations (no “intrinsics”).
|
||||
Dialects can define entirely custom types, which is how MLIR can model things
|
||||
like the LLVM IR type system (which has first class aggregates), domain
|
||||
abstractions important for ML-optimized accelerators like quantized types, and
|
||||
even the Swift or Clang type systems (which are built around Swift/Clang
|
||||
declaration nodes) in the future.
|
||||
|
||||
If you want to connect a new low-level compiler, you would create a new dialect
|
||||
and the lowerings between the TensorFlow Graph dialect and your dialect.
|
||||
This smooths the path for hardware and compiler makers. You can even target
|
||||
dialects at different levels in the same model; the higher-level optimizers
|
||||
will respect the unfamiliar parts of the IR and wait for a lower level to handle
|
||||
it.
|
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
1
tensorflow/compiler/mlir/g3doc/images/mlir-infra.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 148 KiB |
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
36
tensorflow/compiler/mlir/g3doc/overview.md
Normal file
@ -0,0 +1,36 @@
|
||||
# MLIR
|
||||
|
||||
## Overview
|
||||
|
||||
MLIR, or Multi-Level Intermediate Representation, is a representation format
|
||||
and library of compiler utilities that sits between the model representation
|
||||
and low-level compilers/executors that generate hardware-specific code.
|
||||
|
||||
MLIR is, at its heart, a flexible infrastructure for modern optimizing
|
||||
compilers. This means it consists of a specification for intermediate
|
||||
representations (IR) and a code toolkit to perform transformations on that
|
||||
representation. (In compiler parlance, as you move from higher-level
|
||||
representations to lower-level representations, these transformations can be
|
||||
called “lowerings”)
|
||||
|
||||
MLIR is highly influenced by [LLVM](https://llvm.org/) and unabashedly reuses
|
||||
many great ideas from it. It has a flexible type system, and allows
|
||||
representing, analyzing and transforming graphs combining multiple levels of
|
||||
abstraction in the same compilation unit. These abstractions include TensorFlow
|
||||
operations, nested polyhedral loop regions, and even LLVM instructions and fixed
|
||||
hardware operations and types.
|
||||
|
||||
We expect MLIR to be of interest to many groups, including:
|
||||
|
||||
* Compiler researchers and implementers looking to optimize performance and
|
||||
memory consumption of machine learning models
|
||||
* Hardware makers looking for a way to connect their hardware to TensorFlow,
|
||||
such as TPUs, portable neural hardware in phones, and other custom ASICs
|
||||
* People writing language bindings that want to take advantage of optimizing
|
||||
compilers and hardware acceleration.
|
||||
|
||||
The TensorFlow ecosystem contains a number of compilers and optimizers that
|
||||
operate at multiple levels of the software and hardware stack. We expect the
|
||||
gradual adoption of MLIR to simplify every aspect of this stack.
|
||||
|
||||
<img alt="MLIR overview diagram" src="./images/mlir-infra.svg"/>
|
@ -602,6 +602,7 @@ tf_cc_binary(
|
||||
name = "flatbuffer_translate",
|
||||
deps = [
|
||||
":flatbuffer_translate_lib",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
],
|
||||
)
|
||||
|
@ -46,7 +46,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
@ -76,6 +76,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
@ -124,6 +125,20 @@ static opt<bool, true> experimental_prune_unreachable_nodes_unconditionally_flg(
|
||||
llvm::cl::location(experimental_prune_unreachable_nodes_unconditionally),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> input_arrays_flag(
|
||||
"input-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of input tensors, if different from the default inputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string> output_arrays_flag(
|
||||
"output-arrays",
|
||||
llvm::cl::desc(
|
||||
"List of output tensors, if different from the default outputs"),
|
||||
llvm::cl::init(""));
|
||||
|
||||
namespace {
|
||||
bool IsScalar(const TensorT& tensor) {
|
||||
// TODO(b/138222071) We can't distinguish scalars and unranked tensors
|
||||
@ -590,6 +605,11 @@ StatusOr<Operation*> ConvertOp(
|
||||
op_state.addTypes({type});
|
||||
}
|
||||
|
||||
if (op_name == "tfl.lstm") {
|
||||
// TODO(b/147587779): add the right region if region is empty.
|
||||
op_state.addRegion();
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attrs;
|
||||
if (IsCustomOp(op_name)) {
|
||||
auto status = mlir::CustomOptionsToAttributes(op_name, op.custom_options,
|
||||
@ -610,43 +630,30 @@ StatusOr<Operation*> ConvertOp(
|
||||
return builder.createOperation(op_state);
|
||||
}
|
||||
|
||||
// Returns the output tensor indices for the given subgraph. If
|
||||
// ordered_output_arrays is provided, then return the tensor indices in
|
||||
// ordered_output_arrays.
|
||||
StatusOr<llvm::SmallVector<int32_t, 4>> GetOutputTensorIndices(
|
||||
const tflite::SubGraphT& subgraph, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays) {
|
||||
if (ordered_output_arrays.empty()) {
|
||||
return llvm::SmallVector<int32_t, 4>(subgraph.outputs.begin(),
|
||||
subgraph.outputs.end());
|
||||
// Returns indices of the given tensors in the subgraph. Returns error if a
|
||||
// tensor name cannot be found in the subgraph.
|
||||
StatusOr<std::vector<int>> GetTensorIndices(
|
||||
const tflite::SubGraphT& subgraph,
|
||||
const std::vector<std::string>& tensor_names) {
|
||||
absl::flat_hash_map<std::string, int> name_to_index;
|
||||
for (auto index_and_tensor : llvm::enumerate(subgraph.tensors)) {
|
||||
name_to_index[index_and_tensor.value()->name] = index_and_tensor.index();
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t, 4> outputs;
|
||||
outputs.resize(ordered_output_arrays.size());
|
||||
absl::flat_hash_map<std::string, int> output_order_map;
|
||||
for (auto output : llvm::enumerate(ordered_output_arrays)) {
|
||||
output_order_map[output.value()] = output.index();
|
||||
}
|
||||
std::vector<int> indices;
|
||||
indices.reserve(tensor_names.size());
|
||||
|
||||
int tensor_index = 0;
|
||||
int found_output_tensors = 0;
|
||||
for (const auto& tensor : subgraph.tensors) {
|
||||
auto found = output_order_map.find(tensor->name);
|
||||
if (found != output_order_map.end()) {
|
||||
const int output_index = found->second;
|
||||
outputs[output_index] = tensor_index;
|
||||
++found_output_tensors;
|
||||
for (const auto& name : tensor_names) {
|
||||
auto found = name_to_index.find(name);
|
||||
if (found != name_to_index.end()) {
|
||||
indices.push_back(found->second);
|
||||
} else {
|
||||
return errors::InvalidArgument("could not find tensor in subgraph: ",
|
||||
name);
|
||||
}
|
||||
++tensor_index;
|
||||
}
|
||||
|
||||
if (found_output_tensors != ordered_output_arrays.size()) {
|
||||
auto err = errors::InvalidArgument(
|
||||
"cannot find all nodes in ordered_output_arrays");
|
||||
return emitError(base_loc, err.ToString()), err;
|
||||
}
|
||||
|
||||
return outputs;
|
||||
return indices;
|
||||
}
|
||||
|
||||
// Given a list of tensor indices, returns a string of concatenated tensor names
|
||||
@ -661,15 +668,18 @@ mlir::NamedAttribute BuildTFEntryFunctionAttribute(
|
||||
name, builder->getStringAttr(llvm::join(tensor_names, ",")));
|
||||
}
|
||||
|
||||
// Given a list of output indices, traverses the subgraph and returns the set of
|
||||
// ops that are ancestors of the output tensors.
|
||||
// Traverses the subgraph from output_indices to input_indices and returns the
|
||||
// set of ops that are visited.
|
||||
StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> output_indices) {
|
||||
const tflite::SubGraphT& subgraph, ArrayRef<int32_t> input_indices,
|
||||
ArrayRef<int32_t> output_indices) {
|
||||
// Create a map from tensor index to defining op.
|
||||
absl::flat_hash_map<int32_t, const tflite::OperatorT*> defining_op;
|
||||
for (const auto& op : subgraph.operators) {
|
||||
for (int32_t output : op->outputs) {
|
||||
defining_op[output] = op.get();
|
||||
if (!llvm::is_contained(input_indices, output)) {
|
||||
defining_op[output] = op.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -718,18 +728,40 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<std::string>& func_names,
|
||||
const std::vector<std::unique_ptr<tflite::BufferT>>& buffers,
|
||||
Location base_loc, Builder builder,
|
||||
const std::vector<std::string>& ordered_output_arrays, bool is_entry_point,
|
||||
Location base_loc, Builder builder, bool is_entry_point,
|
||||
bool use_external_constant,
|
||||
const std::vector<std::string>& ordered_input_arrays,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
llvm::SmallVector<mlir::Type, 2> ret_types;
|
||||
llvm::SmallVector<mlir::Type, 4> input_types;
|
||||
|
||||
auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc);
|
||||
|
||||
// Construct function type
|
||||
for (auto input : subgraph.inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input);
|
||||
std::vector<int> func_inputs = subgraph.inputs;
|
||||
if (is_entry_point && !ordered_input_arrays.empty()) {
|
||||
if (!experimental_prune_unreachable_nodes_unconditionally) {
|
||||
// TODO(b/149922113): Resolve input-arrays/pruning flags interaction.
|
||||
return errors::InvalidArgument(
|
||||
"input-arrays should be used with experimental pruning flag");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(func_inputs,
|
||||
GetTensorIndices(subgraph, ordered_input_arrays));
|
||||
}
|
||||
|
||||
// Add state variables to inputs.
|
||||
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
|
||||
func_inputs.end());
|
||||
for (int i = 0; i < subgraph.tensors.size(); i++) {
|
||||
auto& tensor = *subgraph.tensors.at(i);
|
||||
if (tensor.is_variable && !input_index_set.contains(i)) {
|
||||
func_inputs.emplace_back(i);
|
||||
input_index_set.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto input_or_variable : func_inputs) {
|
||||
auto& tensor = *subgraph.tensors.at(input_or_variable);
|
||||
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
|
||||
// but we cannot differentiate scalars from unranked tensors.
|
||||
// Here we reverse the default assumption that shape = [] means unranked.
|
||||
@ -753,9 +785,11 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto func_outputs,
|
||||
GetOutputTensorIndices(subgraph, base_loc, ordered_output_arrays));
|
||||
std::vector<int> func_outputs = subgraph.outputs;
|
||||
if (is_entry_point && !ordered_output_arrays.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(func_outputs,
|
||||
GetTensorIndices(subgraph, ordered_output_arrays));
|
||||
}
|
||||
|
||||
for (auto output : func_outputs) {
|
||||
bool is_constant = !is_op_output[output];
|
||||
@ -782,8 +816,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
Value maybe_optional_arg_marker = nullptr;
|
||||
|
||||
// Get or construct MLIR values for each input
|
||||
for (int i = 0, e = subgraph.inputs.size(); i < e; i++) {
|
||||
auto input_tensor = subgraph.inputs[i];
|
||||
for (int i = 0, e = func_inputs.size(); i < e; i++) {
|
||||
auto input_tensor = func_inputs[i];
|
||||
const auto& tensor = *subgraph.tensors.at(input_tensor);
|
||||
auto loc = TensorLoc(tensor, builder, base_loc);
|
||||
if (vals_map[input_tensor]) {
|
||||
@ -806,9 +840,9 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
// Set tf.entry_function attribute
|
||||
if (is_entry_point) {
|
||||
llvm::SmallVector<mlir::NamedAttribute, 2> attributes;
|
||||
if (!subgraph.inputs.empty()) {
|
||||
if (!func_inputs.empty()) {
|
||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||
subgraph, &builder, "inputs", subgraph.inputs));
|
||||
subgraph, &builder, "inputs", func_inputs));
|
||||
}
|
||||
if (!func_outputs.empty()) {
|
||||
attributes.push_back(BuildTFEntryFunctionAttribute(
|
||||
@ -820,7 +854,7 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
if (experimental_prune_unreachable_nodes_unconditionally) {
|
||||
TF_ASSIGN_OR_RETURN(pruned_subgraph_ops,
|
||||
PruneSubgraph(subgraph, func_outputs));
|
||||
PruneSubgraph(subgraph, func_inputs, func_outputs));
|
||||
}
|
||||
|
||||
// Construct MLIR operators from TFLite operators
|
||||
@ -931,8 +965,9 @@ std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) {
|
||||
|
||||
OwningModuleRef tflite::FlatBufferToMlir(
|
||||
absl::string_view buffer, MLIRContext* context, Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant,
|
||||
const std::vector<std::string>& ordered_input_arrays,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool experimental_prune_unreachable_nodes_unconditionally) {
|
||||
auto model_ptr =
|
||||
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
|
||||
@ -971,33 +1006,25 @@ OwningModuleRef tflite::FlatBufferToMlir(
|
||||
builder.getStringAttr(model->description));
|
||||
}
|
||||
|
||||
if (!ordered_output_arrays.empty() && model->subgraphs.size() > 1) {
|
||||
// TODO(b/141485522): support more than one subgraph.
|
||||
return emitError(base_loc,
|
||||
"ordered_output_arrays does not support more than one "
|
||||
"subgraph yet"),
|
||||
nullptr;
|
||||
}
|
||||
|
||||
for (auto e : llvm::enumerate(model->subgraphs)) {
|
||||
auto& subgraph = e.value();
|
||||
std::string name = SubgraphName(e.index(), *subgraph);
|
||||
auto func_or_error = ConvertSubgraph(
|
||||
*subgraph, name, operator_names, func_names, model->buffers, base_loc,
|
||||
// Only the entry point needs pseudo_input_ops
|
||||
builder,
|
||||
// TODO(b/131175224,b/132239787) Support multiple entry points
|
||||
builder, ordered_output_arrays,
|
||||
/*is_entry_point=*/e.index() == 0,
|
||||
/*use_external_constant=*/use_external_constant,
|
||||
/*use_external_constant=*/use_external_constant, ordered_input_arrays,
|
||||
ordered_output_arrays,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
if (!func_or_error.ok()) {
|
||||
return emitError(base_loc, "could not translate function ")
|
||||
<< subgraph->name,
|
||||
<< subgraph->name << ": "
|
||||
<< func_or_error.status().error_message(),
|
||||
nullptr;
|
||||
}
|
||||
module.push_back(func_or_error.ConsumeValueOrDie());
|
||||
}
|
||||
// TFLite subgraphs do not necessarily have names,
|
||||
|
||||
return OwningModuleRef(module);
|
||||
}
|
||||
@ -1012,17 +1039,24 @@ static OwningModuleRef FlatBufferFileToMlirTrans(
|
||||
auto loc =
|
||||
mlir::FileLineColLoc::get(input->getBufferIdentifier(), 0, 0, context);
|
||||
|
||||
// Parses output_arrays_order from command line option.
|
||||
// Parses input/output names from command line options.
|
||||
std::vector<std::string> inputs;
|
||||
std::vector<std::string> outputs;
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_string, &outputs).ok()) {
|
||||
// Use output parser since we only have tensor names.
|
||||
if (!tensorflow::ParseOutputArrayInfo(input_arrays_flag, &inputs).ok()) {
|
||||
return emitError(loc, "parsing input array info failed ")
|
||||
<< input_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
if (!tensorflow::ParseOutputArrayInfo(output_arrays_flag, &outputs).ok()) {
|
||||
return emitError(loc, "parsing output array info failed ")
|
||||
<< output_arrays_string,
|
||||
<< output_arrays_flag,
|
||||
nullptr;
|
||||
}
|
||||
|
||||
return tflite::FlatBufferToMlir(
|
||||
absl::string_view(input->getBufferStart(), input->getBufferSize()),
|
||||
context, loc, outputs, use_external_constant,
|
||||
context, loc, use_external_constant, inputs, outputs,
|
||||
experimental_prune_unreachable_nodes_unconditionally);
|
||||
}
|
||||
|
||||
|
@ -35,9 +35,9 @@ namespace tflite {
|
||||
// are not ancestors of the output nodes will be pruned.
|
||||
mlir::OwningModuleRef FlatBufferToMlir(
|
||||
absl::string_view buffer, mlir::MLIRContext* context,
|
||||
mlir::Location base_loc,
|
||||
const std::vector<std::string>& ordered_output_arrays,
|
||||
bool use_external_constant = false,
|
||||
mlir::Location base_loc, bool use_external_constant = false,
|
||||
const std::vector<std::string>& ordered_input_arrays = {},
|
||||
const std::vector<std::string>& ordered_output_arrays = {},
|
||||
bool experimental_prune_unreachable_nodes_unconditionally = false);
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -42,7 +42,7 @@ limitations under the License.
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -122,8 +122,6 @@ bool emit_custom_ops;
|
||||
bool emit_select_tf_ops;
|
||||
bool lower_tensor_list_ops;
|
||||
bool strip_debug_info;
|
||||
// NOLINTNEXTLINE
|
||||
std::string output_arrays_string;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<bool, true> emit_builtin_tflite_ops_flag(
|
||||
@ -156,11 +154,6 @@ static opt<bool, true> strip_debug_info_flag(
|
||||
"strip-debug-info", llvm::cl::desc("Strip debug info during export"),
|
||||
llvm::cl::location(strip_debug_info), llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static opt<std::string, true> output_arrays_flag(
|
||||
"output-arrays", llvm::cl::desc("List of output tensors"),
|
||||
llvm::cl::location(output_arrays_string), llvm::cl::init(""));
|
||||
|
||||
ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
|
||||
|
||||
// Use initial buffer size in flatbuffer builder to be same as the initial size
|
||||
@ -172,7 +165,7 @@ constexpr size_t kInitialBufferSize = 10240;
|
||||
// `isSigned` is set to false for other types.
|
||||
static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
bool is_signed = true) {
|
||||
if (!is_signed && type.isInteger(8)) {
|
||||
if (!is_signed && type.isSignlessInteger(8)) {
|
||||
return tflite::TensorType_UINT8;
|
||||
}
|
||||
if (!is_signed) {
|
||||
|
@ -27,7 +27,5 @@ extern bool emit_custom_ops;
|
||||
extern bool lower_tensor_list_ops;
|
||||
// The flag to control whether debug info gets stripped on export.
|
||||
extern bool strip_debug_info;
|
||||
// The flag to control the output array info of tflite graph.
|
||||
extern std::string output_arrays_string;
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
@ -275,7 +275,7 @@ Attribute ConstFoldBinaryOp(
|
||||
return ConstFoldBinaryOp<FloatAttr>(result_type, operands[0], operands[1],
|
||||
float_calculate, is_commutative);
|
||||
|
||||
if (elemType.isa<IntegerType>())
|
||||
if (elemType.isSignlessInteger())
|
||||
return ConstFoldBinaryOp<IntegerAttr>(result_type, operands[0], operands[1],
|
||||
int_calculate, is_commutative);
|
||||
|
||||
@ -723,12 +723,11 @@ static LogicalResult Verify(PackOp op) {
|
||||
}
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||
for (Value operand : op.getOperands()) {
|
||||
auto other_type = operand.getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
// TODO(b/135032063): Simplify once fixed.
|
||||
for (Type operand_type : op.getOperandTypes()) {
|
||||
if (failed(mlir::verifyCompatibleShape(input_type, operand_type)))
|
||||
return op.emitOpError("operands should be of the same type. got ")
|
||||
<< input_type << ", " << other_type;
|
||||
<< input_type << ", " << operand_type;
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -1561,7 +1560,7 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
||||
limit_tensor.getType().getRank() == 0 &&
|
||||
delta_tensor.getType().getRank() == 0);
|
||||
Type elem_type = getType().cast<ShapedType>().getElementType();
|
||||
if (elem_type.isa<IntegerType>()) {
|
||||
if (elem_type.isSignlessInteger()) {
|
||||
auto start_attr = start_tensor.getValue<IntegerAttr>({});
|
||||
auto limit_attr = limit_tensor.getValue<IntegerAttr>({});
|
||||
auto delta_attr = delta_tensor.getValue<IntegerAttr>({});
|
||||
@ -1663,7 +1662,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
// Do not try to fold elements attr of a quant type because
|
||||
// DenseElementsAttr does not support it.
|
||||
if (!getType().cast<ShapedType>().getElementType().isIntOrFloat())
|
||||
if (!getType().cast<ShapedType>().getElementType().isSignlessIntOrFloat())
|
||||
return nullptr;
|
||||
|
||||
assert(perm_tensor.getType().getRank() == 1);
|
||||
|
@ -1656,7 +1656,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
|
||||
let summary = "Mean operator";
|
||||
|
||||
let description = [{
|
||||
@ -2482,11 +2482,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, SameOperandsAndResultsScale,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$input,
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$input,
|
||||
TFL_I32OrI64Tensor:$multiples);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8]>:$output);
|
||||
TFL_TensorOf<[F32, I1, I32, I64, TFL_Uint8, QUI8, TFL_Str]>:$output);
|
||||
|
||||
let hasOptions = 0;
|
||||
}
|
||||
|
@ -63,6 +63,41 @@ const char kDetectionPostProcessOp[] =
|
||||
"'detections_per_class' type: 'int' default_value { i : 100 }} attr { "
|
||||
"name: 'use_regular_nms' type: 'bool' default_value { b : false }}";
|
||||
|
||||
const char kUnidirectionalSequenceLstmOp[] =
|
||||
"name: 'UnidirectionalSequenceLstm' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'InputToInputWeights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputToForgetWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'InputToCellWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'InputToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToForgetWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'RecurrentToCellWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'RecurrentToOutputWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToInputWeights' type: DT_FLOAT} input_arg: { name: "
|
||||
"'CellToForgetWeights' type: DT_FLOAT } input_arg: { name: "
|
||||
"'CellToOutputWeights' type: DT_FLOAT } input_arg: { name: 'InputGateBias' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ForgetGateBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'kCellGateBias' type: DT_FLOAT } input_arg: { name: "
|
||||
"'OutputGateBias' type: DT_FLOAT } input_arg: { name: 'ProjectionWeights' "
|
||||
"type: DT_FLOAT } input_arg: { name: 'ProjectionBias' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'InputActivationState' type: DT_FLOAT} input_arg: { "
|
||||
"name: 'InputCellStateTensor' type: DT_FLOAT } "
|
||||
"output_arg: { name: 'Concat' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
const char kUnidirectionalSequenceRnnOp[] =
|
||||
"name: 'UnidirectionalSequenceRnn' input_arg: {name: 'Input' type: "
|
||||
"DT_FLOAT} input_arg: { name: 'Weights' type: DT_FLOAT } "
|
||||
"input_arg: { name: 'RecurrentWeights' type: DT_FLOAT } input_arg: { "
|
||||
"name: 'Bias' type: DT_FLOAT} "
|
||||
"input_arg: { name: 'HiddenState' type: DT_FLOAT} "
|
||||
"output_arg: { name: "
|
||||
"'LastState' type: DT_FLOAT } output_arg: { name: 'Output' type: "
|
||||
"DT_FLOAT} "
|
||||
"attr : { name: '_tflite_input_indices' type: 'list(int)'}";
|
||||
|
||||
// Converts the toco::IODataType to tensorflow::DataType. Only contains the
|
||||
// conversion mapping for constants defined in TFLite Python API.
|
||||
DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
@ -260,6 +295,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
|
||||
#include "mlir/IR/AffineMap.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
|
@ -61,11 +61,9 @@ TfLiteStatus QuantizeModel(
|
||||
std::string serialized_model(
|
||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||
input_builder.GetSize());
|
||||
std::vector<std::string> output_arrays_order;
|
||||
|
||||
OwningModuleRef module =
|
||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
||||
UnknownLoc::get(&context), output_arrays_order);
|
||||
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||
UnknownLoc::get(&context));
|
||||
if (!module) {
|
||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||
return kTfLiteError;
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -191,7 +191,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||
inputs.push_back(op_inst.input());
|
||||
} else if (ele_type.isa<IntegerType>()) {
|
||||
} else if (ele_type.isSignlessInteger()) {
|
||||
// If the operand is an integer tensor, then it doesn't require the
|
||||
// DQ op in the pattern.
|
||||
inputs.push_back(operand);
|
||||
@ -225,7 +225,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
auto user = llvm::cast<Q>(*result.user_begin());
|
||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||
output_types.push_back(user.getType());
|
||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
||||
} else if (result_ele_type.isSignlessInteger()) {
|
||||
// If the result is an integer tensor, then it doesn't require the
|
||||
// D op in the pattern.
|
||||
outputs_replaced.insert({result, enumerated_result.index()});
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -48,11 +48,9 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model,
|
||||
std::string serialized_model(
|
||||
reinterpret_cast<const char*>(input_builder.GetBufferPointer()),
|
||||
input_builder.GetSize());
|
||||
std::vector<std::string> output_arrays_order;
|
||||
|
||||
OwningModuleRef module =
|
||||
tflite::FlatBufferToMlir(serialized_model, &context,
|
||||
UnknownLoc::get(&context), output_arrays_order);
|
||||
OwningModuleRef module = tflite::FlatBufferToMlir(serialized_model, &context,
|
||||
UnknownLoc::get(&context));
|
||||
if (!module) {
|
||||
error_reporter->Report("Couldn't import flatbuffer to MLIR.");
|
||||
return kTfLiteError;
|
||||
|
@ -27,6 +27,20 @@ func @testDilatedConvWithNonZeroSTBPadding(%arg0: tensor<1x128x128x3xf32>, %arg1
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128x8xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithNonTrivialDilations(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
%1 = "tf.Conv2D"(%0, %arg2) {padding = "VALID", dilations = [1, 2, 2, 1], strides = [1, 1, 1, 1]} : (tensor<4x68x68x3xf32>, tensor<5x5x3x8xf32>) -> tensor<4x64x64x8xf32>
|
||||
%2 = "tf.BatchToSpaceND"(%1, %cst, %arg1) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32>
|
||||
return %2 : tensor<1x128x128x8xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithNonTrivialDilations
|
||||
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
||||
func @testDilatedDepthWiseConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32>
|
||||
@ -104,7 +118,7 @@ func @testDilatedDepthWiseConvWithBiasAdd(%arg0: tensor<1x128x128x3xf32>, %arg1:
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -115,7 +129,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -125,7 +139,7 @@ func @testDilatedConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -136,7 +150,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze1
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -146,7 +160,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze1(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
@ -157,7 +171,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -167,7 +181,7 @@ func @testDilatedConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<?xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x?x?xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x?x?xf32>, tensor<i32>) -> tensor<4x?x?x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x?x?x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x?x?x1xf32>
|
||||
@ -178,7 +192,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze2
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<?xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -188,7 +202,7 @@ func @testDilatedDepthWiseConvWithExpandSqueeze2(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -200,7 +214,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
@ -210,7 +224,7 @@ func @testDilatedConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: ten
|
||||
|
||||
func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = constant dense<3> : tensor<i32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.DepthwiseConv2dNative"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
@ -222,10 +236,29 @@ func @testDilatedDepthWiseConvWithExpandSqueeze3(%arg0: tensor<1x128x128xf32>, %
|
||||
|
||||
// CHECK-LABEL: testDilatedDepthWiseConvWithExpandSqueeze3
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<1x128x128xf32>, [[PADDING:%.*]]: tensor<2x2xi32>, [[FILTER:%.*]]: tensor<5x5x1x1xf32>, [[BIAS:%.*]]: tensor<128xf32>)
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = constant dense<3> : tensor<i32>
|
||||
// CHECK-NEXT: [[AXIS:%.*]] = "tf.Const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"([[INPUT]], [[AXIS]]) : (tensor<1x128x128xf32>, tensor<i32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.DepthwiseConv2dNative"([[EXPAND]], [[FILTER]]) {dilations = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x128x128x1xf32>, tensor<5x5x1x1xf32>) -> tensor<1x128x128x1xf32>
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"([[CONV]]) {squeeze_dims = [3]} : (tensor<1x128x128x1xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BiasAdd"([[SQUEEZE]], [[BIAS]]) : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32>
|
||||
// CHECK-NEXT: return [[RESULT]] : tensor<1x128x128xf32>
|
||||
}
|
||||
|
||||
func @testDilatedConvWithDifferentExpandSqueezeAxis(%arg0: tensor<1x128x128xf32>, %arg1: tensor<2x2xi32>, %arg2: tensor<5x5x1x1xf32>, %arg3: tensor<128xf32>) -> tensor<1x128x128x1xf32> {
|
||||
%cst = constant dense<[2, 2]> : tensor<2xi32>
|
||||
%cst_0 = "tf.Const"() { value = dense<3> : tensor<i32> } : () -> tensor<i32>
|
||||
%0 = "tf.SpaceToBatchND"(%arg0, %cst, %arg1) : (tensor<1x128x128xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68xf32>
|
||||
%1 = "tf.ExpandDims"(%0, %cst_0) : (tensor<4x68x68xf32>, tensor<i32>) -> tensor<4x68x68x1xf32>
|
||||
%2 = "tf.Conv2D"(%1, %arg2) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<4x68x68x1xf32>, tensor<5x5x1x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%3 = "tf.Squeeze"(%2) {squeeze_dims = [2]} : (tensor<4x64x64x1xf32>) -> tensor<4x64x64x1xf32>
|
||||
%4 = "tf.BatchToSpaceND"(%3, %cst, %arg1) : (tensor<4x64x64x1xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x1xf32>
|
||||
return %4 : tensor<1x128x128x1xf32>
|
||||
|
||||
// CHECK-LABEL: testDilatedConvWithDifferentExpandSqueezeAxis
|
||||
// CHECK: [[STB:%.*]] = "tf.SpaceToBatchND"
|
||||
// CHECK-NEXT: [[EXPAND:%.*]] = "tf.ExpandDims"
|
||||
// CHECK-NEXT: [[CONV:%.*]] = "tf.Conv2D"
|
||||
// CHECK-NEXT: [[SQUEEZE:%.*]] = "tf.Squeeze"
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = "tf.BatchToSpaceND"
|
||||
// CHECK-NEXT: return [[RESULT]]
|
||||
}
|
||||
|
@ -0,0 +1,13 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Tests -input-arrays flag.
|
||||
|
||||
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
|
||||
%2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK-LABEL: main
|
||||
// CHECK-NOT: tfl.squared_difference
|
||||
// CHECK: tfl.mul %[[CONST:.*]], %arg0
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Ensure lstm roundtrip exactly
|
||||
|
||||
func @main(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
// CHECK-LABEL: main
|
||||
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: return %[[RES0]]
|
||||
|
||||
}
|
@ -123,6 +123,17 @@ func @softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
// CHECK: "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
|
||||
// CHECK-LABEL: softplus
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor<f32>) -> tensor<8x16xf32>
|
||||
// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
}
|
||||
|
||||
func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> {
|
||||
%0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
return %0 : tensor<8x8x8x8xf32>
|
||||
@ -1453,3 +1464,19 @@ func @LstmWithProjection(%arg: tensor<28x1x16xf32>) -> (tensor<28x1x8xf32>) {
|
||||
// CHECK: [[VAL_15:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_8]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_9]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_12]], [[VAL_14]], [[VAL_13]], [[VAL_11]], [[VAL_14]], [[VAL_14]], [[VAL_14]], [[VAL_14]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<28x1x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x16xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, tensor<16x8xf32>, none, none, none, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<8x16xf32>, none, tensor<1x8xf32>, tensor<1x16xf32>, none, none, none, none) -> tensor<28x1x8xf32>
|
||||
// CHECK: return [[VAL_15]] : tensor<28x1x8xf32>
|
||||
// CHECK: }
|
||||
|
||||
func @UnidirectionalRnn(%arg: tensor<28x1x28xf32>) -> (tensor<28x1x28xf32>) {
|
||||
%1 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28x28xf32>} : () -> tensor<28x28xf32>
|
||||
%2 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<28xf32>} : () -> tensor<28xf32>
|
||||
%3 = "tf.Const"() {device = "", dtype = f32, value = dense<0.000000e+00>: tensor<1x28xf32>} : () -> tensor<1x28xf32>
|
||||
%4:2 = "tf.UnidirectionalSequenceRnn"(%arg, %1, %1, %2, %3) {_tflite_input_indices = [0, 1, 2, 3, 4], device = ""} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> (tensor<*xf32>, tensor<28x1x28xf32>)
|
||||
return %4#1 : tensor<28x1x28xf32>
|
||||
}
|
||||
|
||||
// CHECK: func @UnidirectionalRnn([[VAL_0:%.*]]: tensor<28x1x28xf32>) -> tensor<28x1x28xf32> {
|
||||
// CHECK: [[VAL_1:%.*]] = constant dense<0.000000e+00> : tensor<28x28xf32>
|
||||
// CHECK: [[VAL_2:%.*]] = constant dense<0.000000e+00> : tensor<28xf32>
|
||||
// CHECK: [[VAL_3:%.*]] = constant dense<0.000000e+00> : tensor<1x28xf32>
|
||||
// CHECK: [[VAL_4:%.*]] = "tfl.unidirectional_sequence_rnn"([[VAL_0]], [[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]]) {fused_activation_function = "TANH", time_major = true} : (tensor<28x1x28xf32>, tensor<28x28xf32>, tensor<28x28xf32>, tensor<28xf32>, tensor<1x28xf32>) -> tensor<28x1x28xf32>
|
||||
// CHECK: return [[VAL_4]] : tensor<28x1x28xf32>
|
||||
// CHECK: }
|
||||
|
@ -878,6 +878,14 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @packUnranked(%arg0: tensor<2xi32>, %arg1: tensor<*xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<*xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||
|
@ -154,7 +154,7 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
@ -165,7 +165,7 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x?x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_7410([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<?x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.signature.is_stateful} {
|
||||
// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x?x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
@ -181,7 +181,46 @@ func @inference_standard_lstm_7410(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10x
|
||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: return [[VAL_21:%.*]] : tensor<?x8x10xf32>
|
||||
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x?x10xf32>
|
||||
// CHECK: return [[VAL_21:%.*]] : tensor<8x?x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<?x10xf32>) -> tensor<8x8x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> tensor<8x8x10xf32> attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<3xi64>) -> tensor<8x8x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_15:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_18:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: return [[VAL_24]] : tensor<8x8x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
@ -622,3 +622,16 @@ func @QuantizeSharedBiases2(
|
||||
// CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]]
|
||||
// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]])
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ReturnQuantizedResult
|
||||
func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) {
|
||||
%0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
%1 = "tfl.quantize"(%0) {qtype = tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>
|
||||
%2 = "tfl.dequantize"(%1) : (tensor<1x112x112x32x!quant.uniform<u8:f32, 0.023528476789885875>>) -> (tensor<1x112x112x32xf32>)
|
||||
return %0, %2 : tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>
|
||||
|
||||
// CHECK: %[[dw:.*]] = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2)
|
||||
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[dw]])
|
||||
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]])
|
||||
// CHECK: return %[[dq]], %[[dq]]
|
||||
}
|
||||
|
@ -1,22 +1,28 @@
|
||||
// Test to verify loop outlining.
|
||||
|
||||
// RUN: tf-opt --split-input-file --tfl-while-loop-outline %s | FileCheck %s --dump-input-on-failure
|
||||
// Check that while loop outlining is nop if re-ran.
|
||||
// RUN: tf-opt --tfl-while-loop-outline %s -o %t1
|
||||
// RUN: tf-opt --tfl-while-loop-outline %t1 -o %t2
|
||||
// RUN: diff %t1 %t2
|
||||
|
||||
// CHECK-LABEL: func @while
|
||||
func @while() -> tensor<1xf32>
|
||||
attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst = constant dense<1> : tensor<i32> loc("dec")
|
||||
%arg0 = constant dense<5> : tensor<i32> loc("N")
|
||||
%arg1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||
%0:2 = "tfl.while"(%arg0, %arg1) ( {
|
||||
%cst0 = constant dense<5> : tensor<i32> loc("N")
|
||||
%cst1 = constant dense<3.0> : tensor<1xf32> loc("val")
|
||||
%0:2 = "tfl.while"(%cst0, %cst1) ( {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_cond
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%1 = "tfl.greater"(%arg2, %cst_0) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg2: tensor<*xi32>, %arg3: tensor<*xf32>):
|
||||
// CHECK: call @WhileOp_body
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
%1 = "tfl.sub"(%arg2, %cst) {fused_activation_function = "NONE"} :
|
||||
(tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%2 = tfl.add %arg3, %arg3 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
@ -32,6 +38,52 @@ func @while() -> tensor<1xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @while2
|
||||
// Verify that while body//cond with implicitly captured values result in changing while operands/results.
|
||||
func @while2() -> tensor<1xf32> attributes {tf.entry_function = {outputs = "result"}} {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%cst_0 = constant dense<5> : tensor<i32>
|
||||
%cst_1 = constant dense<3.000000e+00> : tensor<1xf32>
|
||||
// Verifies 3 operands post outlining.
|
||||
// CHECK: "tfl.while"({{.*}}, {{.*}}, {{.*}}) (
|
||||
%0:2 = "tfl.while"(%cst_0, %cst_1) ( {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>): // no predecessors
|
||||
// CHECK: call @WhileOp_cond
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
%1 = call @WhileOp_cond(%arg0, %arg1, %cst) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> tensor<i1>
|
||||
"tfl.yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>): // no predecessors
|
||||
// CHECK: call @WhileOp_body
|
||||
// CHECK-SAME: (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
%1:3 = call @WhileOp_body(%arg0, %arg1, %cst) : (tensor<*xi32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>)
|
||||
"tfl.yield"(%1#0, %1#1) : (tensor<*xi32>, tensor<*xf32>) -> ()
|
||||
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
|
||||
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) ->
|
||||
// CHECK-SAME: (tensor<i32>, tensor<1xf32>, tensor<i32>)
|
||||
return %0#1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<0> : tensor<i32>
|
||||
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
|
||||
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
|
||||
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @WhileOp_cond(
|
||||
// CHECK: tfl.greater
|
||||
// CHECK-LABEL: func @WhileOp_body(
|
||||
// CHECK: tfl.sub
|
||||
// CHECK: tfl.add
|
||||
|
||||
// -----
|
||||
|
||||
func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x2xf32> attributes {tf.entry_function = {inputs = "Placeholder", outputs = "rnn/transpose_1"}} {
|
||||
%cst = constant dense<0.000000e+00> : tensor<4x2xf32>
|
||||
%cst_0 = constant dense<0.000000e+00> : tensor<8xf32>
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -80,6 +81,17 @@ class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
|
||||
template <typename Conv2dOpTy>
|
||||
PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
Conv2dOpTy op, PatternRewriter& rewriter) const {
|
||||
// Make sure Conv2D has 'VALID' padding.
|
||||
if (op.template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
// Make sure dilations are all ones if set.
|
||||
const ArrayAttr& dilations =
|
||||
op.template getAttrOfType<ArrayAttr>("dilations");
|
||||
if (dilations && !TFIntListIsAllOnes(dilations)) {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// Check if the ConvOp is preceded by a `Expand` op and succeeded by a
|
||||
// `Squeeze` op.
|
||||
Operation* prev_op = op.getOperation()->getPrevNode();
|
||||
@ -90,6 +102,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
|
||||
TF::ExpandDimsOp expand_op;
|
||||
TF::SqueezeOp squeeze_op;
|
||||
int64_t expand_axis;
|
||||
// Expand + Squeeze op.
|
||||
if (llvm::isa<TF::ExpandDimsOp>(prev_op)) {
|
||||
if (!llvm::isa<TF::SqueezeOp>(next_op)) {
|
||||
@ -99,6 +112,22 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
expand_op = llvm::cast<TF::ExpandDimsOp>(prev_op);
|
||||
squeeze_op = llvm::cast<TF::SqueezeOp>(next_op);
|
||||
|
||||
// Make sure that the axis in `expand_op` is constant.
|
||||
if (auto const_op =
|
||||
llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
|
||||
expand_axis =
|
||||
(*const_op.value().cast<DenseElementsAttr>().getIntValues().begin())
|
||||
.getSExtValue();
|
||||
} else {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
// Make sure that the `squeeze_dims` is equal to `expand_axis`.
|
||||
auto squeeze_dims = squeeze_op.squeeze_dims();
|
||||
if (squeeze_dims.size() != 1 ||
|
||||
squeeze_dims[0].cast<IntegerAttr>().getInt() != expand_axis) {
|
||||
return Pattern::matchFailure();
|
||||
}
|
||||
|
||||
// Update previous/next op pointer.
|
||||
prev_op = prev_op->getPrevNode();
|
||||
if (!prev_op) return Pattern::matchFailure();
|
||||
@ -108,10 +137,14 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
|
||||
// SpaceToBatchND op.
|
||||
if (!llvm::isa<TF::SpaceToBatchNDOp>(prev_op)) return Pattern::matchFailure();
|
||||
// TODO(b/149936532): Check `padding` input, currently ignored.
|
||||
TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(prev_op);
|
||||
|
||||
// Pad op.
|
||||
TF::PadOp pad_op;
|
||||
// TODO(b/149936532): Currently we just ignore the PadOp. However note that
|
||||
// in real scenarios this may not always be correct: user can put a PadOp here
|
||||
// with non-trivial consequences.
|
||||
if (llvm::isa<TF::PadOp>(next_op)) {
|
||||
pad_op = llvm::cast<TF::PadOp>(next_op);
|
||||
next_op = next_op->getNextNode();
|
||||
@ -119,6 +152,7 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
}
|
||||
|
||||
// BatchToSpaceND + BiasAdd.
|
||||
// TODO(b/149936532): Check the `crops` input, currently ignored.
|
||||
TF::BatchToSpaceNDOp bts_op;
|
||||
TF::BiasAddOp biasadd_op;
|
||||
bool final_op_is_bts = true;
|
||||
@ -146,14 +180,10 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
if (!dilations_attr.hasValue()) return Pattern::matchFailure();
|
||||
op.setAttr("dilations", dilations_attr.getValue());
|
||||
|
||||
// Here we need to set the correct padding for Conv op. In TF, the conv op
|
||||
// inserted after 'SpaceToBatch' always has 'VALID' padding. This might
|
||||
// become a problem here if the original Conv op has 'SAME' padding. When
|
||||
// the original conv has 'SAME' padding, TF will set a non-zero padding for
|
||||
// the 'SpaceToBatch' op, so we rely on this information to check if we need
|
||||
// to change the padding from 'VALID' to 'SAME' (a.k.a when we see non-zero
|
||||
// values in `stb_op.paddings`, we change the current Conv's padding to
|
||||
// 'SAME').
|
||||
// Padding is set to 'SAME' when `stb_op` has non-zero paddings.
|
||||
// TODO(b/149936532): This assumption only holds when the input width & height
|
||||
// is multiple of dilation width & height. We should fix it in order to
|
||||
// support other use cases.
|
||||
auto stb_paddings = stb_op.paddings();
|
||||
ElementsAttr stb_paddings_attr;
|
||||
if (matchPattern(stb_paddings, m_Constant(&stb_paddings_attr))) {
|
||||
@ -175,7 +205,8 @@ PatternMatchResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
|
||||
auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
|
||||
SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
|
||||
input_shape.end());
|
||||
expand_shape.push_back(1);
|
||||
expand_shape.insert(expand_shape.begin() + expand_axis, 1);
|
||||
|
||||
auto expand_result_type = RankedTensorType::get(
|
||||
expand_shape, getElementTypeOrSelf(stb_op.input()));
|
||||
expand_op.getResult().setType(expand_result_type);
|
||||
@ -208,7 +239,7 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
ElementsAttr stb_bs_attr, bts_bs_attr;
|
||||
if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
|
||||
!matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
|
||||
// Returns failure status if block shape is not a constant.
|
||||
// Returns failure status if block_shape is not a constant.
|
||||
return {};
|
||||
}
|
||||
// Check that the block_shape of `stb_op` and `bts_op` are equal.
|
||||
@ -217,9 +248,8 @@ ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
|
||||
if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {};
|
||||
}
|
||||
|
||||
// TODO(haoliang): support 1-D dilated conv.
|
||||
// Set dilation factor.
|
||||
if (stb_bs_attr.getNumElements() < 2) return {};
|
||||
|
||||
int dilation_h_factor =
|
||||
stb_bs_attr.getValue({0}).cast<IntegerAttr>().getInt();
|
||||
int dilation_w_factor =
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// TFLite legalization patterns
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
@ -167,6 +167,7 @@ def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>;
|
||||
def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>;
|
||||
def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>;
|
||||
def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>;
|
||||
def : Pat<(TF_SoftplusOp F32Tensor:$arg0), (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), TFL_AF_None))>;
|
||||
def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>;
|
||||
def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>;
|
||||
@ -340,7 +341,7 @@ def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>;
|
||||
class I32VectorElementsAttr<int len> : ElementsAttrBase<
|
||||
CPred<"$_self.isa<DenseIntElementsAttr>() &&"
|
||||
"$_self.cast<DenseIntElementsAttr>().getType()."
|
||||
"getElementType().isInteger(32)">,
|
||||
"getElementType().isSignlessInteger(32)">,
|
||||
"32-bit int elements attribute of shape [" # len # "]"> {
|
||||
|
||||
let storageType = [{ DenseIntElementsAttr }];
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
@ -64,6 +65,7 @@ using xla::Status;
|
||||
using xla::StatusOr;
|
||||
|
||||
constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm";
|
||||
constexpr char kUnidirectionalSequenceRnn[] = "tf.UnidirectionalSequenceRnn";
|
||||
constexpr char kTfLiteInputIndices[] = "_tflite_input_indices";
|
||||
|
||||
// Legalize operations in functions.
|
||||
@ -253,7 +255,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
||||
|
||||
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||
if (!shape_type.getElementType().isInteger(32)) {
|
||||
if (!shape_type.getElementType().isSignlessInteger(32)) {
|
||||
auto new_shape = shape_type.getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(new_shape, new_ele_type);
|
||||
@ -632,6 +634,66 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
// Legalize unidirectional seqeucen rnn.
|
||||
struct LegalizeUnidirectionalSequenceRnn : public RewritePattern {
|
||||
explicit LegalizeUnidirectionalSequenceRnn(MLIRContext* context)
|
||||
: RewritePattern(kUnidirectionalSequenceRnn, 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto tflite_indices_attr =
|
||||
op->getAttrOfType<ArrayAttr>(kTfLiteInputIndices);
|
||||
if (!tflite_indices_attr) return matchFailure();
|
||||
|
||||
if (op->getNumOperands() != 5) {
|
||||
op->emitError()
|
||||
<< "We're expecting 5 inputs for UnidirectionalSequenceRNN, only "
|
||||
<< op->getNumOperands() << " provided";
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
if (op->getNumResults() != 2) {
|
||||
op->emitError()
|
||||
<< "We're expecting 2 inputs for UnidirectionalSequenceRNN, only "
|
||||
<< op->getNumResults() << " found";
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Populate inputs.
|
||||
// UnidirectionalSequenceRnn is expected to have 5 inputs, and none of them
|
||||
// are optional inputs.
|
||||
SmallVector<Value, 5> inputs;
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
inputs.push_back(op->getOperand(i));
|
||||
}
|
||||
|
||||
// Populate outputs.
|
||||
// UnidirectionalSequenceRnn should only have 1 output, and that is the
|
||||
// original ophint converted node's 2nd output.
|
||||
SmallVector<Type, 4> result_types;
|
||||
result_types.push_back(op->getOpResult(1).getType());
|
||||
|
||||
// Populate attributes.
|
||||
SmallVector<NamedAttribute, 2> attributes;
|
||||
// Activation will always be tanh.
|
||||
attributes.push_back(rewriter.getNamedAttr("fused_activation_function",
|
||||
rewriter.getStringAttr("TANH")));
|
||||
|
||||
// will always be time_majored.
|
||||
attributes.push_back(
|
||||
rewriter.getNamedAttr("time_major", rewriter.getBoolAttr(true)));
|
||||
|
||||
auto rnn_op = rewriter.create<TFL::UnidirectionalSequenceRNNOp>(
|
||||
op->getLoc(), result_types, inputs, attributes);
|
||||
|
||||
// Rewire the output.
|
||||
op->getResult(1).replaceAllUsesWith(rnn_op.getResult());
|
||||
op->erase();
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void LegalizeTF::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto* ctx = &getContext();
|
||||
@ -647,7 +709,8 @@ void LegalizeTF::runOnFunction() {
|
||||
ConvertTFReciprocalOp, ConvertTFRandomUniformOp>(ctx);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm>(ctx);
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
LegalizeUnidirectionalSequenceRnn>(ctx);
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
// Converts TF While to TFL While with single call in body and cond.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -335,8 +335,9 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dtype = op.element_dtype();
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
dtype.isInteger(1) || dtype.isSignlessInteger(8) ||
|
||||
dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) ||
|
||||
dtype.isSignlessInteger(64))) {
|
||||
op.emitError(
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the optimization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Both Quantize and Dequantize ops have side effects, so we have to define
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/QuantOps/QuantOps.h" // TF:llvm-project
|
||||
@ -115,6 +116,10 @@ class PrepareQuantizePass : public FunctionPass<PrepareQuantizePass> {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply some sanity check and report some warnings for those don't follow
|
||||
// the best quantization practise. This also fixes some simple violations.
|
||||
void SanityCheckAndAdjustment(FuncOp func);
|
||||
|
||||
QuantizationSpecs quant_specs_;
|
||||
};
|
||||
|
||||
@ -184,13 +189,56 @@ bool PrepareQuantizePass::RemoveRedundantStats(FuncOp func) {
|
||||
return RemoveRedundantStatsOps(func, GetOpQuantSpec);
|
||||
}
|
||||
|
||||
static Value Quantized(Operation* user) {
|
||||
if (auto q = llvm::dyn_cast_or_null<quant::QuantizeCastOp>(user)) {
|
||||
if (auto dq = llvm::dyn_cast_or_null<quant::DequantizeCastOp>(
|
||||
*q.getResult().user_begin())) {
|
||||
return dq.getResult();
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) {
|
||||
// If an op output has two users: one of them is a quantize op and another
|
||||
// one is returned directly, we decide to return the quantized result instead,
|
||||
// so this op can be quantized. This is only applied on the returned result
|
||||
// because the error will not be accumulated.
|
||||
func.walk([&](ReturnOp ret) {
|
||||
int i = 0;
|
||||
for (Value returned : ret.operands()) {
|
||||
llvm::SmallVector<Value, 4> quantized;
|
||||
for (auto user : returned.getUsers()) {
|
||||
if (auto q = Quantized(user)) {
|
||||
quantized.push_back(q);
|
||||
}
|
||||
}
|
||||
if (quantized.size() == 1) {
|
||||
ret.setOperand(i, quantized.front());
|
||||
}
|
||||
i++;
|
||||
}
|
||||
});
|
||||
|
||||
// We prefer to placing quantization emulation ops on the results of the
|
||||
// concat ops.
|
||||
func.walk([&](ConcatenationOp concat) {
|
||||
if (concat.output().hasOneUse() &&
|
||||
Quantized(*concat.output().user_begin())) {
|
||||
return;
|
||||
}
|
||||
concat.emitWarning(
|
||||
"Missing quantization parameter on the output might introduce "
|
||||
"quantization error!");
|
||||
});
|
||||
}
|
||||
|
||||
using PrepareQuantStats =
|
||||
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
|
||||
|
||||
void PrepareQuantizePass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
MLIRContext* ctx = func.getContext();
|
||||
|
||||
ConvertTFLQuantOpsToMlirQuantOps(func);
|
||||
|
||||
if (quant_specs_.post_training_quantization) {
|
||||
@ -220,6 +268,8 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
}
|
||||
applyPatternsGreedily(func, patterns);
|
||||
|
||||
SanityCheckAndAdjustment(func);
|
||||
|
||||
// Finally, the quantization parameters can be propagated to the rest of the
|
||||
// values (tensors).
|
||||
ApplyQuantizationParamsPropagation(func, is_signed, disable_per_channel,
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This is the quantization pattern definition file for TensorFlow Lite.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
|
||||
// Quantize attribute $0 by using quantization parameter from %1.
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -14,7 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -52,25 +52,47 @@ class WhileOutlinePass : public mlir::ModulePass<WhileOutlinePass> {
|
||||
|
||||
tensorflow::OpOrArgLocNameMapper mapper_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::string WhileOutlinePass::GetName(Operation* op, StringRef suffix) {
|
||||
return (mapper_.GetUniqueName(op) + suffix).str();
|
||||
}
|
||||
|
||||
// Returns whether the WhileOp is already outlined (e.g., only consists of calls
|
||||
// to functions).
|
||||
static bool IsAlreadyOutlinedd(WhileOp while_op) {
|
||||
auto just_call = [](Region& region) {
|
||||
auto it = region.front().begin();
|
||||
if (!isa<CallOp>(*it)) return false;
|
||||
++it;
|
||||
if (!isa<YieldOp>(*it)) return false;
|
||||
return true;
|
||||
};
|
||||
return just_call(while_op.body()) && just_call(while_op.cond());
|
||||
}
|
||||
|
||||
void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
OpBuilder builder(&getContext());
|
||||
// Colect external values used. Note: if an external value is also passed in
|
||||
// via argument, then it could end up being passed in multiple times. In the
|
||||
// case where the value was already just passed through, this will result in
|
||||
// redundancy.
|
||||
// Collect external values used.
|
||||
llvm::SetVector<Value> extern_values;
|
||||
|
||||
// Sink down none type constants into the functions.
|
||||
// The basic block arguments correspond to values that are loop carried, while
|
||||
// all those post are loop independent. Initialize extern_values with while_op
|
||||
// not loop carried operands.
|
||||
auto num_loop_carried = while_op.cond().front().getNumArguments();
|
||||
auto not_carried_operands =
|
||||
while_op.getOperands().drop_front(num_loop_carried);
|
||||
extern_values.insert(not_carried_operands.begin(),
|
||||
not_carried_operands.end());
|
||||
auto old_extern_values_size = extern_values.size();
|
||||
|
||||
llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()};
|
||||
for (auto it : llvm::enumerate(regions)) {
|
||||
llvm::SetVector<Value> region_extern_values;
|
||||
Value const_none = nullptr;
|
||||
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
|
||||
|
||||
// Sink down none type constants into the functions.
|
||||
for (auto extern_value : region_extern_values) {
|
||||
if (!extern_value.getType().isa<NoneType>()) {
|
||||
extern_values.insert(extern_value);
|
||||
@ -89,12 +111,23 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
}
|
||||
}
|
||||
|
||||
// Colect new types.
|
||||
bool has_extra_extern_values = old_extern_values_size != extern_values.size();
|
||||
// If an extern value is already an operand post the loop carried operands,
|
||||
// then it need not be passed in again.
|
||||
// Compute all the extra operands that have to be added to the while.
|
||||
llvm::SetVector<Value> extra_operands;
|
||||
if (has_extra_extern_values) {
|
||||
auto new_extern =
|
||||
extern_values.getArrayRef().drop_front(old_extern_values_size);
|
||||
extra_operands.insert(new_extern.begin(), new_extern.end());
|
||||
}
|
||||
|
||||
// Skip if already just calls.
|
||||
if (extra_operands.empty() && IsAlreadyOutlinedd(while_op)) return;
|
||||
|
||||
// Collect new types.
|
||||
SmallVector<Type, 4> types;
|
||||
types.reserve(extern_values.size() +
|
||||
while_op.cond().front().getNumArguments());
|
||||
// Type of block arguments are used as these could differ from those of While
|
||||
// op, but has to match between cond and body.
|
||||
types.reserve(extra_operands.size() + while_op.getNumOperands());
|
||||
for (BlockArgument ba : while_op.cond().front().getArguments())
|
||||
types.push_back(ba.getType());
|
||||
for (Value operand : extern_values) types.push_back(operand.getType());
|
||||
@ -119,7 +152,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
outlined_func.getBody().takeBody(region);
|
||||
Region& func_region = outlined_func.getBody();
|
||||
|
||||
// Replace all external uses with block args and update uses..
|
||||
// Replace all external uses with block args and update uses.
|
||||
llvm::SmallVector<Value, 4> new_args;
|
||||
new_args.reserve(extern_values.size());
|
||||
Block& block = func_region.front();
|
||||
@ -133,10 +166,12 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
Operation* yield_op = outlined_func.getBody().front().getTerminator();
|
||||
OpBuilder b(yield_op);
|
||||
llvm::SmallVector<Value, 4> args;
|
||||
args.reserve(yield_op->getNumOperands() + new_args.size());
|
||||
auto loop_carried_yield_operands =
|
||||
yield_op->getOperands().take_front(num_loop_carried);
|
||||
args.reserve(loop_carried_yield_operands.size() + new_args.size());
|
||||
if (passthru_extra_args) {
|
||||
// Add operands of yield to the return, inserting casts if needed.
|
||||
for (auto it : llvm::zip(yield_op->getOperands(), types)) {
|
||||
for (auto it : llvm::zip_first(loop_carried_yield_operands, types)) {
|
||||
auto value = std::get<0>(it);
|
||||
auto type = std::get<1>(it);
|
||||
if (value.getType() == type) {
|
||||
@ -160,11 +195,6 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
// Replace region with call to outline function.
|
||||
auto replace_with_call = [&](StringRef name, Region& region,
|
||||
bool passthru_extra_args) {
|
||||
// Skip if already only a call.
|
||||
if (region.front().getOperations().size() == 2 &&
|
||||
isa<mlir::CallOp>(region.front().front()))
|
||||
return;
|
||||
|
||||
auto func = create_outline_func(name, region, passthru_extra_args);
|
||||
OpBuilder b(region);
|
||||
// The body of the region is empty/has been outlined into the function.
|
||||
@ -185,19 +215,19 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
|
||||
|
||||
// If there are extern values used then the result type of the while has to
|
||||
// change, so replace with new while op.
|
||||
if (extern_values.empty()) return;
|
||||
if (extra_operands.empty()) return;
|
||||
|
||||
Operation* op = while_op.getOperation();
|
||||
SmallVector<Value, 4> operands;
|
||||
SmallVector<Type, 4> new_types;
|
||||
operands.reserve(op->getNumOperands() + extern_values.size());
|
||||
operands.reserve(types.size());
|
||||
new_types.reserve(operands.size());
|
||||
auto add_operand = [&](Value v) {
|
||||
operands.push_back(v);
|
||||
new_types.push_back(v.getType());
|
||||
};
|
||||
for (auto operand : op->getOperands()) add_operand(operand);
|
||||
for (auto operand : extern_values) add_operand(operand);
|
||||
for (auto operand : extra_operands) add_operand(operand);
|
||||
|
||||
Operation* new_op = OpBuilder(op).insert(Operation::create(
|
||||
op->getLoc(), op->getName(), new_types, operands, op->getAttrs(),
|
||||
@ -212,7 +242,6 @@ void WhileOutlinePass::runOnModule() {
|
||||
getModule().walk(
|
||||
[&](mlir::TFL::WhileOp while_op) { OutlineWhile(while_op); });
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateWhileOutlinePass() {
|
||||
|
@ -38,7 +38,7 @@ FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) {
|
||||
|
||||
IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) {
|
||||
if (attr.getType().getNumElements() != 1 ||
|
||||
!attr.getType().getElementType().isa<IntegerType>()) {
|
||||
!attr.getType().getElementType().isSignlessInteger()) {
|
||||
return {};
|
||||
}
|
||||
SmallVector<uint64_t, 8> index(attr.getType().getRank(), 0);
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
@ -70,14 +70,14 @@ Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
|
||||
builder->getUnitAttr());
|
||||
}
|
||||
|
||||
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
Value Transpose(OpBuilder* builder, Value value_to_transpose,
|
||||
SmallVector<int64_t, 4> perm, RankedTensorType original_type,
|
||||
mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 2> perm = {1, 0};
|
||||
auto perm_op = CreateI64DenseConst(builder, perm, perm, location);
|
||||
|
||||
// Create tensor type for the transpose result.
|
||||
auto transpose_type = type;
|
||||
auto transpose_type = original_type;
|
||||
auto transpose_shape = functional::map(
|
||||
[transpose_type](int64_t dim) { return transpose_type.getDimSize(dim); },
|
||||
perm);
|
||||
@ -88,6 +88,13 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
value_to_transpose, perm_op);
|
||||
}
|
||||
|
||||
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
||||
RankedTensorType type, mlir::Location location) {
|
||||
// Create a constant op for transpose permutation.
|
||||
SmallVector<int64_t, 4> perm = {1, 0};
|
||||
return Transpose(builder, value_to_transpose, perm, type, location);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
||||
return value.getType().cast<RankedTensorType>().getShape();
|
||||
}
|
||||
@ -586,15 +593,30 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
Value recurrent_kernel = func_op.getArgument(4);
|
||||
Value bias = func_op.getArgument(5);
|
||||
|
||||
// Assume it's batch majored.
|
||||
// TFL lstm only supports time-majored inputs, so if it's not time-majored,
|
||||
// we will transpose the inputs and outputs.
|
||||
auto time_major_attr = func_op.getAttrOfType<BoolAttr>("tf.time_major");
|
||||
if (time_major_attr == nullptr) return failure();
|
||||
|
||||
bool time_majored = time_major_attr.getValue();
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_type) {
|
||||
func_op.emitError() << "Input type is not a ranked tensor type";
|
||||
return failure();
|
||||
}
|
||||
|
||||
int batch = input_type.getDimSize(0);
|
||||
int time = input_type.getDimSize(1);
|
||||
auto final_inputs = input;
|
||||
auto final_input_type = input_type;
|
||||
// We will transpose the inputs.
|
||||
if (!time_majored) {
|
||||
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||
final_inputs =
|
||||
Transpose(builder, final_inputs, perm, input_type, func_op.getLoc());
|
||||
final_input_type = final_inputs.getType().dyn_cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
int batch = final_input_type.getDimSize(1);
|
||||
int time = final_input_type.getDimSize(0);
|
||||
|
||||
// Setup correct weights.
|
||||
RankedTensorType weight_type =
|
||||
@ -672,7 +694,13 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
|
||||
builder->getStringAttr("FULL"));
|
||||
|
||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), lstm.getResult());
|
||||
auto final_output = lstm.getResult();
|
||||
if (!time_majored) {
|
||||
SmallVector<int64_t, 4> perm = {1, 0, 2};
|
||||
final_output =
|
||||
Transpose(builder, final_output, perm, result_type, func_op.getLoc());
|
||||
}
|
||||
builder->create<mlir::ReturnOp>(func_op.getLoc(), final_output);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
|
||||
namespace mlir {
|
||||
|
@ -90,7 +90,7 @@ gentbl(
|
||||
td_file = "ir/tf_saved_model_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -114,7 +114,7 @@ gentbl(
|
||||
td_file = "ir/tf_executor_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -138,7 +138,7 @@ gentbl(
|
||||
td_file = "ir/tf_device_ops.td",
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
],
|
||||
)
|
||||
|
||||
@ -281,12 +281,12 @@ cc_library(
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/graph_pruning.cc",
|
||||
"transforms/inline_global_tensors.cc",
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_function_visibility.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
@ -376,6 +376,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LoopOpsTransforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -1000,8 +1001,13 @@ cc_library(
|
||||
srcs = ["utils/tpu_rewrite_device_util.cc"],
|
||||
hdrs = ["utils/tpu_rewrite_device_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
@ -1016,6 +1022,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
@ -573,9 +573,9 @@ void Print(SwitchNOp switchn, OpAsmPrinter &p) {
|
||||
|
||||
ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) {
|
||||
// Parsing:
|
||||
// %2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
// %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||
// Where the first operand is the data to replicate, the second is an i32
|
||||
// indicating which output to populate, followed by the keyword `by` and the
|
||||
// indicating which output to populate, followed by the keyword `of` and the
|
||||
// number of outputs (+1 for the control token).
|
||||
SmallVector<OpAsmParser::OperandType, 2> op_infos;
|
||||
SmallVector<Type, 1> types;
|
||||
|
@ -165,7 +165,7 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island",
|
||||
The `tf_executor.island` operation has a single region with a single block
|
||||
attached (only functional control flow is allowed). The block is terminated
|
||||
by a `tf_executor.yield` operation. The operands of the terminator
|
||||
correspond to the result values of the `tf_executor.graph` operation. An
|
||||
correspond to the result values of the `tf_executor.island` operation. An
|
||||
extra result of type `!tf_executor.control` is always produced by every
|
||||
`tf_executor.island`.
|
||||
Within an island, execution semantics follow standard sequential behavior as
|
||||
@ -299,7 +299,7 @@ def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN",
|
||||
.SetShapeFn(SwitchNShape);
|
||||
|
||||
For example:
|
||||
%2:6 = tf_executor.SwitchN %0, %1 by 5 : tensor<??xf32>
|
||||
%2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32>
|
||||
|
||||
Note: One additional result corresponds to the control output.
|
||||
}];
|
||||
|
@ -510,6 +510,7 @@ Broadcasting is supported, so `value` may have any number of dimensions.
|
||||
// TF_LayoutSensitiveInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
@ -980,7 +981,7 @@ tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect]> {
|
||||
def TF_Conv2DOp : TF_Op<"Conv2D", [NoSideEffect, TF_LayoutSensitiveInterface]> {
|
||||
let summary = [{
|
||||
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
|
||||
}];
|
||||
@ -1030,6 +1031,13 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_LayoutSensitiveInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult UpdateDataFormat(StringRef data_format);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect]> {
|
||||
@ -2091,7 +2099,7 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> {
|
||||
def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Batch normalization.";
|
||||
|
||||
let description = [{
|
||||
@ -2122,6 +2130,13 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> {
|
||||
@ -3096,6 +3111,70 @@ cublas.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, AllTypesMatch<["input", "band"]>]> {
|
||||
let summary = [{
|
||||
Copy a tensor setting everything outside a central band in each innermost matrix
|
||||
to zero.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The `band` part is computed as follows:
|
||||
Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
|
||||
tensor with the same shape where
|
||||
|
||||
`band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
|
||||
|
||||
The indicator function
|
||||
|
||||
`in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
|
||||
(num_upper < 0 || (n-m) <= num_upper)`.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# if 'input' is [[ 0, 1, 2, 3]
|
||||
[-1, 0, 1, 2]
|
||||
[-2, -1, 0, 1]
|
||||
[-3, -2, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
|
||||
[-1, 0, 1, 2]
|
||||
[ 0, -1, 0, 1]
|
||||
[ 0, 0, -1, 0]],
|
||||
|
||||
tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
|
||||
[-1, 0, 1, 0]
|
||||
[-2, -1, 0, 1]
|
||||
[ 0, -2, -1, 0]]
|
||||
```
|
||||
|
||||
Useful special cases:
|
||||
|
||||
```
|
||||
tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
|
||||
tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
|
||||
tf.matrix_band_part(input, 0, 0) ==> Diagonal.
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_I32OrI64Tensor:$num_lower,
|
||||
TF_I32OrI64Tensor:$num_upper
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$band
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tindex = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MatrixDiagOp : TF_Op<"MatrixDiag", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a batched diagonal tensor with a given batched diagonal values.
|
||||
@ -4278,7 +4357,7 @@ This is the opposite of `unpack`.
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect]> {
|
||||
def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Pads a tensor with zeros.";
|
||||
|
||||
let description = [{
|
||||
@ -4317,6 +4396,13 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {0}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_PadV2Op : TF_Op<"PadV2", [NoSideEffect]> {
|
||||
@ -4845,7 +4931,7 @@ I.e., \\(y = 1 / x\\).
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ReluOp : TF_Op<"Relu", [NoSideEffect, SameOperandsAndResultType, TF_LayoutAgnostic]> {
|
||||
let summary = "Computes rectified linear: `max(features, 0)`.";
|
||||
|
||||
let description = [{
|
||||
|
@ -85,7 +85,7 @@ class TF_TensorFlowType <string name, string description> :
|
||||
|
||||
// Any tensor element type allowed in TensorFlow ops
|
||||
def TF_ElementType : Type<Or<[AnyFloat.predicate,
|
||||
AnyInteger.predicate,
|
||||
AnySignlessInteger.predicate,
|
||||
AnyComplex.predicate,
|
||||
TF_TFDialectType.predicate]>,
|
||||
"tf.dtype">;
|
||||
|
@ -50,6 +50,12 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> {
|
||||
[{Returns indices of layout dependent results.}],
|
||||
"SmallVector<unsigned, 4>", "GetLayoutDependentResults", (ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
[{Updates operation attributes and operands to account for the updated
|
||||
data format. If data format is not supported, must return failure.}],
|
||||
"LogicalResult", "UpdateDataFormat",
|
||||
(ins "StringRef":$data_format)
|
||||
>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
@ -151,26 +151,6 @@ static bool AreCastCompatible(Type a, Type b) {
|
||||
b_kind == TensorFlowTypes::VARIANT;
|
||||
}
|
||||
|
||||
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||
DenseIntElementsAttr perm1) {
|
||||
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
||||
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||
|
||||
SmallVector<int64_t, 8> perm0_values;
|
||||
for (auto value : perm0.getIntValues())
|
||||
perm0_values.push_back(value.getSExtValue());
|
||||
|
||||
SmallVector<int64_t, 8> perm1_values;
|
||||
for (auto value : perm1.getIntValues())
|
||||
perm1_values.push_back(value.getSExtValue());
|
||||
|
||||
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||
if (perm0_values[perm1_values[i]] != i) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
}
|
||||
@ -312,6 +292,164 @@ static LogicalResult VerifyTypesCompatibility(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions to work with layout transformation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t, 4> reverse(permutation.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
reverse[permutation[i]] = i;
|
||||
}
|
||||
return reverse;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to) {
|
||||
if (from == "NHWC" && to == "NCHW") {
|
||||
return {0, 3, 1, 2};
|
||||
} else if (from == "NCHW" && to == "NHWC") {
|
||||
return {0, 2, 3, 1};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
// Shuffle elements in the `attr` according to the permutation. Optional
|
||||
// `inner_size` allows to shuffle array attributes created from rank 2 tensors
|
||||
// on outer dimension only.
|
||||
ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
|
||||
int inner_size = 1) {
|
||||
if (attr.size() == 0) return attr;
|
||||
|
||||
assert(attr.size() % inner_size == 0);
|
||||
assert(attr.size() / inner_size == permutation.size());
|
||||
|
||||
SmallVector<Attribute, 8> values{attr.begin(), attr.end()};
|
||||
SmallVector<Attribute, 8> shuffled(values.size());
|
||||
|
||||
for (size_t i = 0; i < permutation.size(); ++i) {
|
||||
for (size_t j = 0; j < inner_size; ++j) {
|
||||
shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j];
|
||||
}
|
||||
}
|
||||
|
||||
return ArrayAttr::get(shuffled, attr.getContext());
|
||||
}
|
||||
|
||||
// Shuffle ranked tensor dimensions according to the permutation.
|
||||
Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation) {
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||
assert(permutation.size() == shape.size());
|
||||
|
||||
SmallVector<int64_t, 4> new_shape(permutation.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i)
|
||||
new_shape[i] = shape[permutation[i]];
|
||||
|
||||
return RankedTensorType::get(new_shape, ranked_type.getElementType());
|
||||
}
|
||||
|
||||
return type;
|
||||
}
|
||||
|
||||
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||
DenseIntElementsAttr perm1) {
|
||||
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
||||
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||
|
||||
SmallVector<int64_t, 8> perm0_values;
|
||||
for (auto value : perm0.getIntValues())
|
||||
perm0_values.push_back(value.getSExtValue());
|
||||
|
||||
SmallVector<int64_t, 8> perm1_values;
|
||||
for (auto value : perm1.getIntValues())
|
||||
perm1_values.push_back(value.getSExtValue());
|
||||
|
||||
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||
if (perm0_values[perm1_values[i]] != i) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
|
||||
// layout sensitive operations that do not have any additional layout dependent
|
||||
// attributes besides `data_format` string.
|
||||
template <typename Op>
|
||||
LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
|
||||
auto perm = GetDataFormatPermutation(op->data_format(), data_format);
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data format attribute.
|
||||
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
|
||||
|
||||
// Update types for all layout sensitive results.
|
||||
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
|
||||
for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
|
||||
OpResult result = op->getOperation()->getResult(idx);
|
||||
result.setType(ShuffleRankedTensorType(result.getType(), perm));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Default implementation for folding operand transpose into the operation.
|
||||
// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
|
||||
template <typename Op>
|
||||
LogicalResult FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation, Op *op,
|
||||
ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
|
||||
MLIRContext *context = op->template getParentOfType<ModuleOp>().getContext();
|
||||
|
||||
// We only support NHWC <-> NCHW permutations.
|
||||
static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
|
||||
static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
|
||||
|
||||
// Operation data format after folding `permutation`.
|
||||
StringRef target_data_format = [&]() -> StringRef {
|
||||
if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
|
||||
return "NCHW"; // cancel NCHW->NHWC operand permutation
|
||||
} else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
|
||||
return "NHWC"; // cancel NHWC->NCHW operand permutation
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}();
|
||||
if (target_data_format.empty()) return failure();
|
||||
|
||||
// To fold operand `permutation` into the `op` we need shuffle all layout
|
||||
// dependent attributes and types with a reverse permutation, and change
|
||||
// operation data format to `target_data_format`.
|
||||
//
|
||||
// Example:
|
||||
// %1 = SomeOp(...) {data_format = NHWC}
|
||||
// %2 = Transpose(%1) {permutation = NHWC->NCHW}
|
||||
// %3 = Op(%2) {data_format = NCHW}
|
||||
//
|
||||
// To bypass %2 we have to change data format to shuffle data format from NCHW
|
||||
// to NHWC, which is the reverse of operand permutation (function argument).
|
||||
auto reverse_permutation =
|
||||
GetDataFormatPermutation(op->data_format(), target_data_format);
|
||||
if (reverse_permutation.empty()) return failure();
|
||||
|
||||
op->setAttr("data_format", StringAttr::get(target_data_format, context));
|
||||
|
||||
for (auto pair : shuffle_attrs) {
|
||||
StringRef attr_name = pair.first;
|
||||
ArrayAttr attr_value = pair.second;
|
||||
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
|
||||
}
|
||||
|
||||
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
|
||||
for (unsigned idx : fold.GetLayoutDependentResults()) {
|
||||
OpResult result = op->getOperation()->getResult(idx);
|
||||
result.setType(
|
||||
ShuffleRankedTensorType(result.getType(), reverse_permutation));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
@ -479,6 +617,15 @@ static LogicalResult Verify(BiasAddOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): BiasAddOp is not really layout sensitive, it must only
|
||||
// support folding operand transposes.
|
||||
LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) {
|
||||
auto ranked = value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked || ranked.getRank() != 4) return failure();
|
||||
|
||||
return ::mlir::TF::UpdateDataFormat(data_format, this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddGradOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -837,6 +984,21 @@ static LogicalResult Verify(OpT op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
|
||||
auto perm = GetDataFormatPermutation(this->data_format(), data_format);
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data_format attribute and result types.
|
||||
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||
|
||||
// Update convolution attributes.
|
||||
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conv2dBackpropInputOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1158,6 +1320,11 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return ::mlir::TF::FoldOperandsPermutation(permutation, this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GatherV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1339,6 +1506,29 @@ void LogicalNotOp::getCanonicalizationPatterns(
|
||||
LogicalNotOfLess, LogicalNotOfLessEqual>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatrixBandPartOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(MatrixBandPartOp op) {
|
||||
if (!HasRankAtLeast(op.input(), 2)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `input` to have rank of at least 2, but found "
|
||||
<< op.input().getType();
|
||||
}
|
||||
if (!IsOfRankOrUnranked(op.num_lower(), 0)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `num_lower` to have 0 dimensions, but found "
|
||||
<< op.num_lower().getType();
|
||||
}
|
||||
if (!IsOfRankOrUnranked(op.num_upper(), 0)) {
|
||||
return op.emitOpError()
|
||||
<< "requires `num_upper` to have 0 dimensions, but found "
|
||||
<< op.num_upper().getType();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1356,57 +1546,8 @@ void MaxOp::build(Builder *builder, OperationState &result, Value input,
|
||||
|
||||
LogicalResult MaxPoolOp::FoldOperandsPermutation(
|
||||
ArrayRef<int64_t> permutation) {
|
||||
MLIRContext *context = getParentOfType<ModuleOp>().getContext();
|
||||
|
||||
// For now we only support folding of NCHW->NHWC and NHWC->NCHW permutations.
|
||||
if (data_format() == "NHWC") {
|
||||
static constexpr std::array<int64_t, 4> kPerm = {0, 2, 3, 1}; // to NHWC
|
||||
if (permutation != ArrayRef<int64_t>(kPerm)) return failure();
|
||||
|
||||
setAttr("data_format", StringAttr::get("NCHW", context));
|
||||
|
||||
} else if (data_format() == "NCHW") {
|
||||
static constexpr std::array<int64_t, 4> kPerm = {0, 3, 1, 2}; // to NCHW
|
||||
if (permutation != ArrayRef<int64_t>(kPerm)) return failure();
|
||||
|
||||
setAttr("data_format", StringAttr::get("NHWC", context));
|
||||
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto shuffle_attr = [&](ArrayAttr attr) -> ArrayAttr {
|
||||
SmallVector<Attribute, 4> values{attr.begin(), attr.end()};
|
||||
SmallVector<Attribute, 4> shuffled(values.size());
|
||||
|
||||
for (size_t i = 0; i < permutation.size(); ++i)
|
||||
shuffled[permutation[i]] = values[i];
|
||||
|
||||
return ArrayAttr::get(shuffled, context);
|
||||
};
|
||||
|
||||
setAttr("strides", shuffle_attr(strides()));
|
||||
setAttr("ksize", shuffle_attr(ksize()));
|
||||
|
||||
auto shuffle_type = [&](Type type) -> Type {
|
||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||
assert(permutation.size() == shape.size());
|
||||
|
||||
SmallVector<int64_t, 4> new_shape(permutation.size());
|
||||
for (size_t i = 0; i < permutation.size(); ++i)
|
||||
new_shape[permutation[i]] = shape[i];
|
||||
|
||||
return RankedTensorType::get(new_shape, ranked_type.getElementType());
|
||||
}
|
||||
|
||||
return type;
|
||||
};
|
||||
|
||||
OpResult result = getOperation()->getResult(0);
|
||||
result.setType(shuffle_type(result.getType()));
|
||||
|
||||
return success();
|
||||
return ::mlir::TF::FoldOperandsPermutation(
|
||||
permutation, this, {{"strides", strides()}, {"ksize", ksize()}});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1426,6 +1567,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MeanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Reduction indices must be defined by a constant operation.
|
||||
auto reduction_op =
|
||||
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
|
||||
if (!reduction_op) return failure();
|
||||
|
||||
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!reductions_value) return failure();
|
||||
|
||||
// Prepare new reduction indices according to operand permutation.
|
||||
SmallVector<int32_t, 4> shuffled_reduction;
|
||||
llvm::transform(reductions_value.getIntValues(),
|
||||
std::back_inserter(shuffled_reduction),
|
||||
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
|
||||
|
||||
// Add constant operation with a new reduction indices.
|
||||
OpBuilder builder(getOperation());
|
||||
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
|
||||
builder.getIntegerType(32));
|
||||
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
|
||||
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||
|
||||
// Use new reduction indices.
|
||||
setOperand(1, shuffled_reduction_op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1568,6 +1741,46 @@ static LogicalResult Verify(PackOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult PadOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Paddings must be defined by a constant operation.
|
||||
auto paddings_op = dyn_cast_or_null<TF::ConstOp>(paddings().getDefiningOp());
|
||||
if (!paddings_op) return failure();
|
||||
|
||||
auto paddings_value = paddings_op.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!paddings_value ||
|
||||
paddings_value.getNumElements() != permutation.size() * 2)
|
||||
return failure();
|
||||
|
||||
SmallVector<int32_t, 8> shuffled_paddings(paddings_value.getNumElements());
|
||||
for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) {
|
||||
size_t outer_idx = index_pair.index() / 2;
|
||||
size_t inner_idx = index_pair.index() % 2;
|
||||
|
||||
shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] =
|
||||
index_pair.value().getSExtValue();
|
||||
}
|
||||
|
||||
// Add constant operation with a new paddings.
|
||||
OpBuilder builder(getOperation());
|
||||
auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(),
|
||||
builder.getIntegerType(32));
|
||||
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings);
|
||||
auto shuffled_paddings_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||
|
||||
// Use new paddings.
|
||||
setOperand(1, shuffled_paddings_op);
|
||||
|
||||
// Change the result type.
|
||||
getResult().setType(ShuffleRankedTensorType(getResult().getType(),
|
||||
ReversePermutation(permutation)));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ParseExampleV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1914,7 +2127,8 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
||||
}
|
||||
|
||||
Type element_type = result_ranked_type.getElementType();
|
||||
if (!element_type.isInteger(32) && !element_type.isInteger(64))
|
||||
if (!element_type.isSignlessInteger(32) &&
|
||||
!element_type.isSignlessInteger(64))
|
||||
return op->emitOpError("requires int32 or int64 return type for result")
|
||||
<< variadic_idx_str;
|
||||
|
||||
|
@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -195,6 +195,13 @@ retained with length 1.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
||||
|
@ -112,24 +112,20 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor) {
|
||||
auto type = global_tensor.type().cast<TensorType>();
|
||||
return RankedTensorType::get(
|
||||
{}, TF::ResourceType::get({type}, type.getContext()));
|
||||
}
|
||||
|
||||
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
|
||||
Type arg_type,
|
||||
GlobalTensorOp global_tensor) {
|
||||
if (global_tensor.is_mutable()) {
|
||||
auto expected_type = RankedTensorType::get(
|
||||
{}, TF::ResourceType::get({global_tensor.type().cast<TensorType>()},
|
||||
arg_type.getContext()));
|
||||
if (arg_type != expected_type) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "mutable bound input with type " << arg_type
|
||||
<< " expected to have type " << expected_type;
|
||||
}
|
||||
} else {
|
||||
if (arg_type != global_tensor.type()) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "bound input for immutable 'tf_saved_model.global_tensor' must "
|
||||
"match the global tensor's type";
|
||||
}
|
||||
auto expected_type = GetBoundInputArgTypeFor(global_tensor);
|
||||
if (arg_type != expected_type) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "bound input with type " << arg_type << " expected to have type "
|
||||
<< expected_type;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -57,6 +57,10 @@ bool HasTfSavedModelSemantics(ModuleOp module);
|
||||
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table);
|
||||
|
||||
// Gets the type that an exported function arg that is bound to `global_tensor`
|
||||
// should have.
|
||||
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -91,7 +91,7 @@ class TensorFlowType : public Type {
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type) {
|
||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
||||
type.isSignlessInteger() || type.isa<TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
@ -141,20 +141,16 @@ class TensorFlowRefType : public TensorFlowType {
|
||||
static TensorFlowType get(Type type);
|
||||
static TensorFlowType getChecked(Type type, MLIRContext* context,
|
||||
Location loc) {
|
||||
if (failed(verifyConstructionInvariants(loc, context, type))) {
|
||||
if (failed(verifyConstructionInvariants(loc, type))) {
|
||||
return TensorFlowRefType();
|
||||
}
|
||||
return get(type);
|
||||
}
|
||||
|
||||
static LogicalResult verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext* context, Type type) {
|
||||
static LogicalResult verifyConstructionInvariants(Location loc, Type type) {
|
||||
// type should be a valid TensorFlow type.
|
||||
if (!IsValidTFTensorType(type)) {
|
||||
if (loc) {
|
||||
emitError(*loc) << "invalid TensorFlow type: " << type;
|
||||
}
|
||||
return failure();
|
||||
return emitError(loc) << "invalid TensorFlow type: " << type;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -230,7 +226,7 @@ class TypeWithSubtypeImpl
|
||||
|
||||
static Derived getChecked(ArrayRef<TensorType> subtypes, MLIRContext* context,
|
||||
Location loc) {
|
||||
return Base::getChecked(loc, context, Derived::getTypeKind(), subtypes);
|
||||
return Base::getChecked(loc, Derived::getTypeKind(), subtypes);
|
||||
}
|
||||
|
||||
static Derived get(MLIRContext* context) { return get({}, context); }
|
||||
@ -239,16 +235,12 @@ class TypeWithSubtypeImpl
|
||||
static bool kindof(unsigned kind) { return kind == Derived::getTypeKind(); }
|
||||
|
||||
static LogicalResult verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext* context,
|
||||
ArrayRef<TensorType> subtypes) {
|
||||
Location loc, ArrayRef<TensorType> subtypes) {
|
||||
// Each of the subtypes should be a valid TensorFlow type.
|
||||
for (TensorType subtype : subtypes) {
|
||||
if (!IsValidTFTensorType(subtype)) {
|
||||
if (loc) {
|
||||
emitError(*loc) << "invalid " << Derived::getTypeName()
|
||||
<< " subtype: " << subtype;
|
||||
}
|
||||
return failure();
|
||||
return emitError(loc) << "invalid " << Derived::getTypeName()
|
||||
<< " subtype: " << subtype;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
@ -280,3 +280,67 @@ func @empty_island_multiple_data_results(%arg0: tensor<*xf32>, %arg1: tensor<*xi
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// The following tests check that certain control dependencies between islands
|
||||
// and certain tf_executor ops are added correctly.
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.NextIteration.Sink [{{.*}}] {{.*}}, %[[CONTROL]]
|
||||
func @next_iteration_sink_control_input() {
|
||||
tf_executor.graph {
|
||||
%source:3 = tf_executor.NextIteration.Source : tensor<*xi32>
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
tf_executor.NextIteration.Sink[%source#1] %island#0 : tensor<*xi32>
|
||||
tf_executor.fetch %island#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.LoopCond {{.*}}, %[[CONTROL]]
|
||||
func @loop_cond_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i1>} : () -> tensor<*xi1>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi1>) -> (tensor<*xi1>)
|
||||
tf_executor.yield %const : tensor<*xi1>
|
||||
}
|
||||
%loop_cond:2 = tf_executor.LoopCond %island#0 : tensor<*xi1>
|
||||
tf_executor.fetch %loop_cond#0 : tensor<*xi1>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.Enter {{.*}}, %[[CONTROL]]
|
||||
func @enter_control_input() {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%enter:2 = tf_executor.Enter %island#0 frame "some/frame" : tensor<*xi32>
|
||||
tf_executor.fetch %enter#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[CONTROL:[^ ,]*]] = tf_executor.island wraps "tf.Print"
|
||||
// CHECK: tf_executor.SwitchN {{.*}}, {{.*}} of {{[0-9]*}} (%[[CONTROL]])
|
||||
func @switchn_control_input(%arg1: tensor<i32>) {
|
||||
tf_executor.graph {
|
||||
%island:2 = tf_executor.island {
|
||||
%const = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<*xi32>
|
||||
%print = "tf.Print"(%const) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
tf_executor.yield %const : tensor<*xi32>
|
||||
}
|
||||
%switchn:4 = tf_executor.SwitchN %island#0, %arg1 of 3: tensor<*xi32>
|
||||
tf_executor.fetch %switchn#0 : tensor<*xi32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -9,5 +9,7 @@ func @device_test(%arg0: tensor<3x1xf32>) -> (tensor<3x3xf32>) {
|
||||
%1 = "tf.MatMul"(%arg0, %0) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||
// CHECK: device = "cpu"
|
||||
%2 = "tf.Relu"(%1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "cpu"} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
return %2 : tensor<3x3xf32>
|
||||
// CHECK: device = "gpu"
|
||||
%3 = "tf.Relu"(%2) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"]} : (tensor<3x3xf32>) -> tensor<3x3xf32>
|
||||
return %3 : tensor<3x3xf32>
|
||||
}
|
||||
|
@ -6,10 +6,10 @@ func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tens
|
||||
// Check that BiasAdd was converted to forced data format, and layout
|
||||
// dependent arguments and results passed through transpose nodes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
@ -20,10 +20,10 @@ func @transposeBiasAdd(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tens
|
||||
// CHECK-LABEL: func @transposeBiasAddWithDefaultAttr
|
||||
func @transposeBiasAddWithDefaultAttr(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%[[ARG_TRANSPOSE]], %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[BIAS_ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
@ -38,4 +38,38 @@ func @transposeBiasWithUnknownShape(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<8x
|
||||
%0 = "tf.BiasAdd"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<*xf32>
|
||||
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D
|
||||
func @transposeConv2D(%input: tensor<1x32x32x3xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32> {
|
||||
|
||||
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||
// dilations, etc...). This test only verifies that changing convolution data
|
||||
// layout will update all the attributes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||
// CHECK-SAME: data_format = "NCHW"
|
||||
// CHECK-SAME: dilations = [1, 4, 2, 3]
|
||||
// CHECK-SAME: explicit_paddings = [1, 2, 7, 8, 3, 4, 5, 6]
|
||||
// CHECK-SAME: padding = "EXPLICIT"
|
||||
// CHECK-SAME: strides = [5, 8, 6, 7]
|
||||
// CHECK-SAME: (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
dilations = [1, 2, 3, 4],
|
||||
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
padding = "EXPLICIT",
|
||||
strides = [5, 6, 7, 8]
|
||||
} : (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||
|
||||
return %0 : tensor<1x32x32x8xf32>
|
||||
}
|
@ -0,0 +1,35 @@
|
||||
// RUN: tf-opt %s -tf-layout-assignment=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transposeConv2D
|
||||
func @transposeConv2D(%input: tensor<1x3x32x32xf32>, %filter: tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32> {
|
||||
|
||||
// IMPORTANT: Tensor shapes do not match convolution parameters (stride,
|
||||
// dilations, etc...). This test only verifies that changing convolution data
|
||||
// layout will update all the attributes.
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
|
||||
// CHECK: %[[CONV2D:[0-9]*]] = "tf.Conv2D"(%[[ARG_TRANSPOSE]], %arg1)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: dilations = [1, 3, 4, 2]
|
||||
// CHECK-SAME: explicit_paddings = [1, 2, 5, 6, 7, 8, 3, 4]
|
||||
// CHECK-SAME: padding = "EXPLICIT"
|
||||
// CHECK-SAME: strides = [5, 7, 8, 6]
|
||||
// CHECK-SAME: (tensor<1x32x32x3xf32>, tensor<1x1x3x8xf32>) -> tensor<1x32x32x8xf32>
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[CONV2D]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Conv2D"(%input, %filter)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
dilations = [1, 2, 3, 4],
|
||||
explicit_paddings = [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
padding = "EXPLICIT",
|
||||
strides = [5, 6, 7, 8]
|
||||
} : (tensor<1x3x32x32xf32>, tensor<1x1x3x8xf32>) -> tensor<1x8x32x32xf32>
|
||||
|
||||
return %0 : tensor<1x8x32x32xf32>
|
||||
}
|
@ -3,14 +3,14 @@
|
||||
// CHECK-LABEL: func @move_across_single_op
|
||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[TANH]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
@ -18,17 +18,17 @@ func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
// CHECK-LABEL: func @move_across_multiple_ops
|
||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[TANH1]]
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[RELU]]
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Tanh"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Relu"(%0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
@ -36,15 +36,15 @@ func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32
|
||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG0_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[ARG1_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg1, %[[ARG_PERM]])
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[ARG0_TRANSPOSE]], %[[ARG1_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[ADD]]
|
||||
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%1 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%2 = "tf.Transpose"(%0, %1) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
@ -52,7 +52,7 @@ func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4
|
||||
// CHECK-LABEL: func @move_with_multiple_uses
|
||||
func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[ARG_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ARG_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%arg0, %[[ARG_PERM]])
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%[[ARG_TRANSPOSE]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[TANH]], %[[TANH]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
@ -60,8 +60,8 @@ func @move_with_multiple_uses(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
%0 = "tf.Tanh"(%arg0) : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%1 = "tf.AddV2"(%0, %0) : (tensor<1x4x4x8xf32>, tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%3 = "tf.Transpose"(%1, %2) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
|
@ -3,13 +3,13 @@
|
||||
// CHECK-LABEL: func @move_across_single_op
|
||||
func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH]], %[[RES_PERM]]) {{.*}} tensor<1x8x4x4xf32>
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %2 : tensor<1x8x4x4xf32>
|
||||
@ -18,16 +18,16 @@ func @move_across_single_op(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
// CHECK-LABEL: func @move_across_multiple_ops
|
||||
func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[TANH0:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[TANH1:[0-9]*]] = "tf.Tanh"(%[[TANH0]]) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[TANH1]], %[[RES_PERM]])
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[TANH:[0-9]*]] = "tf.Tanh"(%arg0) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[TANH]]) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[RELU]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Tanh"(%1) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Tanh"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Relu"(%2) : (tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
}
|
||||
@ -35,14 +35,14 @@ func @move_across_multiple_ops(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32
|
||||
// CHECK-LABEL: func @move_across_multi_operand_op
|
||||
func @move_across_multi_operand_op(%arg0: tensor<1x4x4x8xf32>, %arg1: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>}
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%arg0, %arg1) {{.*}} tensor<1x4x4x8xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%2 = "tf.Transpose"(%arg1, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.AddV2"(%1, %2) : (tensor<1x8x4x4xf32>, tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
return %3 : tensor<1x8x4x4xf32>
|
||||
@ -54,14 +54,14 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
|
||||
// MaxPool operand transpose must be folded into the op and MaxPool
|
||||
// must use NCHW data format with updated kernel size and strides.
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>}
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 3, 3], padding = "SAME", strides = [1, 1, 2, 2]} : (tensor<1x64x112x112xf32>) -> tensor<1x64x56x56xf32>
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[ADD]], %[[RES_PERM]])
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%[[MAX_POOL]], %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute MaxPool in NHWC format
|
||||
%2 = "tf.MaxPool"(%1)
|
||||
@ -72,3 +72,49 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
|
||||
|
||||
return %2 : tensor<1x56x56x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_mean
|
||||
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
|
||||
|
||||
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>}
|
||||
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
|
||||
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi32>) -> tensor<1x64xf32>
|
||||
// CHECK: return %[[MEAN]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute Mean over spatial dimensions in NHWC format.
|
||||
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi32>) -> tensor<1x64xf32>
|
||||
|
||||
return %3 : tensor<1x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_fused_batch_norm
|
||||
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
|
||||
|
||||
// CHECK: %[[RES_PERM:[0-9]*]] = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>}
|
||||
// CHECK: "tf.FusedBatchNormV3"(%arg0, {{.*}} {data_format = "NCHW"
|
||||
// CHECK: %[[RES_TRANSPOSE:[0-9]*]] = "tf.Transpose"(%y, %[[RES_PERM]])
|
||||
// CHECK: return %[[RES_TRANSPOSE]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute FusedBatchNormV3 in NHWC format
|
||||
%2, %batch_mean, %batch_var, %reserve_1, %reserve_2, %reserve_3
|
||||
= "tf.FusedBatchNormV3"(%1, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NHWC",
|
||||
epsilon = 1.001 : f32,
|
||||
exponential_avg_factor = 1.0 : f32,
|
||||
is_training = false
|
||||
}
|
||||
: (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<1x112x112x64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
|
||||
return %2#0 : tensor<1x112x112x64xf32>
|
||||
}
|
||||
|
@ -4,15 +4,15 @@
|
||||
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
|
||||
|
||||
// Convert input: NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
// Compute in NHWC
|
||||
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
// Convert result back: NHWC -> NCHW
|
||||
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
|
||||
// Check that BiasAdd computed in NCHW format, and all redundant transpose
|
||||
// operations removed from the function.
|
@ -0,0 +1,156 @@
|
||||
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NHWC -verify-diagnostics | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: func @transpose_resnet_layer
|
||||
func @transpose_resnet_layer(%arg0: tensor<?x224x224x3xf32>, // input
|
||||
%arg1: tensor<64xf32>, // batch_norm args
|
||||
%arg2: tensor<256xf32>, // batch_norm args
|
||||
%arg3: tensor<7x7x3x64xf32>, // conv filter #0
|
||||
%arg4: tensor<1x1x64x256xf32> // conv filter #1
|
||||
) -> tensor<?x256xf32> {
|
||||
|
||||
// This is a simplified ResNet layer that gets input in NHWC format, converts
|
||||
// it to NCHW before padding, and does all computations in NCHW (this is the
|
||||
// default setup for ResNet model trained in fp32 on GPU).
|
||||
//
|
||||
// To be able to use Tensor Cores on latest NVIDIA GPUs this model has to be
|
||||
// converted to NHWC data format.
|
||||
|
||||
// Padding in spatial dimension (NCHW)
|
||||
%0 = "tf.Const"() {value = dense<[[0, 0], [0, 0], [3, 3], [3, 3]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
|
||||
|
||||
// Reduce over spatial dimensions (NCHW)
|
||||
%1 = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
|
||||
// Transpose input: NHWC -> NCHW
|
||||
%2 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
%3 = "tf.Transpose"(%arg0, %2) : (tensor<?x224x224x3xf32>, tensor<4xi32>) -> tensor<?x3x224x224xf32>
|
||||
|
||||
// Pad spatial dimensions.
|
||||
%4 = "tf.Pad"(%3, %0) : (tensor<?x3x224x224xf32>, tensor<4x2xi32>) -> tensor<?x3x230x230xf32>
|
||||
|
||||
// Shuffled paddings.
|
||||
// CHECK: %[[PADDINGS:[0-9]*]] = "tf.Const"(){{.*}}[0, 0], [3, 3], [3, 3], [0, 0]
|
||||
|
||||
// Pad input with new paddings.
|
||||
// CHECK: %[[PAD:[0-9]*]] = "tf.Pad"(%arg0, %[[PADDINGS]])
|
||||
// CHECK-SAME: (tensor<?x224x224x3xf32>, tensor<4x2xi32>) -> tensor<?x230x230x3xf32>
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Convolution layer #0.
|
||||
// ------------------------------------------------------------------------ //
|
||||
%5 = "tf.Conv2D"(%4, %arg3)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
dilations = [1, 1, 1, 1],
|
||||
explicit_paddings = [],
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 2, 2]
|
||||
} : (tensor<?x3x230x230xf32>, tensor<7x7x3x64xf32>) -> tensor<?x64x112x112xf32>
|
||||
|
||||
// CHECK: %[[CONV0:[0-9]*]] = "tf.Conv2D"
|
||||
// CHECK-SAME %[[PAD]]
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: strides = [1, 2, 2, 1]
|
||||
|
||||
%6, %batch_mean, %batch_variance, %reserved_1, %reserved_2, %reserved_3 =
|
||||
"tf.FusedBatchNormV3"(%5, %arg1, %arg1, %arg1, %arg1)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001000e-05 : f32,
|
||||
is_training = false
|
||||
} : (tensor<?x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
|
||||
-> (tensor<?x64x112x112xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>)
|
||||
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
%7 = "tf.Relu"(%6) : (tensor<?x64x112x112xf32>) -> tensor<?x64x112x112xf32>
|
||||
%8 = "tf.MaxPool"(%7)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
ksize = [1, 1, 3, 3],
|
||||
padding = "SAME",
|
||||
strides = [1, 1, 2, 2]
|
||||
} : (tensor<?x64x112x112xf32>) -> tensor<?x64x56x56xf32>
|
||||
|
||||
// CHECK: %[[MAX_POOL:[0-9]*]] = "tf.MaxPool"
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
// CHECK-SAME: ksize = [1, 3, 3, 1]
|
||||
// CHECK-SAME: strides = [1, 2, 2, 1]
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Convolution layer #1.
|
||||
// ------------------------------------------------------------------------ //
|
||||
%9 = "tf.Conv2D"(%8, %arg4)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
dilations = [1, 1, 1, 1],
|
||||
explicit_paddings = [],
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<?x64x56x56xf32>, tensor<1x1x64x256xf32>) -> tensor<?x256x56x56xf32>
|
||||
|
||||
// CHECK: %[[CONV1:[0-9]*]] = "tf.Conv2D"(%[[MAX_POOL]], %arg4)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
%10, %batch_mean_1, %batch_variance_1, %reserved_1_1, %reserved_1_2, %reserved_1_3 =
|
||||
"tf.FusedBatchNormV3"(%9, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001000e-05 : f32
|
||||
} : (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>)
|
||||
-> (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<*xf32>)
|
||||
|
||||
// CHECK: %[[BATCH_NORM1:[_a-z0-9]*]], {{.*}} = "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: %[[CONV1]]
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Convolution layer #2.
|
||||
// ------------------------------------------------------------------------ //
|
||||
%11 = "tf.Conv2D"(%8, %arg4)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
dilations = [1, 1, 1, 1],
|
||||
explicit_paddings = [],
|
||||
padding = "VALID",
|
||||
strides = [1, 1, 1, 1]
|
||||
} : (tensor<?x64x56x56xf32>, tensor<1x1x64x256xf32>) -> tensor<?x256x56x56xf32>
|
||||
|
||||
// CHECK: %[[CONV2:[0-9]*]] = "tf.Conv2D"(%[[MAX_POOL]], %arg4)
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
%12, %batch_mean_2, %batch_variance_2, %reserved_2_1, %reserved_2_2, %reserved_2_3 =
|
||||
"tf.FusedBatchNormV3"(%11, %arg2, %arg2, %arg2, %arg2)
|
||||
{
|
||||
data_format = "NCHW",
|
||||
epsilon = 1.001000e-05 : f32
|
||||
} : (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>)
|
||||
-> (tensor<?x256x56x56xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<*xf32>)
|
||||
|
||||
// CHECK: %[[BATCH_NORM2:[_a-z0-9]*]], {{.*}} = "tf.FusedBatchNormV3"
|
||||
// CHECK-SAME: %[[CONV2]]
|
||||
// CHECK-SAME: data_format = "NHWC"
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// Add results of convolution layers #1 and #2.
|
||||
// ------------------------------------------------------------------------ //
|
||||
|
||||
%14 = "tf.AddV2"(%10, %12) : (tensor<?x256x56x56xf32>, tensor<?x256x56x56xf32>) -> tensor<?x256x56x56xf32>
|
||||
%15 = "tf.Relu"(%14) : (tensor<?x256x56x56xf32>) -> tensor<?x256x56x56xf32>
|
||||
|
||||
// CHECK: %[[ADD:[0-9]*]] = "tf.AddV2"(%[[BATCH_NORM1]], %[[BATCH_NORM2]])
|
||||
// CHECK: %[[RELU:[0-9]*]] = "tf.Relu"(%[[ADD]])
|
||||
|
||||
// Reduce spatial dimensions
|
||||
%16 = "tf.Mean"(%15, %1) : (tensor<?x256x56x56xf32>, tensor<2xi32>) -> tensor<?x256xf32>
|
||||
|
||||
// Mean should compute reduction over NHWC spatial dimensions.
|
||||
// CHECK: %[[MEAN_DIMS:[0-9]*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>}
|
||||
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%[[RELU]], %[[MEAN_DIMS]])
|
||||
// CHECK-SAME: (tensor<?x56x56x256xf32>, tensor<2xi32>) -> tensor<?x256xf32>
|
||||
// CHECK: return %[[MEAN]] : tensor<?x256xf32>
|
||||
|
||||
return %16 : tensor<?x256xf32>
|
||||
}
|
||||
|
@ -0,0 +1,194 @@
|
||||
// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @check_regions_to_islands
|
||||
func @check_regions_to_islands() {
|
||||
tf_executor.graph {
|
||||
tf_executor.island() {
|
||||
"tf_device.parallel_execute"() ({
|
||||
tf_device.return
|
||||
},
|
||||
{
|
||||
tf_device.return
|
||||
}) {} : () -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) {
|
||||
// CHECK: tf_executor.yield
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_regions_to_islands_with_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_regions_to_islands_with_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
tf_executor.island() {
|
||||
"tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_input_sink_island_forwards_control_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_input_sink_island_forwards_control_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%7 = tf_executor.ControlTrigger {}
|
||||
%8 = tf_executor.ControlTrigger {}
|
||||
tf_executor.island(%7, %8) {
|
||||
"tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"(%1#0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"() : () -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[CT_0:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: %[[CT_1:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CONTROL:[a-z_0-9]*]] = tf_executor.island(%[[CT_0]], %[[CT_1]]) {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[INPUT_CONTROL]]) {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"() : () -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) {
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_control_dep_added_when_region_does_not_have_inputs
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<i1>)
|
||||
func @check_control_dep_added_when_region_does_not_have_inputs(%arg0 : tensor<i1>) {
|
||||
tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%7:3 = tf_executor.island() {
|
||||
%8:2 = "tf_device.parallel_execute"() (
|
||||
{
|
||||
%3 = "tf.opB"() : () -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}
|
||||
) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
|
||||
tf_executor.yield %8#0, %8#1 : tensor<i1>, tensor<i32>
|
||||
}
|
||||
|
||||
tf_executor.island {
|
||||
"tf.opD"(%7#0, %7#1) : (tensor<i1>, tensor<i32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%outputs_0) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]]
|
||||
|
||||
|
||||
// CHECK-LABEL: func @check_output_barrier_correctly_forwards_outputs
|
||||
func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor<i1>) -> tensor<i1> {
|
||||
%0 = tf_executor.graph {
|
||||
%1:2 = tf_executor.island {
|
||||
%2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
tf_executor.yield %2 : tensor<i1>
|
||||
}
|
||||
%8:3 = tf_executor.island() {
|
||||
%7:2 = "tf_device.parallel_execute"() ({
|
||||
%3 = "tf.opB"() : () -> tensor<i1>
|
||||
tf_device.return %3 : tensor<i1>
|
||||
},
|
||||
{
|
||||
%5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {} : () -> (tensor<i1>, tensor<i32>)
|
||||
tf_executor.yield %7#0, %7#1 : tensor<i1>, tensor<i32>
|
||||
}
|
||||
tf_executor.fetch %8#0 : tensor<i1>
|
||||
}
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: %[[INPUT_A:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_A_OUTPUT:[a-z_0-9]*]] = "tf.opA"(%[[ARG_0]]) : (tensor<i1>) -> tensor<i1>
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[INPUT_0:[a-z_0-9]*]], %[[INPUT_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[INPUT_A]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_1_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island(%[[INPUT_CTL]]) {
|
||||
// CHECK-NEXT: %[[OP_B_OUTPUT:[a-z_0-9]*]] = "tf.opB"() : () -> tensor<i1>
|
||||
// CHECK: tf_executor.yield %[[OP_B_OUTPUT]] : tensor<i1>
|
||||
// CHECK: %[[ISLAND_2_OUTPUT:[a-z_0-9]*]], %{{.*}} = tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_C_OUTPUT:[a-z_0-9]*]] = "tf.opC"(%[[INPUT_0]]) : (tensor<i1>) -> tensor<i32>
|
||||
// CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor<i32>
|
||||
// CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island {
|
||||
// CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor<i1>, tensor<i32>
|
@ -854,6 +854,78 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp
|
||||
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp3D
|
||||
func @testValidMatrixBandPartOp3D(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64x64xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64x64xbf16>
|
||||
return %0 : tensor<64x64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOpUnranked
|
||||
func @testValidMatrixBandPartOpUnranked(%arg0: tensor<*xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<*xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<*xbf16> {
|
||||
// expected-error @+1 {{op failed to verify that all of {input, band} have same type}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<*xbf16>
|
||||
return %0 : tensor<*xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<i64>, %arg1: tensor<64x64xi64>, %arg2: tensor<i64>) -> tensor<i64> {
|
||||
// expected-error @+1 {{op requires `input` to have rank of at least 2, but found 'tensor<i64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<i64>, tensor<64x64xi64>, tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<32xi64>, %arg2: tensor<i64>) -> tensor<64x64xi64> {
|
||||
// expected-error @+1 {{op requires `num_lower` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<32xi64>, tensor<i64>) -> tensor<64x64xi64>
|
||||
return %0 : tensor<64x64xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.MatrixBandPart
|
||||
func @testInvalidMatrixBandPartOp(%arg0: tensor<64x64xi64>, %arg1: tensor<i64>, %arg2: tensor<32xi64>) -> tensor<64x64xi64> {
|
||||
// expected-error @+1 {{op requires `num_upper` to have 0 dimensions, but found 'tensor<32xi64>'}}
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xi64>, tensor<i64>, tensor<32xi64>) -> tensor<64x64xi64>
|
||||
return %0 : tensor<64x64xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.{|Stateful}PartitionedCall
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -47,7 +47,7 @@ class TestModule(tf.Module):
|
||||
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||
# CHECK-SAME: %arg0: tensor<f32> {tf_saved_model.index_path = [0]},
|
||||
# CHECK-SAME: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[VAR]]},
|
||||
# CHECK-SAME: %arg2: tensor<f32> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
||||
# CHECK-SAME: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @[[CONST]]}) -> (
|
||||
# CHECK-SAME: tensor<f32> {tf_saved_model.index_path = []})
|
||||
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["some_function"]
|
||||
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])
|
||||
|
@ -1,53 +0,0 @@
|
||||
// RUN: tf-opt -tf-saved-model-inline-global-tensors -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Simple case of inlining.
|
||||
|
||||
// CHECK-NOT: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Do not inline mutable global tensors.
|
||||
|
||||
// CHECK: tf_saved_model.global_tensor
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<1.0> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK-NOT: tf.Const
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Sanity check handling of non-bound inputs.
|
||||
// The pass shouldn't do anything in this case.
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.index_path = [0]})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK-NOT: tf.Const
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: have an arg that isn't a bound input.
|
@ -25,7 +25,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
// CHECK: func @__concrete_function_run_computation
|
||||
func @__concrete_function_run_computation(
|
||||
%arg0: tensor<f32> {tf_saved_model.index_path = [0, "foo"]},
|
||||
%arg1: tensor<1x64xf32> {tf_saved_model.bound_input = @some_constant},
|
||||
%arg1: tensor<!tf.resource<tensor<1x64xf32>>> {tf_saved_model.bound_input = @some_constant},
|
||||
%arg2: tensor<!tf.resource<tensor<?x64xf32>>> {tf_saved_model.bound_input = @some_variable}
|
||||
) -> (
|
||||
tensor<f32> {tf_saved_model.index_path = [0, "bar"]}
|
||||
|
@ -244,7 +244,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{mutable bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
|
||||
// expected-error@+1 {{bound input with type 'tensor<f32>' expected to have type 'tensor<!tf.resource<tensor<?xf32>>>'}}
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
@ -253,18 +253,6 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<1xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{bound input for immutable 'tf_saved_model.global_tensor' must match the global tensor's type}}
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<1xf32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
||||
|
@ -6,19 +6,16 @@
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Basic test of freezing.
|
||||
// Test case: Basic test of marking immutable.
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-NOT: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK-NOT: tf.ReadVariableOp
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: return %arg0
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
@ -28,18 +25,16 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Don't freeze if the variable is mutated.
|
||||
// Test case: Don't mark immutable if the variable is mutated.
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
// CHECK: tf.AssignVariableOp
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -50,14 +45,13 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Don't freeze if the variable is exported.
|
||||
// Test case: Don't mark immutable if the variable is exported.
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", tf_saved_model.exported_names = ["v"], type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
@ -71,7 +65,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Check that a non-bound input is not modified.
|
||||
// Test case: Check that a non-bound input is left unchanged.
|
||||
|
||||
// CHECK: func @g
|
||||
func @g(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
@ -86,14 +80,16 @@ module attributes {tf_saved_model.semantics} {
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Check that an immutable bound input isn't modified.
|
||||
// Test case: Check that no change is made for a global tensor that is already
|
||||
// immutable.
|
||||
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
func @h(%arg0: tensor<f32> {tf_saved_model.bound_input = @c}) -> (tensor<f32> {tf_saved_model.index_path = []})
|
||||
// CHECK: func @h(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||
func @h(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||
attributes {tf_saved_model.exported_names = ["h"]} {
|
||||
return %arg0 : tensor<f32>
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
@ -134,7 +130,7 @@ module attributes {tf_saved_model.semantics} {
|
||||
"tf_saved_model.global_tensor"() { sym_name = "c", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.bound_input = @c})
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @c})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
|
@ -102,18 +102,19 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
func @main(%arg0: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"}) {
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"}) {
|
||||
%0 = "tf.Const"() {value = dense<100> : tensor<i32>} : () -> tensor<i32>
|
||||
%1:5 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3)
|
||||
{T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE",
|
||||
"tfdtype$DT_RESOURCE"], body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false,
|
||||
output_shapes = ["tfshape$", "tfshape$", "tfshape$", "tfshape$", "tfshape$"]}
|
||||
%1:7 = "tf.While"(%0, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
|
||||
{body = @while_body_7560,
|
||||
cond = @while_cond_7550, device = "", is_stateless = false}
|
||||
: (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>)
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body_7560
|
||||
@ -122,9 +123,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||
-> (tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) {
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
@ -133,27 +137,33 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
|
||||
%new_var = "tf._UnknownOp0_"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
|
||||
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%new_var, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
[%newvar, %arg6] as %arg32: tensor<*x!tf.resource<tensor<f32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// %arg30 is used in the cond function, and %arg31 is not pass-through of
|
||||
// while inputs, so neither should be formatted.
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %2#1)
|
||||
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
|
||||
// %arg32 is not a pass-through.
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %2#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
// CHECK-LABEL: func @while_cond_7550
|
||||
func @while_cond_7550(%arg0: tensor<i32>,
|
||||
%arg1: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg2: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg3: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"})
|
||||
%arg4: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>> {tf.device = "/device:TPU:1"},
|
||||
%arg5: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:0"},
|
||||
%arg6: tensor<*x!tf.resource<tensor<f32>>> {tf.device = "/device:TPU:1"})
|
||||
-> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.GreaterEqual"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
|
@ -891,35 +891,3 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of launch_func on TPU with replication with multiple logical cores.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_tpu_launch_func_with_multiple_logical_cores
|
||||
func @replicated_tpu_launch_func_with_multiple_logical_cores(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
// CHECK-SAME: n = 2
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
|
||||
// CHECK: "tf_device.parallel_execute"()
|
||||
// CHECK-NEXT: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
|
||||
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[ARG_1]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: %[[LAUNCH:[0-9]*]] = "tf_device.launch"() ( {
|
||||
// CHECK-NEXT: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[ARG_1]], %[[COMPILE_OUTPUT]]#1)
|
||||
%2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
%2 = "tf.C"(%1#1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
// `tf_device.launch` with equivalent `tf_device.launch_func` operations.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
// Here, the element type can be any integer or float type. But, note that only
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Visitors.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
// This transformation pass transforms functional control flow operations in the
|
||||
// standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user