Merge pull request #5 from tensorflow/master

Sync master with tensorflow
This commit is contained in:
andrewstevens-infineon 2020-05-27 15:23:09 +02:00 committed by GitHub
commit d221e82eb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1169 changed files with 85849 additions and 35354 deletions

View File

@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Enable using platform specific build settings
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14

87
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,87 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
* For TF-GPU - See point 1
* For TF-CPU - See point 2
-----------------------------------------------------------------------------------------------
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
Make sure you are using compatible TF and CUDA versions.
Please refer following TF version and CUDA version compatibility table.
| TF | CUDA |
| :-------------: | :-------------: |
| 2.1.0 - 2.2.0 | 10.1 |
| 1.13.1 - 2.0 | 10.0 |
| 1.5.0 - 1.12.0 | 9.0 |
* If you have above configuration and using _**Windows**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
-----------------------------------------------------------------------------------------------
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
* Try Google Colab to use TensorFlow.
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
* All you need is a good internet connection and you are all set.
* Try to build TF from sources by changing CPU optimization flags.
*Please let us know if this helps.*
windows_comment: >
From the stack trace it looks like you are hitting windows path length limit.
* Try to disable path length limit on Windows 10.
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
Please let us know if this helps.

View File

@ -142,6 +142,7 @@ Build Type | Status
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)

View File

@ -1,3 +1,41 @@
# Release 2.3.0
## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
# Release 2.0.2
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 1.15.3
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 2.2.0
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).

View File

@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed
to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it
As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.

View File

@ -114,6 +114,14 @@ http_archive(
],
)
http_archive(
name = "person_detect_data",
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -1368,8 +1368,13 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
try:
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e:
print("Error checking bazel version: ", e.output.decode('UTF-8').strip())
raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1387,7 +1392,6 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'

View File

@ -524,7 +524,10 @@ package_group(
],
)
package_group(name = "ndarray_tensor_allow_list")
package_group(
name = "ndarray_tensor_allow_list",
packages = ["//learning/pathways/..."],
)
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.

View File

@ -216,6 +216,7 @@ tf_cuda_library(
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [
@ -394,8 +395,14 @@ tf_cuda_library(
deps = [
":tf_status",
":tf_status_internal",
"//tensorflow/core:lib",
],
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
tf_cc_test(

View File

@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
return ret;
}
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
// Reduce GPU memory allocation, and set appropriate config options for TFE
// context.
auto* config = TF_CreateConfig(
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
10);
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
if (!status->status.ok()) {
CHECK(!config);
TFE_DeleteContextOptions(opts);
return nullptr;
}
auto* ctx = TFE_NewContextFromSession(opts, session, status);
TF_DeleteBuffer(config);
TFE_DeleteContextOptions(opts);
return ctx;
}
// TODO: retrieve the device string via TFE_ContextListDevices()
static const char DEFAULT_CPU_DEVICE[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
int tensor_id, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
shared_name.size());
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
// TODO: consider making this an unknown shape.
const int64_t* dims_ptr = nullptr;
int num_dims = 0;
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
/*num_values*/ 0, status);
if (!status->status.ok()) return nullptr;
int num_retvals = 1;
TFE_TensorHandle* queue = nullptr;
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return queue;
}
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return;
TFE_OpAddInput(op, tensor, status);
if (!status->status.ok()) return;
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
int num_retvals = 0;
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
if (!status->status.ok()) return;
CHECK_EQ(num_retvals, 0);
}
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
TF_DataType inputType,
TFE_TensorHandle* queue,
TF_Status* status) {
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
if (!status->status.ok()) return nullptr;
TFE_OpAddInput(op, queue, status);
if (!status->status.ok()) return nullptr;
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
TFE_OpSetAttrInt(op, "timeout_ms", -1);
TFE_TensorHandle* ret;
int num_retvals = 1;
TFE_Execute(op, &ret, &num_retvals, status);
if (!status->status.ok()) return nullptr;
CHECK_EQ(num_retvals, 1);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
assert(session);
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TF_DataType inputType,
TF_Status* status) {
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
return ret;
}
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
assert(session);
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status) {
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, inputType, queue, tensor, status);
}
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
TFE_TensorHandle* tensor, TF_Status* status) {
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
}
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
TF_Status* status) {
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
auto ctx = TFE_CreateContextFromSession(session, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
ctx, TFE_DeleteContext);
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
if (!status->status.ok()) return nullptr;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
queue_deleter(queue, TFE_DeleteTensorHandle);
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(
(*iter).c_str(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values),
num_values));
}
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,

View File

@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
// Create a serialized tensorflow.ServerDef proto.
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
// Creates from `session` a new eager context to run a graph function or
// sends/recvs, so that these concurrent TFE executions can share (via
// `session` and its associated device mgr) the same set of fifo queue resource
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
// graph function execution can access the same fifo queue resource handles
// (associated with devices managed by the device manager, which can be obtained
// from `session`).
//
// TODO: Remove this function once we migrate away from using session.
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
TF_Session* session, TF_Status* status);
// TODO: Retire this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
TF_Session* session, int tensor_id, TF_DataType inputType,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
TF_Status* status);
// TODO: consider folding the 2 APIs below into the ones above.
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
int tensor_id,
TFE_TensorHandle* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

View File

@ -144,6 +144,24 @@ cc_library(
],
)
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
@ -319,6 +337,7 @@ tf_cuda_cc_test(
tags = [
"noguitar", # TODO(b/155445984): flaky
#"guitar",
"notap", # TODO(b/156981931): flaky
"multi_gpu",
],
deps = [
@ -349,7 +368,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"noasan", # leaks gRPC server instances
"notsan", # b/157098283
],
deps = [
":c_api",
":c_api_experimental",
@ -357,10 +379,13 @@ tf_cuda_cc_test(
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
@ -507,6 +532,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
}
#if !defined(IS_MOBILE_PLATFORM)
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
const tensorflow::ServerDef& server_def) {
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
return false;
}
return server_def.default_session_config().SerializeAsString() ==
context->session_options().config.SerializeAsString();
}
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
const tensorflow::DeviceMgr* device_mgr =
AreLocalDevicesCompatible(context, server_def)
? context->local_device_mgr()
: nullptr;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
server_def, {device_mgr}, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
&curr_remote_workers));
@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
tensorflow::GetDefaultCustomKernelCreator()));
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
TF_Session* sess, TF_Status* status) {
const tensorflow::DeviceMgr* device_mgr = nullptr;
status->status = sess->session->LocalDeviceManager(&device_mgr);
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return tensorflow::wrap(new tensorflow::EagerContext(
opts->session_options.options,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
opts->device_placement_policy),
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
/*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator()));
}
void TFE_DeleteContext(TFE_Context* ctx) {
if (ctx == nullptr) {
return;
@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
status->status = tensorflow::unwrap(ctx)->AsyncWait();
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
TestRemoteExecuteUpdateServerDefWithFailures(true);
}
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
const string first_name =
keep_localhost_for_first_connect ? "localhost" : "abc";
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
tensorflow::Status status2;
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
// Rename local device
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev1_name =
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
// Another renaming of local device
const string second_name = "def";
server_def.set_job_name(second_name);
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
absl::StrCat(second_name, ":",
tensorflow::testing::PickUnusedPortOrDie());
serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle2, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteTensorHandle(var_handle2);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
} // namespace

View File

@ -19,11 +19,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
@ -574,6 +579,181 @@ TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);

View File

@ -58,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
@ -101,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
@ -129,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:

View File

@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace
} // namespace tensorflow

View File

@ -101,6 +101,9 @@ class AbstractContextInterface {
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
delete_function_(delete_function),
rendezvous_builder_(rendezvous_builder) {}
Status NewServer(const ServerDef& server_def,
Status NewServer(const ServerDef& server_def, const Options& options,
std::unique_ptr<ServerInterface>* out_server) override {
TF_RETURN_IF_ERROR(CGrpcServer::Create(
server_def, init_function_, start_function_, stop_function_,

View File

@ -155,6 +155,7 @@ cc_library(
"saved_model_api_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
],
)

View File

@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
if (!status->status.ok()) {
return nullptr;
}
return new TF_SavedModel{std::move(result)};
return tensorflow::wrap(result.release());
}
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
void TF_DeleteSavedModel(TF_SavedModel* model) {
delete tensorflow::unwrap(model);
}
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
const char* function_path,
TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetFunction(function_path, &result);
tensorflow::unwrap(model)->GetFunction(function_path, &result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
tensorflow::ConcreteFunction* result = nullptr;
tensorflow::Status get_function_status =
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
&result);
status->status.Update(get_function_status);
if (!get_function_status.ok()) {
return nullptr;
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
}
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
return new TF_ConcreteFunctionList{
tensorflow::unwrap(model)->ListFunctions()};
}
} // end extern "C"

View File

@ -18,13 +18,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
// Internal structures used by the SavedModel C API. These are likely to change
// and should not be depended on.
struct TF_SavedModel {
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
};
typedef struct TF_SavedModel TF_SavedModel;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_

View File

@ -20,7 +20,7 @@ load(
"tf_cc_test",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
def tf_library(
name,
@ -42,7 +42,8 @@ def tf_library(
mlir_components = "None",
deps = None,
tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
math enabled on cpu.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
@ -187,7 +188,9 @@ def tf_library(
# `find` on such an object.
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
flags = tfcompile_extra_flags() + flags
target_cpu = tfcompile_target_cpu()
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
flags = extra_flags + flags
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
@ -207,6 +210,15 @@ def tf_library(
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
default_fast_math_xla_flags = ("XLA_FLAGS='" +
"--xla_cpu_enable_fast_math=true " +
"--xla_cpu_fast_math_honor_nans=false " +
"--xla_cpu_fast_math_honor_infs=false " +
"--xla_cpu_fast_math_honor_functions=false " +
"--xla_cpu_fast_math_honor_division=false " +
"--xla_cpu_enable_fast_min_max=true " +
"$${XLA_FLAGS:-}' ")
native.genrule(
name = ("gen_" + name),
srcs = srcs,
@ -216,6 +228,7 @@ def tf_library(
function_object_file,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
@ -256,6 +269,7 @@ def tf_library(
session_module_pb,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +

View File

@ -67,6 +67,8 @@ int main(int argc, char** argv) {
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
// generally the preferred case.
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::AppendDebugOptionsFlags(&flag_list);

View File

@ -505,6 +505,7 @@ cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],
hdrs = ["shape_inference.h"],
visibility = [":friends"],
deps = [
":shape_inference_helpers",
"//tensorflow/compiler/xla:statusor",

View File

@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorArraySplitV3",
"TensorArrayV3",
"TensorArrayWriteV3",
"TensorListConcatV2",
"TensorListElementShape",
"TensorListFromTensor",
"TensorListGather",
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
"TensorListPushBack",
"TensorListReserve",
"TensorListSetItem",
"TensorListSplit",
"TensorListStack",
"TensorScatterAdd",
"TensorScatterSub",

View File

@ -104,6 +104,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,

View File

@ -260,6 +260,41 @@ cc_library(
],
)
cc_library(
name = "tftext_utils",
srcs = [
"utils/tftext_utils.cc",
],
hdrs = [
"utils/tftext_utils.h",
],
copts = ["-std=c++14"],
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "stateful_ops_utils",
srcs = [
@ -320,6 +355,7 @@ cc_library(
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
":tftext_utils",
":validators",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",

View File

@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
std::string node_def_str;
if (!node_def.SerializeToString(&node_def_str)) {
return emitError(loc, "failed to serialize tensorflow node_def"),
llvm::None;
}
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
return builder_.CreateVector(flex_builder->GetBuffer());
}
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();
for (const auto& pair : node_def.attr()) {
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
std::sort(attrs.begin(), attrs.end(),
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
for (const Item& pair : attrs) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
const ::tensorflow::AttrValue& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s());

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -254,6 +254,14 @@ class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
class TFL_OperandIsNoneOrHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
PredOpTrait<"operand " # n # " is at most " # m # "-D",
Or<[
@ -285,14 +293,18 @@ def TFL_FloatNonNegative : AttrConstraint<
CPred<"!$_self.cast<FloatAttr>().getValue().isNegative()">,
"whose value is non-negative">;
def TFL_BoolTrue: AttrConstraint<
def TFL_BoolTrue : AttrConstraint<
CPred<"$_self.cast<BoolAttr>().getValue()">,
"whose value is true">;
def TFL_BoolFalse: AttrConstraint<
def TFL_BoolFalse : AttrConstraint<
CPred<"!$_self.cast<BoolAttr>().getValue()">,
"whose value is false">;
class TFL_StringEqualsTo<string value> : AttrConstraint<
CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">,
"whose value equals to '" # value # "'">;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
TCOpResIsShapedTypePred<i, j>,
@ -1892,7 +1904,10 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
let hasOptions = 1;
}
def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> {
def TFL_RoundOp: TFL_Op<"round", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType]> {
let summary = "Round operator";
let description = [{
@ -1909,7 +1924,14 @@ Rounds the values of a tensor to the nearest integer, element-wise.
}
def TFL_SliceOp : TFL_Op<"slice", [
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultsScale,
TFL_OperandHasRankAtMost<0, 4>,
TFL_OperandHasRankAtMost<1, 1>,
TFL_OperandHasRankAtMost<2, 1>,
TFL_GpuTargetOp]> {
let summary = "Return a slice from 'input'.";
let description = [{
@ -1927,13 +1949,13 @@ equivalent to setting:
}];
let arguments = (ins
AnyTensor:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32OrI64Tensor:$begin,
TFL_I32OrI64Tensor:$size
);
let results = (outs
AnyTensor:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
);
let verifier = [{ return Verify(*this); }];
@ -1961,7 +1983,10 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
}
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
NoSideEffect, SameOperandsAndResultsScale]> {
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultsScale]> {
let summary = "Min-reduction operator";
let description = [{
@ -1969,19 +1994,23 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
}];
let arguments = (ins
AnyTensor:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
NoSideEffect, SameOperandsAndResultsScale]> {
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultsScale]> {
let summary = "Max-reduction operator";
let description = [{
@ -1989,18 +2018,22 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
}];
let arguments = (ins
AnyTensor:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect]> {
let summary = "Prod-reduction operator";
let description = [{
@ -2008,12 +2041,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
}];
let arguments = (ins
TFL_TensorOf<[F32, I8, I32, I64]>:$input,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let results = (outs
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
@ -2308,10 +2342,13 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
let hasFolder = 1;
}
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
def TFL_ReluOp: TFL_Op<"relu", [
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Relu operator";
let description = [{
@ -2319,9 +2356,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
x -> max(0, x)
}];
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
@ -2335,10 +2372,13 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
];
}
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
def TFL_Relu6Op: TFL_Op<"relu6", [
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
let summary = "Relu6 operator";
let description = [{
@ -2346,9 +2386,9 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
x -> max(0, min(6, x))
}];
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
@ -2362,9 +2402,12 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
];
}
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale]> {
let summary = "Relu1 operator";
let description = [{
@ -2372,9 +2415,9 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
x -> max(-1, min(1, x))
}];
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
// This builder doesn't work with quantized type, so it can only be used by
// non-quantization tablegen patterns. Currently, it is used by the
@ -2406,7 +2449,11 @@ def TFL_ReshapeOp: TFL_Op<"reshape", [
let hasFolder = 1;
}
def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [NoSideEffect]> {
def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
TFL_OperandHasRank<1, 1>]> {
let summary = "Reverses variable length slices.";
let description = [{
@ -2423,15 +2470,15 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
}];
let arguments = (ins
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input,
TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$input,
TFL_I32OrI64Tensor:$seq_lengths,
I32Attr:$seq_dim,
I32Attr:$batch_dim
Confined<I32Attr, [IntNonNegative]>:$seq_dim,
Confined<I32Attr, [IntNonNegative]>:$batch_dim
);
let results = (outs
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output
TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$output
);
let hasOptions = 1;
@ -2439,6 +2486,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
SameOperandsAndResultType,
SameOperandsAndResultShape,
NoQuantizableResult,
TFL_GpuTargetOp]> {
let summary = "Reciprocal of square root operator";
@ -2463,7 +2511,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[I32, I64]>:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult().getType().cast<TensorType>().getElementType();
@ -2472,9 +2520,11 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let hasOptions = 1;
}
// TODO(jpienaar): Flesh this out.
def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>,
TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>,
def TFL_RangeOp: TFL_Op<"range", [
NoSideEffect,
TFL_OperandHasRank<0, 0>,
TFL_OperandHasRank<1, 0>,
TFL_OperandHasRank<2, 0>,
PredOpTrait<"operands and output must have same element type",
And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>,
TCresVTEtIsSameAsOp<0, 2>]>>]> {
@ -2486,17 +2536,20 @@ def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>,
}];
let arguments = (ins
AnyTensor:$start,
AnyTensor:$limit,
AnyTensor:$delta);
TFL_TensorOf<[I32, F32]>:$start,
TFL_TensorOf<[I32, F32]>:$limit,
TFL_TensorOf<[I32, F32]>:$delta);
let results = (outs AnyTensor:$result);
let results = (outs TFL_TensorOf<[I32, F32]>:$result);
let hasFolder = 1;
}
def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
[NoSideEffect, TFL_OperandHasRank<1,1>]> {
def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
TFL_OperandHasRank<1, 1>]> {
let summary = "ReverseV2 Operator";
let description = [{
@ -2518,18 +2571,18 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
let arguments = (
ins
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
TFL_TensorOf<[I32, I64]>:$axis
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input,
TFL_I32Tensor:$axis
);
let results = (outs
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
);
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output);
}
// Select has many instances in TF models where one or more of its operands
// are unranked. Therefore, we skip adding shape constraints here.
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
def TFL_SelectOp : TFL_Op<"select", [
NoSideEffect,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> {
@ -2545,9 +2598,11 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let arguments = (ins
TFL_BoolTensor:$condition,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
let results = (outs AnyTensor:$output);
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
let results = (outs
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
// TODO(jpienaar): autogenerate this.
let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, "
@ -2561,7 +2616,12 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let hasOptions = 1;
}
def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
def TFL_SelectV2Op : TFL_Op<"select_v2", [
NoSideEffect,
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "SelectV2 operator";
let description = [{
@ -2574,9 +2634,11 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
let arguments = (ins
TFL_BoolTensor:$condition,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
let results = (outs AnyTensor:$output);
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
let results = (outs
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, "
"Value cond, Value x, Value y",
@ -2605,9 +2667,11 @@ def TFL_SinOp: TFL_Op<"sin", [
let hasFolder = 1;
}
// TODO(b/130643170): Adds some constraint for the input/output element types.
def TFL_SoftmaxOp : TFL_Op<"softmax", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankRange<0, 1, 4>,
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
@ -2623,11 +2687,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
}];
let arguments = (
ins AnyTensor:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
F32Attr:$beta
);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
let hasOptions = 1;
}
@ -2914,6 +2978,7 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
NoSideEffect,
SameOperandsAndResultsScale,
TFL_OperandHasRankRange<0, 3, 4>,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
@ -2924,13 +2989,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
TFL_TensorOf<[I32]>:$block_shape,
TFL_TensorOf<[I32]>:$paddings
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$block_shape,
TFL_I32Tensor:$paddings
);
let results = (outs
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output
);
}
@ -3045,7 +3110,12 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
}
def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
NoSideEffect, SameOperandsAndResultsScale]> {
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRank<0, 4>,
TFL_OperandHasRank<1, 1>,
SameOperandsAndResultsScale]> {
let summary = "ResizeBilinear Op";
let description = [{
@ -3053,23 +3123,26 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
}];
let arguments = (ins
// TODO(ycling): Support quantized types.
TFL_TensorOf<[F32, I32, QI8, QUI8]>:$input,
TFL_TensorOf<[I32]>:$size,
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
TFL_I32Tensor:$size,
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TFL_TensorOf<[F32, QI8, QUI8]>:$output
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output
);
let hasOptions = 1;
}
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
[NoSideEffect,
SameOperandsAndResultsScale]> {
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRank<0, 4>,
TFL_OperandHasRank<1, 1>,
SameOperandsAndResultsScale]> {
let summary = "ResizeNearestNeighbor Op";
let description = [{
@ -3077,14 +3150,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
}];
let arguments = (ins
TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input,
TFL_TensorOf<[I32]>:$size,
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
TFL_I32Tensor:$size,
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs
TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
);
let hasOptions = 1;
@ -3349,7 +3422,9 @@ def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
}
def TFL_QuantizeOp: TFL_Op<"quantize", [
FirstAttrDerivedResultType, NoQuantizableResult]> {
FirstAttrDerivedResultType,
SameOperandsAndResultShape,
NoQuantizableResult]> {
let summary = "Quantize operator";
let description = [{
@ -3358,11 +3433,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
}];
let arguments = (
ins AnyTensor:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input,
TensorTypeAttr:$qtype
);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[QI8, QUI8, QI16, TFL_Quint8]>:$output);
}
def TFL_DensifyOp: TFL_Op<"densify", [
@ -3472,6 +3547,19 @@ def TFL_LSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
TFL_OperandHasRank<2, 2>, // input_to_forget_weights
TFL_OperandHasRank<3, 2>, // input_to_cell_weights
TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights
TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights
TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights
TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights
TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights
TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights
TFL_OperandHasRank<13, 1>, // forget_gate_bias
TFL_OperandHasRank<14, 1>, // cell_gate_bias
TFL_OperandHasRank<15, 1>, // output_gate_bias
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
TFL_StatefulOp]> {
let summary = "The full lstm operator";
@ -3498,23 +3586,23 @@ Ba et al. 'Layer Normalization'
ins TFL_TensorOf<[F32, QI8]>:$input,
// Weights
TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
// Recurrent weights
TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
// Cell weights
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights,
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_input_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights,
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_forget_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_output_weights,
// Bias
TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
@ -3523,7 +3611,7 @@ Ba et al. 'Layer Normalization'
TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
// Projection weight and bias
TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
// Optional input
TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
@ -3539,8 +3627,8 @@ Ba et al. 'Layer Normalization'
// Attributes
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
// Since this op is the FULL kernel only, constrain it.
Confined<
DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
@ -3580,6 +3668,24 @@ def TFL_UnidirectionalSequenceLSTMOp :
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
TFL_OperandHasRankAtLeast<0, 2>, // input
TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights
TFL_OperandHasRank<2, 2>, // input_to_forget_weights
TFL_OperandHasRank<3, 2>, // input_to_cell_weights
TFL_OperandHasRank<4, 2>, // input_to_output_weights
TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights
TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights
TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights
TFL_OperandHasRank<8, 2>, // recurrent_to_output_weights
TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights
TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights
TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights
TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias
TFL_OperandHasRank<13, 1>, // forget_gate_bias
TFL_OperandHasRank<14, 1>, // cell_gate_bias
TFL_OperandHasRank<15, 1>, // output_gate_bias
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias
TFL_StatefulOp]> {
let summary = "Unidirectional sequence lstm operator";
@ -3595,35 +3701,35 @@ def TFL_UnidirectionalSequenceLSTMOp :
}];
let arguments = (
ins TFL_TensorOf<[F32, I8]>:$input,
ins TFL_FpTensor:$input,
// Weights
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
TFL_TensorOf<[F32, I8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$input_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
// Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
// Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_input_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_forget_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_output_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
TFL_TensorOf<[F32]>:$forget_gate_bias,
TFL_TensorOf<[F32]>:$cell_bias,
TFL_TensorOf<[F32]>:$output_gate_bias,
TFL_FpTensor:$forget_gate_bias,
TFL_FpTensor:$cell_bias,
TFL_FpTensor:$output_gate_bias,
// Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
// Optional input
TFL_TensorOfOrNone<[F32]>:$projection_bias,
@ -3632,19 +3738,19 @@ def TFL_UnidirectionalSequenceLSTMOp :
TFL_StatefulTensor:$input_cell_state,
// Layer norm coefficients
TFL_TensorOfOrNone<[F32, I8]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$output_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI8]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI8]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI8]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, QI8]>:$output_layer_norm_coefficients,
// Attributes
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
BoolAttr:$time_major
);
let results = (outs AnyTensor:$output);
let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
let hasOptions = 1;
@ -3841,15 +3947,14 @@ def TFL_BidirectionalSequenceLSTMOp :
}];
}
def RnnResultConstraint : PredOpTrait<
"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>;
// UnidirectionalSequenceRNN op.
def TFL_UnidirectionalSequenceRNNOp :
TFL_Op<"unidirectional_sequence_rnn",
[RnnResultConstraint, TFL_StatefulOp]> {
def TFL_UnidirectionalSequenceRNNOp : TFL_Op<"unidirectional_sequence_rnn", [
TFL_OperandHasRank<4, 2>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
PredOpTrait<"input and constant value operands must have same element type",
TFL_TCopVTEtAreSameAt<1, 2>>,
TFL_StatefulOp]> {
let summary = "Unidirectional sequence rnn operator";
let description = [{
@ -3866,16 +3971,16 @@ def TFL_UnidirectionalSequenceRNNOp :
}];
let arguments = (
ins TFL_TensorOf<[F32, I8]>:$input,
ins TFL_FpTensor:$input,
// Weights
TFL_TensorOf<[F32, I8]>:$input_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$input_to_input_weights,
// Recurrent weights
TFL_TensorOf<[F32, I8]>:$recurrent_to_input_weights,
TFL_TensorOf<[F32, QI8]>:$recurrent_to_input_weights,
// Bias
TFL_TensorOf<[F32]>:$input_gate_bias,
TFL_FpTensor:$input_gate_bias,
// Hidden state.
TFL_StatefulTensor:$hidden_state,
@ -3885,7 +3990,7 @@ def TFL_UnidirectionalSequenceRNNOp :
TFL_AFAttr:$fused_activation_function
);
let results = (outs TFL_TensorOf<[F32, I8]>:$output);
let results = (outs TFL_FpTensor:$output);
let hasOptions = 1;
@ -3941,14 +4046,12 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
let results = (outs);
}
def SVDFResultConstraint: PredOpTrait<
"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>;
// SVDF op.
def TFL_SVDFOp :
TFL_Op<"svdf",
[SVDFResultConstraint, TFL_StatefulOp]> {
TFL_Op<"svdf", [
PredOpTrait<"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>,
TFL_StatefulOp]> {
let summary = "Single value decomposition filter operator";
@ -3960,13 +4063,13 @@ def TFL_SVDFOp :
}];
let arguments = (
ins TFL_TensorOf<[F32, I8]>:$input,
ins TFL_TensorOf<[F32, QI8]>:$input,
// Feature Weights.
TFL_TensorOf<[F32, I8]>:$feature_weights,
TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights,
// Time weights
TFL_TensorOf<[F32, I8]>:$time_weights,
TFL_TensorOf<[F32, QI8]>:$time_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
@ -3975,11 +4078,11 @@ def TFL_SVDFOp :
TFL_StatefulTensor:$activation_state,
// Attributes
I32Attr:$rank,
Confined<I32Attr, [IntPositive]>:$rank,
TFL_AFAttr:$fused_activation_function
);
let results = (outs TFL_TensorOf<[F32, I8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
let hasOptions = 1;
@ -3991,7 +4094,10 @@ def TFL_SVDFOp :
}];
}
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "SegmentSum operator";
let description = [{
@ -3999,7 +4105,7 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
}];
let arguments = (ins
TFL_TensorOf<[F32, I32]>:$data,
TFL_TensorOf<[F32, I32]>:$input,
TFL_I32Tensor:$segment_ids
);
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
@ -4030,7 +4136,7 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
input: A list of input tensors whose types are T.
output: A list of output tensors whose types are T.
cond: A region takes 'input' and returns a boolean scalar tensor.
cond: A region that takes 'input' and returns a boolean scalar tensor.
body: A region that takes a list of tensors and returns another
list of tensors. Both lists have the same types.
}];

View File

@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_STRING;
case toco::IODataType::BOOL:
return DT_BOOL;
case toco::IODataType::COMPLEX64:
return DT_COMPLEX64;
default:
return DT_INVALID;
}
@ -252,7 +254,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
return errors::InvalidArgument("Failed to open file in ", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -35,6 +36,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
}
};
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
template <typename RQ>
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
explicit FoldTrivalRequantizeOp(MLIRContext* context)
: OpRewritePattern<RQ>(context, 1) {}
LogicalResult matchAndRewrite(RQ op,
PatternRewriter& rewriter) const override {
Value pre_quantized = op.input();
auto pre_quantized_type =
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
if (!pre_quantized_type) return failure();
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
new_output_types.push_back(result.getType());
}
}
// Remove this rescale op.
rewriter.replaceOp(op, {pre_quantized});
// Replace the output scale of the preceding op.
rewriter.setInsertionPointAfter(def);
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
def->getOperands(), new_output_types,
def->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
rewriter.replaceOp(def, new_op->getResults());
return success();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned.

View File

@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE]]
}
// Checks that tfl.reshape should be removed if its output has more than one
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1
// CHECK: return %2
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
// CHECK: return %[[RESULT]]
}
// Checks that tfl.reshape should be kept if its output has more than one
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %0, %1
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
}
// Checks that tfl.reshape should be removed if its output type is the same

View File

@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_int
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_splat_float
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_float
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_same_shape
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK-LABEL: @rank_input_known_rank
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
// CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: return %[[CST]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32>
}
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
%cst_1 = constant dense<4> : tensor<i32>
%cst_2 = constant dense<1> : tensor<i32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
%cst_1 = constant dense<4.0> : tensor<f32>
%cst_2 = constant dense<1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
%cst_1 = constant dense<-4.0> : tensor<f32>
%cst_2 = constant dense<-1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
%cst_1 = constant dense<7.0> : tensor<f32>
%cst_2 = constant dense<1.5> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
return %0 : tensor<4x2x3xi32>
}
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %87 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
return %2 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concat_2_tensors_1_empty
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return [[cst]] : tensor<2xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return %[[CST]] : tensor<2xi32>
}
// CHECK-LABEL: @concat_3_tensors_1_empty
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %0 : tensor<?xi32>
}
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
return %0 : tensor<2x2x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsMiddleDim
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
return %0 : tensor<1x4x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsLastDim
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
return %0 : tensor<1x2x6xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_different_rank
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}

View File

@ -0,0 +1,101 @@
# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s
node {
name: "Placeholder"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
node {
name: "Placeholder_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 3
}
dim {
size: 7
}
}
}
}
}
node {
name: "MatMul"
op: "BatchMatMulV2"
input: "Placeholder"
input: "Placeholder_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "adj_x"
value {
b: false
}
}
attr {
key: "adj_y"
value {
b: false
}
}
}
versions {
producer: 175
}
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK: %[[VAL_5:.*]] = constant unit
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32>
# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32>
# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32>
# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32>
# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32>
# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32>
# CHECK: }

View File

@ -1,15 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// Ensure lstm roundtrip exactly
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
}

View File

@ -0,0 +1,14 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure
module {
func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
%0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>)
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<?xi64>
%2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<?xi64>)
return %2#0, %2#1 : tensor<?x!tf.string>, tensor<?xi64>
}
// CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
}

View File

@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ],

View File

@ -1,6 +1,6 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
@ -72,21 +72,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 10,
// CHECK-NEXT: name: "arg9",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 11,
// CHECK-NEXT: name: "arg10",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 12,
// CHECK-NEXT: name: "arg11",
// CHECK-NEXT: quantization: {
@ -100,21 +100,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 14,
// CHECK-NEXT: name: "arg13",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 15,
// CHECK-NEXT: name: "arg14",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 16,
// CHECK-NEXT: name: "arg15",
// CHECK-NEXT: quantization: {
@ -128,7 +128,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 1, 4 ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
@ -261,9 +261,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
// CHECK-EMPTY:
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}

View File

@ -1,6 +1,6 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
@ -9,63 +9,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "arg2",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "arg3",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "arg4",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "arg5",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "arg6",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "arg7",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 9,
// CHECK-NEXT: name: "arg8",
// CHECK-NEXT: quantization: {
@ -121,63 +121,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 17,
// CHECK-NEXT: name: "arg16",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 18,
// CHECK-NEXT: name: "arg17",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 19,
// CHECK-NEXT: name: "arg18",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 20,
// CHECK-NEXT: name: "arg19",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 21,
// CHECK-NEXT: name: "arg20",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 22,
// CHECK-NEXT: name: "arg21",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: name: "Const1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: },
// CHECK-NEXT: is_variable: true
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: buffer: 25,
// CHECK-NEXT: name: "tfl.unidirectional_sequence_lstm",
// CHECK-NEXT: quantization: {
@ -244,9 +244,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
@ -259,9 +259,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
// CHECK-NEXT: }
// CHECK-EMPTY:
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>):
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}

View File

@ -37,7 +37,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: shape: [ 4, 4 ],
// CHECK-NEXT: name: "Const",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
@ -76,7 +76,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
@ -90,7 +90,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
// CHECK-EMPTY:
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>):
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

View File

@ -190,9 +190,9 @@ func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
return %0 : tensor<? x f32>
}
func @testQuantizedResizeNearestNeighbor(tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
func @testQuantizedResizeNearestNeighbor(tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
^bb0(%arg0: tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
@ -581,36 +581,36 @@ func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
// -----
// CHECK-LABEL: testUnidirectionalSequenceRnn
func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x ? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithoutProjection
func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstm
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -663,10 +663,10 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.0372480
// -----
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -689,10 +689,10 @@ func @testQuantizedBasicLstm(%arg0: tensor<1x384x!quant.uniform<u8:f32, 7.812500
// -----
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -707,11 +707,11 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
// -----
// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor.
// test invalid input dimension, the third input operand for lstm op should be 2-D tensor.
func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}}
// expected-error @+1 {{'tfl.lstm' op failed to verify that operand 2 is 2-D}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %24 : tensor<4xf32>
@ -720,22 +720,22 @@ func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4
// -----
// 'input_to_output_weights' input for lstm op has unmatched rank with `input`.
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}
// -----
// Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`.
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
}
@ -1201,7 +1201,7 @@ func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>)
// -----
func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}}
// expected-error @+1 {{'tfl.resize_bilinear' op failed to verify that input and output must have same element type}}
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -1499,8 +1499,8 @@ func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!q
// CHECK-LABEL: testSvdf
func @testSvdf(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@ -439,6 +439,19 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @NotReorderReshapeAddIfHighDim
func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<1x30x96xf32> {
%cst = constant dense<2.0> : tensor<f32>
%shape = constant dense<[1, 30, 96]> : tensor<3xi32>
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x30x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32>
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<f32>) -> tensor<1x30x96xf32>
return %2 : tensor<1x30x96xf32>
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
// CHECK: return %[[rs2]]
}
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
%shape = constant dense<[40, 40]> : tensor<2xi32>

View File

@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
}
// CHECK-LABEL: RemoveTrival
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
// CHECK-NEXT: return %[[fc]]
}
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
%cst = constant dense<[1, 1001]> : tensor<2xi32>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>

View File

@ -1,27 +1,27 @@
// RUN: tf-opt -tfl-split-merged-operands %s | FileCheck %s
func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
// CHECK-LABEL: testSingleLstm
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
// CHECK-LABEL: testMultipleLstms
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}

View File

@ -162,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
if (pass_config.shape_inference) {
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
if (attr.getType().isInteger(/*width=*/32)) {
*attr_i32 = attr;
return success();
}
int64_t value = attr.getInt();
if (value > std::numeric_limits<int>::max() ||
value < std::numeric_limits<int>::min()) {
return failure();
}
*attr_i32 = IntegerAttr::get(
IntegerType::get(/*width=*/32, attr.getContext()), value);
return success();
}
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
// "axis" operand could be a i64 tensor. Resolve it here.
IntegerAttr axis_i32;
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis),
fused_activation_function);
op, output_type, values, axis_i32, fused_activation_function);
return success();
}

View File

@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
// TensorFlow operations that doesn't have operands and results of type
// variant are legal. Here, we don't distinguish between variants encoding
// TensorList or some other type as that information is not available here.
// This constraint should be relaxed to support other variant types in TFLite.
// Partial legalization is used below to still allow ops with variant types
// still.
auto is_legal = [](Operation *op) {
auto is_not_variant = [](Type ty) {
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
ConvertTensorListPushBack, ConvertTensorListReserve,
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile>(context);
return applyFullConversion(func, target, patterns);
return applyPartialConversion(func, target, patterns);
}
void LowerStaticTensorListPass::runOnOperation() {

View File

@ -29,6 +29,10 @@ def ExtractSingleElementAsFloat : NativeCodeCall<
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// Checks if the value has rank at most 'n'.
class HasRankAtMost<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
@ -347,7 +351,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input)]>;
(IsTailOfShape $rhs, $input),
(HasRankAtMost<5> $input),
(HasRankAtMost<5> $rhs)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,

View File

@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
if (!emit_quant_adaptor_ops_) {

View File

@ -41,15 +41,22 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
// The cmd line flag to turn on/off Tf.Text API fusion.
// NOLINTNEXTLINE
static llvm::cl::opt<bool> fuse_tftext(
"tfl-fuse-tftext", llvm::cl::value_desc("bool"),
llvm::cl::desc("Fuse TF.Text API ops when it's true"),
llvm::cl::init(false));
namespace mlir {
namespace TFL {
namespace {
constexpr char kTFAPIImplements[] = "tf.api_implements";
constexpr char kTfTextAPIPRefix[] = "tftext:";
// Abstracts the conversion of the embedded lookup composite function.
class ConvertEmbeddedLookupFunc {
@ -187,6 +194,10 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
OpBuilder builder(func.getBody());
if (failed(ConvertKerasLSTMLayer(func, &builder)))
return signalPassFailure();
} else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) {
if (failed(ConvertTFTextAPI(func, attr.getValue()))) {
return signalPassFailure();
}
}
}

View File

@ -70,6 +70,7 @@ class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() {
if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8;
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
// convert all of them to signed.
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
} else {
// Convert quant stats to uint8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
applyPatternsAndFoldGreedily(func, patterns);

View File

@ -0,0 +1,127 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
constexpr char kTFAPIImplements[] = "tf.api_implements";
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
std::string content = "";
ShapedType type = RankedTensorType::get(
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
return OpaqueElementsAttr::get(
builder->getContext()->getRegisteredDialect("tfl"), type, content);
}
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
return func.getType()
.getInput(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
}
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
return func.getType()
.getResult(idx)
.dyn_cast_or_null<mlir::RankedTensorType>();
}
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
if (func.getNumResults() != 2) {
return failure();
}
if (func.getNumArguments() != 1) {
return failure();
}
auto input_type = getInputType(func, 0);
if (!input_type || input_type.getRank() != 1 ||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
}
auto value_type = getResultType(func, 0);
if (!value_type || value_type.getRank() != 1 ||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
return failure();
}
auto offset_type = getResultType(func, 1);
if (offset_type.getRank() != 1 ||
!offset_type.getElementType().isInteger(64)) {
return failure();
}
return success();
}
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
llvm::StringRef api) {
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
Value text = func.getArgument(0);
auto output_type = func.getType().getResult(0);
auto offset_type = func.getType().getResult(1);
SmallVector<Type, 2> shape = {output_type, offset_type};
ArrayRef<Type> output_types(shape);
OpBuilder builder(func.getBody());
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
ValueRange(text), api,
emptyCustomOption(&builder));
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
return success();
}
} // namespace
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) {
if (api.str() == kWhitespaceTokenizer) {
if (succeeded(VerifyWhitespaceTokenizer(func))) {
return ConvertWhitespaceTokenizer(func, api);
}
}
return failure();
}
} // namespace TFL
} // namespace mlir

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header file defines common utils used by TFLite transformation
// passes to work with op attributes.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api);
} // end namespace TFL
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;

View File

@ -224,7 +224,10 @@ cc_library(
hdrs = [
"ir/tf_attributes.h",
],
deps = ["@llvm-project//mlir:IR"],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
cc_library(
@ -427,6 +430,7 @@ cc_library(
"transforms/parallel_execute_to_islands.cc",
"transforms/promote_resources_to_args.cc",
"transforms/raise_control_flow.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc",
@ -555,8 +559,7 @@ cc_library(
srcs = ["ir/dialect_registration.cc"],
deps = [
":tensorflow",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:Shape",
],
alwayslink = 1,
)
@ -785,6 +788,9 @@ cc_library(
name = "convert_type",
srcs = ["utils/convert_type.cc"],
hdrs = ["utils/convert_type.h"],
visibility = [
"//visibility:public",
],
deps = [
":tensorflow_types",
"//tensorflow/core:framework",
@ -1140,6 +1146,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
@ -1278,6 +1285,7 @@ cc_library(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
@ -1292,6 +1300,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)

View File

@ -0,0 +1,55 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"tf_cuda_library",
"tfe_xla_copts",
)
package(
default_visibility = [":friends"],
licenses = ["notice"], # Apache 2.0
)
package_group(
name = "friends",
packages = ["//tensorflow/..."],
)
tf_cuda_library(
name = "mlir_c_api",
srcs = [
"c_api_unified_experimental_mlir.cc",
],
copts = tf_copts() + tfe_xla_copts(),
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "mlir_c_api_registration",
srcs = ["c_api_unified_experimental_mlir_registration.cc"],
deps = [
":mlir_c_api",
"//tensorflow/c/eager:c_api_unified_internal",
],
alwayslink = 1,
)

View File

@ -0,0 +1,493 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstddef>
#include <memory>
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
namespace mlir {
namespace TF {
using tensorflow::internal::AbstractFunction;
using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dyncast;
using tensorflow::internal::ExecutionContext;
using tensorflow::internal::OutputList;
namespace {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
Type* type) {
Status s = tensorflow::ConvertDataType(dtype, builder, type);
if (s.ok()) *type = UnrankedTensorType::get(*type);
return s;
}
class MlirTensor : public AbstractTensor {
public:
explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {}
Value getValue() { return value_; }
static constexpr AbstractTensorKind kKind = kMlirTensor;
private:
Value value_;
};
class MlirAbstractOp : public AbstractOp {
public:
explicit MlirAbstractOp(MLIRContext* context)
: AbstractOp(kKind), context_(context) {}
void SetOpType(const char* op_type, TF_Status* s) override;
void SetAttrType(const char* attr_name, TF_DataType dtype,
TF_Status* s) override;
void SetOpName(const char* const op_name, TF_Status* s) override;
MLIRContext* GetContext() { return context_; }
Type AddRef(Type type, TF_Status* s);
OperationState* Create(ArrayRef<Value> operands, TF_Status* s);
static constexpr AbstractOpKind kKind = kMlirOp;
private:
MLIRContext* context_;
llvm::StringMap<Attribute> attrs_;
std::unique_ptr<OperationState> state_;
const char* op_name_ = nullptr;
};
// MlirFunction is a thin wrapper over a FuncOp.
class MlirFunction : public AbstractFunction {
public:
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
OwningModuleRef module, FuncOp func)
: AbstractFunction(kKind),
context_(std::move(context)),
module_(std::move(module)),
func_(func) {}
TF_Function* GetTfFunction(TF_Status* s) override;
static constexpr AbstractFunctionKind kKind = kGraphFunc;
private:
std::unique_ptr<MLIRContext> context_;
OwningModuleRef module_;
FuncOp func_;
};
class MlirFunctionContext : public ExecutionContext {
public:
explicit MlirFunctionContext(const char* name)
: ExecutionContext(kKind),
context_(std::make_unique<MLIRContext>()),
builder_(context_.get()) {
// TODO(aminim) figure out the location story here
module_ = ModuleOp::create(builder_.getUnknownLoc());
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
builder_.getFunctionType(llvm::None, llvm::None));
module_->push_back(func_);
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
}
AbstractOp* CreateOperation() override {
return new MlirAbstractOp(context_.get());
}
void ExecuteOperation(AbstractOp* abstract_op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override;
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override;
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override;
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
s->status = tensorflow::errors::Unimplemented(
"Registering graph functions has not been implemented yet.");
}
static constexpr ExecutionContextKind kKind = kMlirContext;
private:
std::unique_ptr<MLIRContext> context_;
OpBuilder builder_;
FuncOp func_;
OwningModuleRef module_;
};
void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) {
if (state_) {
s->status = tensorflow::errors::FailedPrecondition(
"SetOpType called on already built op.");
return;
}
std::string name = "tf.";
name += op_type;
// TODO(aminim) figure out the location story here
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
}
void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype,
TF_Status* s) {
if (!state_) {
s->status = tensorflow::errors::FailedPrecondition(
"op_type must be specified before specifying attrs.");
return;
}
Type mlir_type;
Builder builder(context_);
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
builder, &mlir_type);
if (!s->status.ok()) return;
attrs_[attr_name] = TypeAttr::get(mlir_type);
}
void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) {
// TODO(aminim): should we use a location?
if (op_name_) {
s->status = tensorflow::errors::FailedPrecondition(
"SetOpName called on already built op.");
return;
}
op_name_ = op_name;
}
Type MlirAbstractOp::AddRef(Type type, TF_Status* s) {
Type elt_type = getElementTypeOrSelf(type);
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
s->status = tensorflow::errors::InvalidArgument(
"Requested reference to a reference type");
return nullptr;
}
elt_type = TensorFlowRefType::get(elt_type);
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
return RankedTensorType::get(tensor_type.getShape(), elt_type);
}
return UnrankedTensorType::get(elt_type);
}
OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
state_->operands = llvm::to_vector<4>(operands);
const tensorflow::OpDef* op_def;
auto node_name = state_->name.getStringRef().drop_front(
TensorFlowDialect::getDialectNamespace().size() + 1);
s->status =
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def);
if (!s->status.ok()) return nullptr;
Builder builder(context_);
// Process operands according to the op_def and infer derived attributes.
int current_operand = 0;
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
if (!input_arg.number_attr().empty()) {
// TODO(b/156122856): we don't support variadic operands.
s->status = tensorflow::errors::Unimplemented(
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
return nullptr;
} else if (!input_arg.type_list_attr().empty()) {
s->status = tensorflow::errors::InvalidArgument(
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
return nullptr;
}
if (current_operand >= operands.size()) {
s->status = tensorflow::errors::InvalidArgument("Missing operand for '",
input_arg.name(), "'");
return nullptr;
}
Type expected_type;
if (input_arg.type() != tensorflow::DT_INVALID) {
s->status =
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type);
if (!s->status.ok()) return nullptr;
if (input_arg.is_ref()) expected_type = AddRef(expected_type, s);
if (!s->status.ok()) return nullptr;
} else {
expected_type = operands[current_operand].getType();
}
if (!input_arg.type_attr().empty()) {
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
}
++current_operand;
}
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
int original_size = state_->types.size();
if (!output_arg.number_attr().empty()) {
// Same type repeated "repeats" times.
Attribute repeats_attr = attrs_[output_arg.number_attr()];
if (!repeats_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.number_attr(),
"' required for output list '", output_arg.name(), "'");
return nullptr;
}
if (!repeats_attr.isa<IntegerAttr>()) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.number_attr(),
"' required for output list '", output_arg.name(),
"' isn't an integer");
return nullptr;
}
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
if (!output_arg.type_attr().empty()) {
// Same type repeated "repeats" times.
Attribute attr = attrs_[output_arg.type_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_attr(), "' required for output '",
output_arg.name(), "' isn't a type attribute");
return nullptr;
}
for (int i = 0; i < repeats; ++i)
state_->types.push_back(type_attr.getType());
} else if (output_arg.type() != tensorflow::DT_INVALID) {
for (int i = 0; i < repeats; ++i) {
Type type;
s->status =
ConvertDataTypeToTensor(output_arg.type(), builder, &type);
if (!s->status.ok()) return nullptr;
state_->types.push_back(type);
}
} else {
s->status = tensorflow::errors::InvalidArgument(
"Missing type or type_attr field in ",
output_arg.ShortDebugString());
return nullptr;
}
} else if (!output_arg.type_attr().empty()) {
Attribute attr = attrs_[output_arg.type_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_attr(), "' required for output '",
output_arg.name(), "' isn't a type attribute");
return nullptr;
}
state_->types.push_back(type_attr.getValue());
} else if (!output_arg.type_list_attr().empty()) {
// This is pointing to an attribute which is an array of types.
Attribute attr = attrs_[output_arg.type_list_attr()];
if (!attr) {
s->status = tensorflow::errors::InvalidArgument(
"Missing attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(), "'");
return nullptr;
}
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
if (!array_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(),
"' isn't an array attribute");
return nullptr;
}
for (Attribute attr : array_attr) {
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
if (!type_attr) {
s->status = tensorflow::errors::InvalidArgument(
"Array Attribute '", output_arg.type_list_attr(),
"' required for output '", output_arg.name(),
"' has a non-Type element");
return nullptr;
}
state_->types.push_back(type_attr.getValue());
}
} else if (output_arg.type() != tensorflow::DT_INVALID) {
Type type;
Builder builder(context_);
s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type);
if (!s->status.ok()) return nullptr;
state_->types.push_back(type);
} else {
s->status = tensorflow::errors::InvalidArgument(
"No type fields in ", output_arg.ShortDebugString());
if (!s->status.ok()) return nullptr;
}
if (output_arg.is_ref()) {
// For all types that were added by this function call, make them refs.
for (Type& type : llvm::make_range(&state_->types[original_size],
state_->types.end())) {
type = AddRef(type, s);
if (!s->status.ok()) return nullptr;
}
}
}
return state_.get();
}
TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
PassManager pm(func_.getContext());
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
// In case of failure, the `diag_handler` converts MLIR errors emitted to
// the MLIRContext into a tensorflow::Status.
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
(void)result;
s->status = diag_handler.ConsumeStatus();
if (!s->status.ok()) return nullptr;
tensorflow::GraphExportConfig configs;
std::unique_ptr<TF_Function> tf_function(new TF_Function);
s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs,
&tf_function->fdef);
return tf_function.release();
}
void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op,
int num_inputs,
AbstractTensor* const* inputs,
OutputList* o, TF_Status* s) {
auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op);
if (mlir_op == nullptr) {
s->status = tensorflow::errors::InvalidArgument(
"Unable to cast AbstractOp to TF_GraphOp.");
return;
}
SmallVector<Value, 8> operands;
for (int i = 0; i < num_inputs; ++i) {
auto* operand = dyncast<MlirTensor>(inputs[i]);
if (!operand) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing eager tensors is not supported yet.");
return;
}
if (operand->getValue().getContext() != context_.get()) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing tensors from other context is not supported.");
return;
}
operands.push_back(operand->getValue());
}
OperationState* state = mlir_op->Create(operands, s);
if (!s->status.ok() || !state) return;
Operation* op = builder_.createOperation(*state);
int num_results = op->getNumResults();
o->outputs.clear();
o->outputs.reserve(num_results);
for (Value result : op->getResults())
o->outputs.push_back(new MlirTensor(result));
}
AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype,
TF_Status* s) {
Type type;
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
builder_, &type);
if (!s->status.ok()) return nullptr;
return new MlirTensor(func_.getBody().front().addArgument(type));
}
AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
TF_Status* s) {
Block& body = func_.getBody().front();
SmallVector<Value, 8> ret_operands;
for (AbstractTensor* output : outputs->outputs) {
auto* operand = dyncast<MlirTensor>(output);
if (!operand) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing eager tensors is not supported yet.");
return nullptr;
}
if (operand->getValue().getContext() != context_.get()) {
s->status = tensorflow::errors::InvalidArgument(
"Capturing tensors from other context is not supported.");
return nullptr;
}
ret_operands.push_back(operand->getValue());
}
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
auto result_types =
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
return new MlirFunction(std::move(context_), std::move(module_), func_);
}
extern "C" {
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
RegisterDialects();
return new MlirFunctionContext(fn_name);
}
}
} // end anonymous namespace
} // end namespace TF
} // end namespace mlir

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
using tensorflow::internal::ExecutionContext;
extern "C" {
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
}
namespace {
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("mlir", MlirTracingFactory);
return true;
}();
} // namespace

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace TFControlFlow {

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect>
tf_device_dialect;
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
tf_saved_model_dialect;
static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect;
} // namespace mlir

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
namespace mlir {
namespace TF {
@ -45,6 +47,28 @@ struct ShapeAttrStorage : public AttributeStorage {
bool unranked = false;
};
// The storage class for FuncAttr.
struct FuncAttrStorage : public AttributeStorage {
using KeyTy = std::pair<Attribute, Attribute>;
explicit FuncAttrStorage(Attribute name, Attribute attrs)
: name(name), attrs(attrs) {}
bool operator==(const KeyTy& key) const { return key == KeyTy(name, attrs); }
static unsigned hashKey(const KeyTy& key) {
return llvm::hash_combine(key.first, key.second);
}
static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
const KeyTy& key) {
return new (allocator.allocate<FuncAttrStorage>())
FuncAttrStorage(key.first, key.second);
}
Attribute name;
Attribute attrs;
};
} // namespace detail
// Get or create a shape attribute.
@ -85,5 +109,24 @@ bool ShapeAttr::hasStaticShape() const {
return true;
}
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
DictionaryAttr attr) {
auto symbol = SymbolRefAttr::get(name, context);
return Base::get(context, AttrKind::FUNC, symbol, attr);
}
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
DictionaryAttr attr) {
return Base::get(context, AttrKind::FUNC, symbol, attr);
}
SymbolRefAttr FuncAttr::GetName() const {
return getImpl()->name.cast<SymbolRefAttr>();
}
DictionaryAttr FuncAttr::GetAttrs() const {
return getImpl()->attrs.cast<DictionaryAttr>();
}
} // namespace TF
} // namespace mlir

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
namespace mlir {
@ -30,6 +31,7 @@ namespace AttrKind {
enum Kind {
FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR,
SHAPE = FIRST_USED_TENSORFLOW_ATTR,
FUNC,
LAST_USED_TENSORFLOW_ATTR,
};
@ -38,6 +40,7 @@ enum Kind {
namespace detail {
struct ShapeAttrStorage;
struct FuncAttrStorage;
} // namespace detail
@ -71,6 +74,33 @@ class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute,
static bool kindof(unsigned kind) { return kind == AttrKind::SHAPE; }
};
// Custom attribute to model AttrValue.value.func (NameAttrList type attribute).
// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a
// DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is
// currently printed and parsed for the following format:
//
// #tf.func<@symbol, {attr = "value"}>
//
// where the first element is the SymbolRefAttr and the second element is the
// DictionaryAttr.
class FuncAttr
: public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage> {
public:
using Base::Base;
static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name,
DictionaryAttr attr);
static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol,
DictionaryAttr attr);
SymbolRefAttr GetName() const;
DictionaryAttr GetAttrs() const;
static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; }
};
} // namespace TF
} // namespace mlir

View File

@ -47,37 +47,6 @@ limitations under the License.
namespace mlir {
namespace tf_executor {
namespace {
// If the given tensor has elements of type with subtypes, then returns a new
// type after dropping subtypes info. Otherwise, returns the original type as
// is.
ShapedType DropTypeSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!subtype_ty) return ty;
Type default_ty = GetDefaultTypeOf(subtype_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType ty) {
Type element_ty = ty.getElementType();
auto ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_ty) return ty;
Type default_ty = GetDefaultTypeOf(ref_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
} // namespace
//===----------------------------------------------------------------------===//
// TF Executor Dialect
@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) {
namespace {
using TF::DropRefType;
using TF::DropTypeSubTypes;
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;

View File

@ -1195,6 +1195,46 @@ subsequent operation and then be optimized away, however.)
}];
}
def TF_CaseOp : TF_Op<"Case", []> {
let summary = [{
An n-way switch statement which calls a single branch function.
}];
let description = [{
An n-way switch statement, implementing the following:
```
switch (branch_index) {
case 0:
output = branches[0](input);
break;
case 1:
output = branches[1](input);
break;
...
case [[nbranches-1]]:
default:
output = branches[nbranches-1](input);
break;
}
```
}];
let arguments = (ins
I32Tensor:$branch_index,
Variadic<TF_Tensor>:$input,
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
}
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Cast x of type SrcT to y of DstT.";
@ -6331,6 +6371,8 @@ If `x` and `y` are reals, this will return the floating-point division.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
@ -7434,9 +7476,15 @@ select(condition, t, e) ==> [[1, 2],
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
let hasCanonicalizer = 1;
let verifier = [{
return Verify(*this);
}];
}
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> {
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "";
let description = [{
@ -8309,7 +8357,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
);
}
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
let summary = "Stops gradient computation.";
let description = [{
@ -10586,6 +10634,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device.";
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrAttr:$key,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
let summary = "An op that receives embeddng activations on the TPU.";

View File

@ -58,6 +58,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/platform/logging.h"
@ -110,7 +111,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
return !type || type.getRank() <= rank;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1;
}
@ -252,6 +252,39 @@ static LogicalResult VerifyTypesCompatibility(
return success();
}
// This is a helper for the Select to SelectV2 canonicalization. The `data` rank
// refers to the rank of `t`/`e` (these two inputs have equal rank; this is
// checked in the verifier).
//
// In most cases, the predicate for Select can be used directly as the predicate
// for SelectV2. However, there is one case that varies, which is when the
// predicate is a tensor and the data is multidimensional. In this case, Select
// op semantics dictate that the predicate tensor length must match the size of
// the first data dimension. This varies from normal broadcasting semantics
// (which are used in SelectV2), so we must reshape the tensor in this case to
// be compatible.
static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc,
Value cond, int data_rank) {
auto cond_tensor = cond.getType().cast<RankedTensorType>();
// Reshape is only needed in the case that the cond rank is 1 (i.e. it is
// a vector) AND t/e rank is > 1.
if (cond_tensor.getRank() != 1 || data_rank <= 1) {
// No reshape necessary. Leave cond as it is.
return cond;
}
// This is the case where a reshape is needed. We want to construct the
// shape [x,1,...1], where x is the value in the pred tensor and the
// length of the shape is equal to data_rank.
SmallVector<int64_t, 8> shape(data_rank, 1);
shape[0] = cond_tensor.getShape().front();
auto new_shape_type =
RankedTensorType::get({data_rank}, builder->getIntegerType(64));
auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape);
auto new_shape = builder->create<ConstOp>(loc, shape_attr);
return builder->create<ReshapeOp>(loc, cond, new_shape);
}
//===----------------------------------------------------------------------===//
// Helper functions detect device capabilities from RuntimeDevices.
//===----------------------------------------------------------------------===//
@ -462,9 +495,10 @@ LogicalResult FoldOperandsPermutation(
namespace {
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
template <typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
template <
typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType();
@ -479,7 +513,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value);
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
@ -496,6 +531,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
return arithmetic_op.x();
}
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
// TODO(chhe): we could fold and add an identity to force the broadcast.
if (result_op_type != rhs_type) {
return {};
}
bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
@ -1256,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) {
if (rank == 1) {
int64_t dim0 = input_ty.getDimSize(0);
if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
return op.emitOpError("requires 1D input of size 4");
if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
return op.emitOpError("requires 1D input of size 4 or size 2");
}
if (rank == 2) {
@ -1620,10 +1661,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "fill op has two operand");
auto type = getType().cast<ShapedType>();
// DenseElementsAttr that is used in this folder only supports int and float
// types.
// TODO(hinsu): Handle complex types once there is a attribute kind for
// complex.
if (!type.getElementType().isIntOrFloat()) return {};
auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
if (!value) return {};
auto type = getType().cast<ShapedType>();
if (type.hasStaticShape())
return DenseElementsAttr::get(type, value.getValue({}));
@ -1774,75 +1821,125 @@ static LogicalResult Verify(GatherV2Op op) {
static LogicalResult Verify(IfOp op) {
auto module = op.getParentOfType<ModuleOp>();
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch());
if (!thenFn)
auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
if (!then_fn)
return op.emitOpError("then_branch refers to an undefined function : ")
<< op.then_branch();
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch());
if (!elseFn)
auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
if (!else_fn)
return op.emitOpError("else_branch refers to an undefined function : ")
<< op.else_branch();
auto thenFuncType = thenFn.getType();
auto elseFuncType = elseFn.getType();
auto then_fn_type = then_fn.getType();
auto else_fn_type = else_fn.getType();
// Non-conditional operands starting with the second operand are passed to
// branches and should be pair-wise compatible with branches' inputs.
unsigned expectedNumInputs = op.getNumOperands() - 1;
if (thenFuncType.getNumInputs() != expectedNumInputs ||
elseFuncType.getNumInputs() != expectedNumInputs)
return op.emitError("branches should have " + Twine(expectedNumInputs) +
unsigned expected_num_inputs = op.getNumOperands() - 1;
if (then_fn_type.getNumInputs() != expected_num_inputs ||
else_fn_type.getNumInputs() != expected_num_inputs)
return op.emitError("branches should have " + Twine(expected_num_inputs) +
" inputs");
for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, thenInputType}))
for (unsigned i = 0; i < expected_num_inputs; ++i) {
auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operand_type, then_input_type}))
return op.emitError(
llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}",
thenInputType, operandType, i));
then_input_type, operand_type, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operandType, elseInputType}))
auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
if (!AreCastCompatible({operand_type, else_input_type}))
return op.emitError(
llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}",
elseInputType, operandType, i));
else_input_type, operand_type, i));
// If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid.
if (!AreCastCompatible({thenInputType, elseInputType}))
if (!AreCastCompatible({then_input_type, else_input_type}))
return op.emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i));
then_input_type, else_input_type, i));
}
// Branches' results should be pair-wise compatible with the op results.
unsigned expectedNumResults = op.getNumResults();
if (thenFuncType.getNumResults() != expectedNumResults ||
elseFuncType.getNumResults() != expectedNumResults)
return op.emitError("branches should have " + Twine(expectedNumResults) +
unsigned expected_num_results = op.getNumResults();
if (then_fn_type.getNumResults() != expected_num_results ||
else_fn_type.getNumResults() != expected_num_results)
return op.emitError("branches should have " + Twine(expected_num_results) +
" results");
for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible({thenResultType, resultType}))
for (unsigned i = 0; i < expected_num_results; ++i) {
auto result_type = op.getResult(i).getType().cast<TensorType>();
auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({then_result_type, result_type}))
return op.emitError(
llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}",
thenResultType, resultType, i));
then_result_type, result_type, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible({elseResultType, resultType}))
auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
if (!AreCastCompatible({else_result_type, result_type}))
return op.emitError(
llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}",
elseResultType, resultType, i));
else_result_type, result_type, i));
}
return success();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(YieldOp op) {
auto parent = op.getParentOp();
// A YieldOp should be contained within an IfRegion op
// (and WhileRegion in future)
if (!isa<IfRegionOp>(parent))
op.emitError() << " expects parent op "
<< "'" << IfRegionOp::getOperationName() << "' but got '"
<< parent->getName().getStringRef() << "'";
return success();
}
//===----------------------------------------------------------------------===//
// IfRegionOp
//===----------------------------------------------------------------------===//
LogicalResult VerifyRegionResults(Operation *op, Region &region,
StringRef region_name) {
auto op_name = op->getName().getStringRef();
// verify that op outputs match yield inputs
YieldOp yield = cast<YieldOp>(region.front().getTerminator());
unsigned expected_num_results = op->getNumResults();
if (yield.getNumOperands() != expected_num_results)
return op->emitError(region_name + " region should have " +
Twine(expected_num_results) + " results");
for (int idx : llvm::seq<int>(0, expected_num_results)) {
auto op_result_type = op->getResult(idx).getType().cast<TensorType>();
auto region_result_type =
yield.getOperand(idx).getType().cast<TensorType>();
if (!AreCastCompatible({region_result_type, op_result_type}))
return op->emitError(llvm::formatv(
"{0} result type {1} is incompatible with {2} "
"result type {3} at index {4}",
region_name, region_result_type, op_name, op_result_type, idx));
}
return success();
}
static LogicalResult Verify(IfRegionOp op) {
if (failed(VerifyRegionResults(op, op.then_branch(), "then")))
return failure();
if (failed(VerifyRegionResults(op, op.else_branch(), "else")))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// InvertOp
//===----------------------------------------------------------------------===//
@ -2408,6 +2505,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<RealDivWithSqrtDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@ -2539,6 +2640,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
return unranked();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SelectToSelectV2>(context);
}
// Verifies a few extra requirements on SelectOp:
// (1) `then` and `else` must have same shape
// (2) At least one of the following must be true:
// (a) `cond` has the same rank as `then` and `else`
// (b) `cond` is a scalar
// (c) `cond` is a vector AND `then` and `else` are non-scalar with their
// first dimension equal to `cond`.
static LogicalResult Verify(SelectOp op) {
auto then_tensor = op.t().getType().cast<TensorType>();
auto else_tensor = op.e().getType().cast<TensorType>();
// Check (1).
if (!AreCastCompatible({then_tensor, else_tensor}))
return op.emitOpError() << "requires t and e have compatible shapes";
// Get data rank (if exists).
int data_rank;
// If data is unranked or data_rank is 0, this will remain -2. Otherwise
// refers to first dimension of then and/or else.
int data_first_dim = -2;
bool then_has_rank = then_tensor.hasRank();
bool else_has_rank = else_tensor.hasRank();
if (then_has_rank && else_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
if (else_tensor.getRank() > 0)
data_first_dim = std::max(
static_cast<int>(else_tensor.getShape().front()), data_first_dim);
} else if (then_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
} else if (else_has_rank) {
data_rank = else_tensor.getRank();
if (else_tensor.getRank() > 0)
data_first_dim = else_tensor.getShape().front();
} else {
// Neither has a rank.
return success();
}
auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
if (!cond_tensor) return success();
auto cond_rank = cond_tensor.getRank();
// Check (2a) and (2b).
if (cond_rank == 0 || cond_rank == data_rank) return success();
// Check (2c).
if (cond_rank == 1) {
auto cond_shape = cond_tensor.getShape().front();
if (data_rank == 0) {
return op.emitOpError()
<< "requires that t and e are nonscalar when pred is a vector";
}
// We know `data` tensor has a rank of at least 1.
if (data_first_dim != -1 && cond_shape != -1 &&
data_first_dim != cond_shape) {
return op.emitOpError() << "requires that, when pred is a vector, the "
"shape matches the first dimension of t and e";
}
return success();
}
// None of (2a,b,c) were true; fail.
return op.emitOpError() << "requires that pred is a scalar OR has the same "
"rank as t and e OR is a vector";
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//
@ -2598,9 +2774,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
<< variadic_idx_str << " to match rank of operand"
<< variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, verify that the result is dynamic.
return op->emitOpError("requires dynamic shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
// The operand is an unranked tensor, print a warning if the result
// is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
}
Type element_type = result_ranked_type.getElementType();
@ -3551,12 +3730,20 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) {
if (!const_perm) return {};
auto const_value = const_perm.value();
const auto &elements = const_value.getValues<APInt>();
const auto elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) return {};
}
// TODO(jpienaar): Remove if/when we handle this more generally.
if (op.getType() != op.x().getType()) {
// If the types don't match then only fold if all the operands are in the TF
// dialect.
for (auto user : op.getOperation()->getUsers())
if (user->getDialect() != op.getDialect()) return {};
}
return op.x();
}
@ -3700,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult Verify(WhileOp op) {
auto module = op.getParentOfType<ModuleOp>();
auto condFn = module.lookupSymbol<FuncOp>(op.cond());
auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
if (!condFn) {
auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
auto body_fn = module.lookupSymbol<FuncOp>(op.body());
if (!cond_fn) {
return op.emitOpError("cond refers to an undefined function : ")
<< op.cond();
}
if (!bodyFn) {
if (!body_fn) {
return op.emitOpError("body refers to an undefined function : ")
<< op.body();
}
auto condFuncType = condFn.getType();
auto bodyFuncType = bodyFn.getType();
auto cond_fn_type = cond_fn.getType();
auto body_fn_type = body_fn.getType();
// Verify that the cond function has exactly one result.
if (condFuncType.getNumResults() != 1)
if (cond_fn_type.getNumResults() != 1)
return op.emitOpError("requires cond function to have exactly one result");
SmallVector<Type, 4> operands(op.getOperandTypes());
// Collect all the type lists for the op so that different pairs of type lists
// can be compared for the compatibility.
int numTypeLists = 5;
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
{"operand", operands},
{"body function result", bodyFuncType.getResults()},
{"result", op.getResultTypes()},
{"cond function input", condFuncType.getInputs()},
{"body function input", bodyFuncType.getInputs()},
};
constexpr int kNumTypeLists = 5;
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
type_lists = {{
{"operand", operands},
{"body function result", body_fn_type.getResults()},
{"result", op.getResultTypes()},
{"cond function input", cond_fn_type.getInputs()},
{"body function input", body_fn_type.getInputs()},
}};
// A pair of type lists should be cast compatible with each other if one is
// converted to the another for a function call or assignment or there is a
@ -3753,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
// never converted from one to the another nor there is a common source
// tensors. Compatibility requirement is not transitive.
for (int i = 0; i < numTypeLists; ++i) {
for (int i = 0; i < kNumTypeLists; ++i) {
// Skip the first pair as the While op operands and body function results
// does not need to be compatible with each other.
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
auto &a = typeLists[i];
auto &b = typeLists[j];
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
auto &a = type_lists[i];
auto &b = type_lists[j];
int aSize = a.second.size();
if (aSize != b.second.size())
int a_size = a.second.size();
if (a_size != b.second.size())
return op.emitOpError(
llvm::formatv("requires the number of {0}s to be equal to the "
"number of {1}s. Found {2} and {3}, respectively",
a.first, b.first, aSize, b.second.size()));
a.first, b.first, a_size, b.second.size()));
for (int idx = 0; idx < aSize; ++idx) {
auto aType = a.second[idx];
auto bType = b.second[idx];
for (int idx = 0; idx < a_size; ++idx) {
auto a_type = a.second[idx];
auto b_type = b.second[idx];
if (!AreCastCompatible({aType, bType}))
if (!AreCastCompatible({a_type, b_type}))
return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx));
a.first, a_type, b.first, b_type, idx));
}
}
}
@ -3856,7 +4044,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
addInterfaces<TFInlinerInterface>();
addAttributes<ShapeAttr>();
addAttributes<ShapeAttr, FuncAttr>();
// Support unknown operations because not all TensorFlow operations are
// registered.
@ -3911,6 +4099,49 @@ void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT
os << ">";
}
// Parses a #tf.func attribute of the following format:
//
// #tf.func<@symbol, {attr = "value"}>
//
// where the first element is a SymbolRefAttr and the second element is a
// DictionaryAttr.
FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) {
auto emit_error = [&, spec]() {
emitError(loc, "invalid TensorFlow func attribute: ") << spec;
return nullptr;
};
if (!spec.consume_front("func<")) return emit_error();
size_t func_name_num_read = 0;
Attribute func_name_attr =
mlir::parseAttribute(spec, context, func_name_num_read);
if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
return emit_error();
spec = spec.drop_front(func_name_num_read);
if (!spec.consume_front(", ")) return emit_error();
size_t func_attrs_num_read = 0;
Attribute func_attrs_attr =
mlir::parseAttribute(spec, context, func_attrs_num_read);
if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
return emit_error();
spec = spec.drop_front(func_attrs_num_read);
if (!spec.consume_front(">")) return emit_error();
return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
func_attrs_attr.cast<DictionaryAttr>());
}
// Prints a #tf.func attribute of the following format:
//
// #tf.func<@symbol, {attr = "value"}>
void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) {
os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">";
}
} // namespace
Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
@ -3920,6 +4151,8 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
}
@ -3929,6 +4162,9 @@ void TensorFlowDialect::printAttribute(Attribute attr,
case AttrKind::SHAPE:
PrintShapeAttr(attr.cast<ShapeAttr>(), os);
break;
case AttrKind::FUNC:
PrintFuncAttr(attr.cast<FuncAttr>(), os);
break;
default:
llvm_unreachable("unexpected tensorflow attribute kind");
}

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"

View File

@ -207,6 +207,70 @@ else_branch: A function that takes 'inputs' and returns a list of
}];
}
def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
The "yield" operation represents a return operation within the conditional
and body of structured control flow (e.g., if and while). The operation
takes a variable number of operands and produces no results. The number and
types of inputs must match the signature of the operation that contains the
region.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let verifier = [{
return Verify(*this);
}];
}
def TF_IfRegionOp : TF_Op<"IfRegion",
[SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "output = cond ? then_branch output : else_branch output";
let description = [{
"output = cond ? then_branch output : else_branch output"
cond: A Tensor. If the tensor is a scalar of non-boolean type, the
scalar is converted to a boolean according to the
following rule: if the scalar is a numerical value, non-zero means
True and zero means False; if the scalar is a string, non-empty
means True and empty means False. If the tensor is not a scalar,
being empty means False and being non-empty means True.
input: A list of input tensors.
then_branch: A region that computes the outputs of the op if cond = true.
It returns a list of tensors using tf.yield (as the terminator). The
types of these returned tensors is same as that of the else_branch
else_branch: A region that computes the outputs of the op if cond = false.
It returns a list of tensors using tf.yield (as the terminator). The
types of these returned tensors is same as that of the then_branch
}];
let arguments = (ins
TF_Tensor:$cond,
Variadic<TF_Tensor>:$input,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
BoolAttr:$is_stateless
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch);
let verifier = [{
return Verify(*this);
}];
}
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Computes the mean of elements across dimensions of a tensor.";

View File

@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef<Type> types) {
return true;
}
ShapedType DropTypeSubTypes(ShapedType ty) {
Type element_ty = ty.getElementType();
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!subtype_ty) return ty;
Type default_ty = GetDefaultTypeOf(subtype_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
ShapedType DropRefType(ShapedType ty) {
Type element_ty = ty.getElementType();
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
if (!ref_ty) return ty;
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
return UnrankedTensorType::get(default_ty);
}
} // namespace TF
} // namespace mlir

View File

@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
// compatible.
bool AreCastCompatible(ArrayRef<Type> types);
// If the given tensor has elements of type with subtypes, then returns a new
// type after dropping subtypes info. Otherwise, returns the original type as
// is.
ShapedType DropTypeSubTypes(ShapedType ty);
// If the given tensor has elements of type ref, then returns a new type
// of the shape, but corresponding non-ref type as element type. Otherwise,
// returns the original type as is.
ShapedType DropRefType(ShapedType ty);
} // end namespace TF
} // end namespace mlir

View File

@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi
// CHECK: return %arg0
}
// CHECK-LABEL: testSelectScalarPred
func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> {
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
return %0: tensor<4x2xf16>
}
// CHECK-LABEL: testSelectVectorPred
func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"
// CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1>
// CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: testSelectAllSameShape
func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// If we don't have guarantees on input shapes, we can't support canonicalizing
// to SelectV2. Test these cases.
// CHECK-LABEL: testSelectInvalid
func @testSelectInvalid(%arg0: tensor<?xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: testSelectInvalidUnranked
func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testSelectThenUnranked
func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testSelectElseUnranked
func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testLogicalNotOfEqual
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
@ -473,12 +526,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
}
// CHECK-LABEL: @foldFill
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) {
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32>
%complex_cst = "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
// Here, custom folder doesn't handle complex dtypes and it is folded through
// the constant folding hook.
// TODO(hinsu): Handle complex dtypes in the custom folder for FillOp.
// CHECK: "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex<f32>>} : () -> tensor<*xcomplex<f32>>
%4 = "tf.Fill"(%0, %complex_cst) : (tensor<3xi32>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
}

View File

@ -302,15 +302,13 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) ->
return %0: tensor<2xi32>
}
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialAdd
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
@ -331,26 +329,22 @@ func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
}
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialAddV2
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialSub
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
@ -362,26 +356,31 @@ func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
}
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialMul
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialDiv
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialRealDiv
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
}
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
@ -411,28 +410,35 @@ func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
// CHECK: tf.Div
}
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK-LABEL: DontRemoveTrivialAdd
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: return %[[RESULT]] : tensor<2x2xf32>
}
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
return %1 :tensor<?x?xf32>
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 :tensor<?x?xf32>
// CHECK-LABEL: DontRemoveTrivialAdd2
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
}
// Test no fold because of the broadcast.
func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
return %1 : tensor<1x6x8x1xf32>
// CHECK-LABEL: DontRemoveTrivialMul
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
}

View File

@ -0,0 +1,50 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics
// Tests invalid #tf.func attributes.
// expected-error@+1 {{invalid TensorFlow func attribute: func}}
func @main() attributes {tf._implements = #tf.func} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<>}}
func @main() attributes {tf._implements = #tf.func<>} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol>}}
func @main() attributes {tf._implements = #tf.func<@symbol>} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<{}>}}
func @main() attributes {tf._implements = #tf.func<{}>} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<"test", {}>}}
func @main() attributes {tf._implements = #tf.func<"test", {}>} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, "">}}
func @main() attributes {tf._implements = #tf.func<@symbol, "">} {
return
}
// -----
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, {}, "">}}
func @main() attributes {tf._implements = #tf.func<@symbol, {}, "">} {
return
}

View File

@ -0,0 +1,13 @@
// RUN: tf-opt %s | tf-opt | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @func_attr
// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>
func @func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>} {
return
}
// CHECK-LABEL: func @nested_func_attr
// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.000000e+00 : f32}>}>
func @nested_func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.0 : f32}>}>} {
return
}

View File

@ -0,0 +1,53 @@
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure
node {
name: "custom_relu_func_call"
op: "custom_relu"
}
node {
name: "custom_embedding_matmul_func_call"
op: "custom_embedding_matmul"
}
library {
function {
signature {
name: "custom_relu"
}
attr {
key: "_implements"
value {
func {
name: "tensorflow.relu"
}
}
}
}
function {
signature {
name: "custom_embedding_matmul"
}
attr {
key: "_implements"
value {
func {
name: "tensorflow.embedding_matmul"
attr {
key: "key1"
value {
i: 2
}
}
attr {
key: "key2"
value {
b: false
}
}
}
}
}
}
}
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}

View File

@ -2,17 +2,17 @@
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
return %0 : tensor<?x?x?x?xi32>
}
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0 : tensor<4x4x4x4xi32>
}
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32>
}
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
}
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
return %2 : tensor<2x3xf16>
}
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %3 : tensor<1xi32>
}
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
}
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %3 : tensor<4x8xf32>
@ -682,6 +682,11 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
return %0 : tensor<2xf32>
}
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
return %0 : tensor<1x519xf32>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1493,3 +1498,12 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
// CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32>
// CHECK: return [[VAL_371]] : tensor<2xf32>
// CHECK: }
// CHECK-LABEL: func @convert_slice(
// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> {
// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64>
// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32>
// CHECK: return [[VAL_375]] : tensor<1x519xf32>
// CHECK: }

View File

@ -0,0 +1,85 @@
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
// Test case: Basic converting.
func @f() {
// CHECK: "tf.VarHandleOp"
// CHECK: "tf.ReadVariableOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: Two ReadVariable ops.
func @f() {
// CHECK: "tf.VarHandleOp"
// During lowering to resource variables, this pass will preserve the
// locations of the ReadVariableOps as Identity ops to keep the original graph
// composition and order.
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.ReadVariableOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
%val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No follow-up ReadVariable case.
func @f() {
// CHECK-NOT: "tf.VariableV2"
// CHECK-NOT: "tf.VarHandleOp"
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
return
}
// -----
// Test case: No converting when there is another use case.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}}
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No class attribute on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}}
%val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: No named location found on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}}
%val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}
// -----
// Test case: Invalid multiple location information in a class attribute on VariableV2 op.
func @f() {
// expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}}
%val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
return
}

View File

@ -1,10 +1,11 @@
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail
// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants=false -verify-diagnostics | FileCheck %s -dump-input=fail
// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s -dump-input=fail
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
// CHECK-NOT: tf.Cast
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"
// CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: return %[[RESULT]] : tensor<1xi32>
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
@ -60,8 +61,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @simple_folding
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
// CHECK: %[[SHAPE:.*]] = "tf.Shape"
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]]
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
@ -300,13 +301,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @fold_cast
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NOT: Cast
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @while_variant
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
@ -362,8 +356,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
// CHECK-LABEL: func @partitioned_call_func_const
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: return %[[CONST]]
return %arg0 : tensor<2xi32>
}
@ -410,4 +402,18 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
return
}
// CHECK-LABEL: const_fold
func @const_fold() -> () {
// CHECK: tf.Const
// CHECK-SAME: () -> tensor<4xi32>
%0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
// CHECK: tf.Const
// CHECK-SAME: () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
// CHECK: tf.Add
// CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return
}
}

View File

@ -854,6 +854,215 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
// -----
// Test invalid tf.Yield operation (parent should be IfRegion)
func @testInvalidYieldOp(%arg0: f32) -> () {
// expected-error @+1 {{expects parent op 'tf.IfRegion'}}
"tf.Yield"(%arg0) : (f32) -> ()
}
// -----
// Test valid tf.IfRegion operation
// CHECK-LABEL: func @testValidIfRegionOp
func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Test valid tf.IfRegion operation with multiple results
// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults
func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({
%t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t0, %t1, %t2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
}, {
%e0 = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
%e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
%3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %4 : tensor<2xf32>
}
// -----
// Test invalid type for operand #0 for tf.IfRegion operation
func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{operand #0 must be tensor of tf.dtype values}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Test invalid type for operand #1 for tf.IfRegion operation
func @testInvalidIfRegionOpType1(%arg0: tensor<i1>, %arg1: f32) -> f32 {
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = addf %arg1, %arg1 : f32
"tf.Yield"(%t) : (f32) -> ()
}, {
%e = mulf %arg1, %arg1 : f32
"tf.Yield"(%e) : (f32) -> ()
}) { is_stateless = false} : (tensor<i1>, f32) -> f32
return %0 : f32
}
// -----
// tf.IfRegion operation should have 2 regions
func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{op expected 2 regions}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%te = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%te) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// tf.IfRegion regions should be terminated with a tf.Yield
func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// tf.Region yield number of results should match op number of results
func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then region should have 1 result}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else region should have 1 result}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// tf.IfRegion yield types should match op result types
func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{then result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
}, {
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// expected-error @+1 {{else result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
%0 = "tf.IfRegion"(%arg0, %arg1) ({
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
}, {
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// -----
// Test valid tf.MatrixBandPart
// CHECK-LABEL: func @testValidMatrixBandPartOp
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
@ -1007,6 +1216,116 @@ func @pcall_func_2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
// -----
//===--------------------------------------------------------------------===//
// tf.Select
//===--------------------------------------------------------------------===//
// Test valid tf.Select
// CHECK-LABEL: func @testSelect
func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
return %0: tensor<3x2xf16>
}
// -----
func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// -----
// Test invalid tf.Select - broadcasting then/else parameters is not supported
func @selectBroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
// expected-error @+1 {{requires t and e have compatible shapes}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2xi32> {
// expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// -----
func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> {
// expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32>
return %0: tensor<1x8x8xi32>
}
// -----
//===--------------------------------------------------------------------===//
// tf.SelectV2
//===--------------------------------------------------------------------===//
// Test valid tf.SelectV2
// CHfaECK-LABEL: func @selectV2BroadcastThen
func @selectV2BroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// Test valid tf.SelectV2
// CHECK-LABEL: func @selectV2BroadcastElse
func @selectV2BroadcastElse(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// Test valid tf.SelectV2
// CHECK-LABEL: func @selectV2BroadcastPred
func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2BroadcastAll
func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
return %0: tensor<8x8x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2DynamicRanked
func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
return %0: tensor<2x?x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2Unranked
func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}
// -----
// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate
func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
return %0: tensor<3x2xf16>
}
// -----
//===--------------------------------------------------------------------===//
// tf.Softmax
//===--------------------------------------------------------------------===//
@ -1326,7 +1645,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
// expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -1370,7 +1689,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}}
// expected-warning @+1 {{has static shape result #1 for unranked operand #1}}
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
return %0#1 : tensor<2xi32>
}
@ -1428,7 +1747,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1
// -----
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
// expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}

View File

@ -104,3 +104,9 @@ module attributes {tf_saved_model.semantics} {
return
}
}
// -----
// Test running the pass on a module that does not have
// tf_saved_model.semantics.
module {}

View File

@ -136,3 +136,9 @@ module attributes {tf_saved_model.semantics} {
}
}
// -----
// Test running the pass on a module that does not have
// tf_saved_model.semantics.
module {}

View File

@ -2,80 +2,183 @@
// Tests extraction of a outside compiled ops at head of TPU computation.
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: tf_device.launch
// CHECK: "tf.A"
// CHECK-NEXT: tf_device.return
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.B"() : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// CHECK-LABEL: func @single_head_outside_compilation
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: tf_device.launch
// CHECK: "tf.A"
// CHECK-NEXT: tf_device.return
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.B"() : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @multiple_head_outside_compilation
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return %[[B_OUT]]
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
"tf.D"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// CHECK-LABEL: func @ops_no_operands
func @ops_no_operands() -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]])
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
%1 = "tf.B"(%0) {}: (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
// CHECK-NOT: tf_device.launch
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// CHECK-LABEL: func @aliased_output
func @aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]])
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1
%0:3 = "tf_device.cluster"() ( {
%1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
%2 = "tf.B"(%1) {}: (tensor<i32>) -> (tensor<i32>)
%3 = "tf.C"(%2) : (tensor<i32>) -> (tensor<i32>)
tf_device.return %1, %3, %2 : tensor<i32>, tensor<i32>, tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>)
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
// CHECK-NEXT: tf_device.return %[[D_OUT]]
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.B"
// CHECK: "tf.C"
// CHECK: "tf.E"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"() {} : () -> (tensor<i32>)
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
// CHECK-LABEL: func @all_head_computation_ops
func @all_head_computation_ops(%arg0 : tensor<i32>) -> (tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: "tf_device.cluster"
// CHECK-NEXT: tf_device.return
//
// CHECK: return %[[LAUNCH_OUT]]
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
tf_device.return %3 : tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>)
return %0 : tensor<i32>
}
// CHECK-LABEL: func @multiple_head_outside_compilation
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
"tf.D"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
// CHECK-NOT: tf_device.launch
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
// CHECK-NEXT: tf_device.return %[[D_OUT]]
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.B"
// CHECK: "tf.C"
// CHECK: "tf.E"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"() {} : () -> (tensor<i32>)
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @test_replicated_head_outside_compilation
func @test_replicated_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
// CHECK-NEXT: tf_device.return %[[D_OUT]]
// CHECK: device = "TPU_REPLICATED_HOST"
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.B"
// CHECK: "tf.C"
// CHECK: "tf.E"
// CHECK-NEXT: tf_device.return
tf_device.replicate() {n = 2 : i32} {
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"() {} : () -> (tensor<i32>)
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
}

View File

@ -141,4 +141,133 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
return %1 : tensor<?xf32>
}
// Tests extraction of a single outside compiled cluster with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
// CHECK-LABEL: func @mixed_input_single_outside_compilation
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a multiple outside compiled clusters with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute.

View File

@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication.
// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under
// data parallelism replicated host devices are also added to the
// tf_device.replicate
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_tpu_cluster_func
@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}
// CHECK-SAME: n = 2
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
// Tests simple case of `tf_device.cluster_func` on TPU with replication and
// parallel_execute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
@ -1231,17 +1234,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK: "tf._TPUCompileMlir"
// CHECK: "tf.TPUCompileSucceededAssert"
// CHECK: "tf_device.parallel_execute"
// CHECK-NOT:"tf._TPUCompileMlir"
// CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1
// CHECK: "tf.TPUExecute"
// CHECK-NOT:"tf._TPUCompileMlir"
// CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1
%3 = "tf_device.parallel_execute"() ( {
"tf.D"() : () -> ()
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.D"(%program) : (tensor<!tf.string>) -> ()
tf_device.return
}, {
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}, {
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.E"(%program) : (tensor<!tf.string>) -> ()
tf_device.return
}) : () -> (tensor<?xi32>)
tf_device.return %3 : tensor<?xi32>
}
@ -1317,15 +1329,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01"
// -----
// Tests devices are set properly for replicated model parallelism.
// Tests devices are set properly for replicated model parallelism. No
// replicated host device should be present.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @replicated_parallel_execute
func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
// CHECK: tf_device.replicate
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]}
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
@ -1357,8 +1368,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests that inputs are inputs with maximal and replicate sharding are set properly
// for replicated model parallelism.
// Tests that inputs are inputs with maximal and replicate sharding are set
// properly for replicated model parallelism.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations
@ -1392,8 +1403,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests devices are set properly for replicated model parallelism with
// outputs to TPU computation placed on logical device 0.
// Tests devices are set properly for replicated model parallelism with outputs
// to TPU computation placed on logical device 0.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_different_outputs
@ -1469,8 +1480,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests inputs are correctly split and fed into TPU computation for
// tiled input sharding.
// Tests inputs are correctly split and fed into TPU computation for tiled input
// sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:

View File

@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
(replaceWithValue $arg)>;
//===----------------------------------------------------------------------===//
// Select op patterns.
//===----------------------------------------------------------------------===//
def ReshapeSelectPredIfNecessary : NativeCodeCall<
"ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, "
"$2.getType().cast<RankedTensorType>().getRank())">;
// Select supports tensor `condition` where the shape is equal to the first
// dimension of t and e. SelectV2 op supports normal broadcasting, so in these
// cases the condition needs to be reshaped.
def SelectToSelectV2 : Pat<
(TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond,
StaticShapeTensorOf<[AnyType]>:$t,
StaticShapeTensorOf<[AnyType]>:$e),
(TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>;
//===----------------------------------------------------------------------===//
// Square op patterns.
//===----------------------------------------------------------------------===//

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <algorithm>
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"

View File

@ -48,6 +48,9 @@ struct FreezeGlobalTensorsPass
void FreezeGlobalTensorsPass::runOnOperation() {
auto module = getOperation();
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
return;
}
SymbolTable symbol_table(module);
DenseSet<Operation*> frozen_global_tensors;
@ -66,7 +69,9 @@ void FreezeGlobalTensorsPass::runOnOperation() {
// previous optimize global tensors pass). If not, this pass has to fail
// since it cannot perform one of its goals.
if (global_tensor.is_mutable()) {
global_tensor.emitError() << "is not immutable";
global_tensor.emitError() << "is not immutable, try running "
"tf-saved-model-optimize-global-tensors "
"to prove tensors are immutable";
return signalPassFailure();
}

View File

@ -16,21 +16,61 @@ limitations under the License.
// This file implements logic for legalizing HLO to TensorFlow.
#include <memory>
#include <vector>
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
namespace mlir {
namespace TF {
namespace {
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::SliceOp slice_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
DenseIntElementsAttr strides = slice_op.strides();
// Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op.
if (!strides.isSplat() ||
strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
return failure();
rewriter.setInsertionPointAfter(slice_op);
auto start_indices = slice_op.start_indices();
auto limit_indices = slice_op.limit_indices();
std::vector<int64_t> size_values;
for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
limit_indices.getValues<APInt>())) {
size_values.emplace_back(std::get<1>(pair).getSExtValue() -
std::get<0>(pair).getSExtValue());
}
RankedTensorType ty =
RankedTensorType::get({static_cast<int64_t>(size_values.size())},
rewriter.getIntegerType(64));
auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
auto size = rewriter.create<ConstOp>(
slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
slice_op.operand(), start, size);
return success();
};
};
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
public:
LegalizeHloToTf() = default;
@ -63,6 +103,7 @@ void LegalizeHloToTf::runOnFunction() {
// Add legalization patterns to the list.
OwningRewritePatternList patterns;
populateWithGenerated(&context, &patterns);
patterns.insert<ConvertSliceOp>(&context);
ConversionTarget target(context);
target.addLegalDialect<TensorFlowDialect>();

View File

@ -18,12 +18,16 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
// Check that two values can be broadcasted together
@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_DivOp, TF_DivOp],
[HLO_ShiftLeftOp, TF_LeftShiftOp],
[HLO_MaxOp, TF_MaximumOp],
[HLO_MinOp, TF_MinimumOp],
[HLO_MulOp, TF_MulOp],
[HLO_PowOp, TF_PowOp],
[HLO_SubOp, TF_SubOp],
[HLO_Atan2Op, TF_Atan2Op],
[HLO_RemOp, TF_ModOp]] in
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op],
[HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp],
[HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp],
[HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp],
[HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp],
[HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp],
[HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp],
[HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp],
[HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op],
[HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in {
def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
[HLO_OrOp, TF_BitwiseOrOp],
[HLO_XorOp, TF_BitwiseXorOp]] in
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp],
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp],
[HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in {
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
[HLO_OrOp, TF_LogicalOrOp]] in
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp],
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in {
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
}
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>;
def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim),
//===----------------------------------------------------------------------===//
// Compare op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>;
}
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>;
}

View File

@ -278,6 +278,10 @@ void EraseUnusedBoundInputs(ModuleOp module) {
void OptimizeGlobalTensorsPass::runOnOperation() {
auto module = getOperation();
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
return;
}
EraseUnusedBoundInputs(module);
ResourceAnalyzer resource_analyzer(module);

View File

@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
// functions.
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
// Creates a pass that converts readonly reference variables to the
// corresponding resource variables.
std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
// Marks function visibility using tf.entry_function specification. That is,
// functions with tf.entry_function attributes are marked with public
// visibility while the other functions are marked with private visibility.

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace {
// Location attribute.
constexpr StringRef kClassAttr = "_class";
constexpr StringRef kLocationPrefix = "loc:@";
// A pass that converts readonly reference variables to the corresponding
// resource variables.
//
// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
//
// For the background, this pass is a part of hoisting VariableV2 ops by
// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
// can be done by the following passes:
// - Capturing resource values into global tensors (importing saved model).
// - Promoting VarHandle ops to function input/outputs.
// - Freezing global tensor pass.
//
// This path assumes that all the VariableV2 ops is read-only via verifying the
// heuristic method that assumes that all the users of them is Identity op,
// fed directly.
class ConvertReadonlyReferenceVariablesToResourceVariablesPass
: public PassWrapper<
ConvertReadonlyReferenceVariablesToResourceVariablesPass,
FunctionPass> {
public:
void runOnFunction() override;
};
// Parse node name from "_class" attribute.
StringRef GetNodeNameFromClassAttr(Operation *op) {
ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
if (!classes_attr) {
op->emitOpError() << "has no '_class' attribute";
return StringRef();
}
StringRef result;
for (Attribute class_attr : classes_attr) {
StringRef node_name = class_attr.cast<StringAttr>().getValue();
if (!node_name.startswith(kLocationPrefix)) {
continue;
}
if (!result.empty()) {
// Invalid case since there are multiple loc:@ attributes.
op->emitOpError()
<< "expects only one named location in '_class' attribute, but got "
<< classes_attr;
return StringRef();
}
result = node_name.drop_front(kLocationPrefix.size());
}
if (result.empty()) {
op->emitOpError() << "expects variable name in '_class' attribute, but got "
<< classes_attr;
}
return result;
}
void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
FuncOp func = getFunction();
OpBuilder builder(func.getContext());
SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
// Checks all the VariableV2 ops is read-only via verifying the heuristic
// method that assumes that all the users of them is Identity op, feeded
// directly.
auto read_only_vars_fn = [&variable_v2s_to_replace](
VariableV2Op variable_v2_op) {
if (variable_v2_op.getResult().use_empty()) {
// Erase the op when there is no user.
variable_v2_op.erase();
return mlir::WalkResult::advance();
}
if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
Operation *user) {
if (!isa<IdentityOp>(user)) {
variable_v2_op.emitOpError()
<< "expects all users to be 'tf.Identity', but got user "
<< user->getName();
return false;
}
return true;
})) {
return mlir::WalkResult::interrupt();
}
variable_v2s_to_replace.push_back(variable_v2_op);
return mlir::WalkResult::advance();
};
WalkResult walk_res = func.walk(read_only_vars_fn);
if (walk_res.wasInterrupted()) return signalPassFailure();
for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
builder.setInsertionPoint(variable_v2_op);
ShapedType shaped_type =
variable_v2_op.getResult().getType().cast<ShapedType>();
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
if (!device_attr) device_attr = builder.getStringAttr("");
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
if (variable_name.empty()) {
return signalPassFailure();
}
VarHandleOp var_handle_op = builder.create<VarHandleOp>(
variable_v2_op.getLoc(),
ArrayRef<Type>{RankedTensorType::get(
{}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
builder.getContext()))},
ArrayRef<Value>{},
ArrayRef<NamedAttribute>{
builder.getNamedAttr("device", device_attr),
builder.getNamedAttr("container", variable_v2_op.containerAttr()),
builder.getNamedAttr("shared_name",
builder.getStringAttr(variable_name))});
for (Operation *user :
make_early_inc_range(variable_v2_op.getResult().getUsers())) {
builder.setInsertionPoint(user);
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
user->getLoc(), ArrayRef<Type>{tensor_type},
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
user->erase();
}
variable_v2_op.erase();
}
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
return std::make_unique<
ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
}
static PassRegistration<
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
pass("readonly-references-to-resources",
"Convert readonly reference variables to resource variables.");
} // namespace TF
} // namespace mlir

View File

@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
}
auto walk_res = func_op.walk([&](Operation* op) {
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
// Record VarHanldeOp's device attribute.
// Record VarHandleOp's device attribute.
auto device_attr =
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
if (!device_attr || device_attr.getValue().empty()) {

View File

@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
}
// Lifts loads/stores from while loop's body and cond functions.
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&body.front());
RemoveIdentity(&cond.front());
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&cond.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch =
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));

View File

@ -429,7 +429,8 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
// existing computed values.
Attribute ComputeOutputComponent(const ValuePort& value_port,
ValueQueryFn values) {
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
if (auto known = values(value_port)) return known;
auto op = value_port.producer.dyn_cast<Operation*>();
if (!op) return nullptr;
@ -454,6 +455,21 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
ValuePort op_port(op->getOperand(port[1]));
return values(op_port);
}
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
if (port.size() == 1)
return ComputeOutputComponent(
ValuePort(graph.GetFetch().fetches()[port[0]]), values);
return nullptr;
}
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
if (port.size() == 1)
return ComputeOutputComponent(
ValuePort(island.GetYield().fetches()[port[0]]), values);
return nullptr;
}
return nullptr;
}
@ -462,7 +478,8 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
// TF Graph version, constant values computed, etc.)
class ShapeInference {
public:
ShapeInference(int64_t graph_version, MLIRContext* context);
ShapeInference(int64_t graph_version, MLIRContext* context,
bool propagate_caller_callee_constants);
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
ValuePortInputs* inputs) {
@ -475,14 +492,19 @@ class ShapeInference {
}
Attribute ComputeOutputComponent(const ValuePort& value_port) {
return ::mlir::TF::ComputeOutputComponent(
if (auto known_attr = results_[value_port]) return known_attr;
auto attr = ::mlir::TF::ComputeOutputComponent(
value_port, [this](const ValuePort& port) { return results_[port]; });
RecordValue(value_port, attr);
return attr;
}
// Returns ShapeHandle if the op result could be computed as shape.
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
void RecordValue(const ValuePort& value_port, Attribute value) {
LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
<< value << "\n");
results_[value_port] = value;
}
@ -520,19 +542,41 @@ class ShapeInference {
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
int64_t max_iteration);
// Propagates any constant operand of call_op to the called function body's
// corresponding argument if the callee has only one use.
//
// TODO(b/154065712): Move this to a more general inter-procedural constant
// folding pass.
void PropagateConstantToCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module);
// Propagates any constant return value of the callee function to the call
// op's corresponding result.
void PropagateConstantFromCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module);
// Tries to compute the result of folding the op. This doesn't actually
// perform constant folding, it is just computes the equivalent constants.
// Returns whether it was able to compute constant values.
LogicalResult TryToFold(Operation* op);
private:
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
// corresponds to a constant value.
ValuePortResultMap results_;
int64_t graph_version_;
MLIRContext* context_;
Dialect* tf_dialect_;
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
// SCCP pass instead.
bool propagate_caller_callee_constants_;
};
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
: graph_version_(graph_version) {
context_ = context;
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
bool propagate_caller_callee_constants)
: graph_version_(graph_version),
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
}
@ -581,7 +625,6 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
auto ret = ComputeOutputComponent(front);
if (!ret) continue;
RecordValue(front, ret);
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
// If worklist is empty, then this is the root query op.
@ -602,6 +645,8 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
}
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
llvm::dbgs() << "\n");
assert(tf_dialect_ == op->getDialect());
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
@ -686,10 +731,14 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
size_t index = it.index();
// If the operand is constant, then convert it to Tensor.
ElementsAttr attr;
if (matchPattern(operand, m_Constant(&attr))) {
ValuePort vp(operand);
Attribute attr = ComputeOutputComponent(vp);
if (!attr && matchPattern(operand, m_Constant(&attr)))
RecordValue(vp, attr);
if (attr) {
tensorflow::Tensor* input_tensor = &tensors[index];
auto status = tensorflow::ConvertToTensor(attr, input_tensor);
auto status =
tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor);
if (status.ok()) {
input_tensors[index] = input_tensor;
} else {
@ -728,10 +777,12 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
!input_tensors[input];
});
if (requires_inputs) {
LLVM_DEBUG(llvm::dbgs() << "\trequired input\n");
std::vector<ShapeHandle> input_tensors_as_shapes;
for (int input : llvm::seq<int>(0, c.num_inputs())) {
if (c.requested_input_tensor_as_partial_shape(input) &&
!input_tensors[input]) {
LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n");
auto op_result = op->getOperand(input).dyn_cast<OpResult>();
if (!op_result) continue;
// Resize on first valid shape computed.
@ -865,45 +916,62 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
return success(all_succeeded);
}
// If the callee has only one use, propagates any constant operand of call_op to
// the called function body's corresponding argument.
//
// TODO(b/154065712): Move this to a more general inter-procedural constant
// folding pass.
void PropagateConstantToCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module) {
void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym,
ModuleOp module) {
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
int num_uses = std::distance(func_uses->begin(), func_uses->end());
if (num_uses != 1) return;
OpBuilder builder(&func.front().front());
Operation* op = call_op.getOperation();
if (num_uses == 1) {
// If this is the only caller, and an operand is a constant, propagate
// the constant inside the function.
for (auto arg : func.getArguments()) {
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
if (isa_and_nonnull<TF::ConstOp>(operand)) {
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
// If this is the only caller, and an operand is a constant, propagate
// the constant value inside the function.
for (auto arg : func.getArguments()) {
auto operand = op->getOperand(arg.getArgNumber());
if (propagate_caller_callee_constants_) {
if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
arg.replaceAllUsesWith(
builder.clone(*operand.getDefiningOp())->getResult(0));
}
continue;
}
auto known_constant = ComputeOutputComponent(ValuePort(operand));
if (!known_constant) continue;
LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
known_constant.print(llvm::dbgs() << " constant ");
llvm::dbgs() << "\n");
RecordValue(ValuePort(arg), known_constant);
}
}
// Propagates any constant return value of the callee function to the call op's
// corresponding result.
void PropagateConstantFromCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym, ModuleOp module) {
void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
SymbolRefAttr callee_sym,
ModuleOp module) {
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
// If the return value is a constant, replace the call result with a constant.
// If the return value is a constant, use the constant as the value of
// the call return.
Operation* op = call_op.getOperation();
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
for (auto retval :
llvm::enumerate(func.front().getTerminator()->getOperands())) {
auto retval_op = retval.value().getDefiningOp();
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
op->getResult(retval.index())
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
if (propagate_caller_callee_constants_) {
auto retval_op = retval.value().getDefiningOp();
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
op->getResult(retval.index())
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
}
continue;
}
ValuePort vp(retval.value());
if (auto known_constant = ComputeOutputComponent(vp)) {
LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
}
}
}
@ -938,10 +1006,71 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
return success();
}
LogicalResult ShapeInference::TryToFold(Operation* op) {
LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
// If any output result is known, then the op probably has been computed
// before.
if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
return success();
SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
SmallVector<OpFoldResult, 8> fold_results;
// Check to see if any operands to the operation is constant and whether
// the operation knows how to constant fold itself.
bool some_unknown = false;
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
if (!(constant_operands[i] =
ComputeOutputComponent(ValuePort(op->getOperand(i)))))
some_unknown = true;
}
// Attempt to constant fold the operation.
auto* abstract_op = op->getAbstractOperation();
LogicalResult folded = failure();
if (abstract_op) {
folded = abstract_op->foldHook(op, constant_operands, fold_results);
}
// Attempt dialect fallback if op's fold hook failed.
if (failed(folded)) {
Dialect* dialect = op->getDialect();
if (!dialect) return failure();
// Only attempt TF dialect fallback if there are no unknown operands.
if (some_unknown && dialect == tf_dialect_) return failure();
SmallVector<Attribute, 8> constants;
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
return failure();
fold_results.assign(constants.begin(), constants.end());
}
for (auto result : zip(op->getResults(), fold_results)) {
auto fold_result = std::get<1>(result);
Attribute attr = nullptr;
if ((attr = fold_result.dyn_cast<Attribute>())) {
RecordValue(ValuePort(std::get<0>(result)), attr);
} else {
auto value = fold_result.get<Value>();
if ((attr = ComputeOutputComponent(ValuePort(value))))
RecordValue(ValuePort(std::get<0>(result)), attr);
}
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
if (std::get<0>(result).getType() == eattr.getType()) continue;
// Inserts a cast back to the original type if any user is not in the
// TF dialect.
Type old_type = std::get<0>(result).getType();
std::get<0>(result).setType(eattr.getType());
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_,
old_type);
}
}
return success();
}
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
int64_t max_iteration) {
// An operation folder that is used to attempt folding before inference._
OperationFolder folder(context_);
bool changed = true;
// TODO(aminim): we could have a more efficient traversal by guiding the
@ -955,9 +1084,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
region->walk([&](Operation* op) {
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
// TODO(jpienaar): Debug why we can't just return here. We end up with
// additional constant due to the propagation of constant into attached
// function if we return already.
return;
}
if (op->getDialect() != tf_dialect_) {
@ -965,8 +1092,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
return;
}
// Before attempting inference, just try to fold the operation.
if (succeeded(folder.tryToFold(op))) return;
// Before attempting inference, just try to compute the folded
// value/shape.
if (succeeded(TryToFold(op))) return;
// Best-effort shape inference in attached functions. Do not return
// failure even if it doesn't get to fixed point.
@ -989,8 +1117,10 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version) {
ShapeInference context(graph_version, func.getContext());
int64_t graph_version,
bool propagate_caller_callee_constants) {
ShapeInference context(graph_version, func.getContext(),
propagate_caller_callee_constants);
if (arg_shapes.empty()) {
if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
return failure();
@ -1014,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<int64_t> shape = arg_shapes[i];
Type element_type;
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
if (!input_ty || input_ty.getShape().size() != shape.size()) {
if (input_ty.getRank() != shape.size()) {
return failure();
}
element_type = input_ty.getElementType();

View File

@ -30,9 +30,11 @@ namespace TF {
// Given a list of refined shapes matching the function arguments of func, runs
// shape inference over the function to propagate this updated information.
// If arg_shapes are empty, then argument shapes will be left unchanged.
LogicalResult InferShapeForFunction(FuncOp func,
ArrayRef<ArrayRef<int64_t>> arg_shapes,
int64_t graph_version);
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
// SCCP pass instead.
LogicalResult InferShapeForFunction(
FuncOp func, ArrayRef<ArrayRef<int64_t>> arg_shapes, int64_t graph_version,
bool propagate_caller_callee_constants = true);
} // namespace TF

View File

@ -47,8 +47,15 @@ namespace {
// This transformation pass propagate shapes on the TensorFlow graph.
// It is a ModulePass in order to be able to change function types.
struct ShapeInference
class ShapeInference
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
public:
ShapeInference() = default;
ShapeInference(const ShapeInference& that) {
propagate_caller_callee_constants_ =
that.propagate_caller_callee_constants_;
}
void runOnOperation() override {
auto module = getOperation();
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
@ -58,10 +65,17 @@ struct ShapeInference
}
int64_t producer = producer_or.ValueOrDie();
for (auto func : module.getOps<FuncOp>()) {
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer,
propagate_caller_callee_constants_)))
return signalPassFailure();
}
}
private:
Option<bool> propagate_caller_callee_constants_{
*this, "propagate-caller-callee-constants",
llvm::cl::desc("Propagate constants between callers and callees"),
llvm::cl::init(true)};
};
PassRegistration<ShapeInference> pass(

View File

@ -14,23 +14,30 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <tuple>
#include <type_traits>
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
@ -46,146 +53,184 @@ bool HasOutsideCompilationAttribute(Operation* op) {
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
}
// Returns whether all operands of `op` are from values inside the
// `input_value_set`.
bool OpContainsOperandsFromSet(Operation* op,
const llvm::SetVector<Value>& input_value_set) {
for (auto operand : op->getOperands())
if (input_value_set.count(operand) == 0) return false;
Operation* GetOpOfValue(Value value) {
if (auto block_arg = value.dyn_cast<BlockArgument>())
return block_arg.getOwner()->getParentOp();
return true;
return value.getDefiningOp();
}
void RecordOutsideCompiledOpsAndUsages(
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
llvm::SetVector<Value>* outside_compiled_op_usages) {
if (HasOutsideCompilationAttribute(op) &&
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
outside_compiled_ops->insert(op);
outside_compiled_op_usages->insert(op->getResults().begin(),
op->getResults().end());
// Returns a set of ops that are outside compiled and can be extracted to before
// the TPU computation. These ops are either connected to the inputs of the TPU
// computation or other ops that can be extracted, and have no dependencies with
// other ops in the TPU computation that cannot be extracted.
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
tf_device::ClusterOp cluster) {
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
auto cluster_ops = cluster.GetBody().without_terminator();
for (Operation& cluster_op : cluster_ops) {
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
// An outside compiled op can be extracted if its operands are not from
// other ops in the cluster that cannot be extracted.
auto result = cluster_op.walk([&](Operation* op) {
for (Value operand : op->getOperands()) {
Operation* operand_op = GetOpOfValue(operand);
if (operand_op->isProperAncestor(cluster) ||
cluster_op.isAncestor(operand_op) ||
head_outside_compiled_ops.count(operand_op))
continue;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op);
}
return head_outside_compiled_ops.takeVector();
}
// Traverses the MLIR graph and returns a set of ops that
// are connected to inputs of TPU computation and outside compiled.
void ExtractOutsideCompiledOpsConnectedToHead(
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
for (auto& usage : input_value.getUses()) {
auto head_operation = usage.getOwner();
RecordOutsideCompiledOpsAndUsages(head_operation,
&parent_outside_compiled_ops_at_head,
values_used_in_host_cluster);
// Parses TPU compilation and execution devices from a TPU cluster and returns
// the host device for the head and tail computations. If the TPU computation is
// replicated, kTPUReplicatedHost is returned instead.
LogicalResult GetHostDeviceForHeadTailComputation(
mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster,
std::string* host_device) {
auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>();
if (replicate) {
*host_device = tensorflow::kTPUReplicatedHost;
return success();
}
// Traverse the graph and find all outside compiled ops connected from
// the `input_value`.
while (!parent_outside_compiled_ops_at_head.empty()) {
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
auto op_results = head_outside_compiled_op->getOpResults();
for (auto op_result : op_results) {
for (auto& use : op_result.getUses()) {
auto connected_op = use.getOwner();
RecordOutsideCompiledOpsAndUsages(connected_op,
&connected_outside_compiled_ops,
values_used_in_host_cluster);
auto num_cores_per_replica_attr =
cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr);
if (!num_cores_per_replica_attr)
return cluster.emitOpError(
"cluster op missing `num_cores_per_replica` attribute");
if (num_cores_per_replica_attr.getInt() != 1)
return cluster.emitOpError(
"outside compilation is not supported with model parallelism.");
auto topology_attr =
cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
if (!topology_attr)
return cluster.emitOpError("cluster op missing `topology` attribute");
auto device_assignment_attr =
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
if (!device_assignment_attr)
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
tensorflow::kDeviceAssignmentAttr)
.str());
auto status_or_device_coodinates =
tensorflow::GetDeviceCoordinates(device_assignment_attr);
if (!status_or_device_coodinates.ok())
return cluster.emitError()
<< "error in fetching tpu device coordinates: "
<< status_or_device_coodinates.status().error_message();
// Determine compilation and execution devices.
auto status_or_tpu_device_assignment =
tensorflow::GetTPUCompilationAndExecutionDevices(
devices.device_names(), /*num_replicas=*/1,
/*num_cores_per_replica=*/1, topology_attr.getValue(),
status_or_device_coodinates.ConsumeValueOrDie());
if (!status_or_tpu_device_assignment.ok())
return cluster.emitError()
<< "error in fetching TPU compilation/execution devices: "
<< status_or_tpu_device_assignment.status().error_message();
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
return success();
}
// Moves head outside compiled ops into its own `tf_device.LaunchOp`
// computation.
tf_device::LaunchOp CreateHeadComputation(
OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
llvm::StringRef host_device) {
Block* launch_block = new Block;
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
// Find results of ops in head computation that needs to returned.
llvm::SmallVector<Value, 4> launch_results;
llvm::SmallVector<Type, 4> launch_result_types;
for (Operation& head_outside_compiled_op : *launch_block) {
for (Value result : head_outside_compiled_op.getResults()) {
bool has_uses_in_cluster = false;
for (Operation* user : result.getUsers()) {
if (user->getParentRegion() &&
cluster.body().isAncestor(user->getParentRegion())) {
has_uses_in_cluster = true;
break;
}
}
}
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
parent_outside_compiled_ops_at_head.end());
std::swap(parent_outside_compiled_ops_at_head,
connected_outside_compiled_ops);
}
}
// TODO(hongjunchoi): Also handle ops without inputs that are outside
// compiled.
//
// Returns set of ops that are outside compiled and are directly connected
// to inputs to the TPU computation.
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
tf_device::ClusterOp tpu_cluster) {
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
llvm::SetVector<Value> values_used_in_cluster;
auto& cluster_region = tpu_cluster.body();
getUsedValuesDefinedAbove(cluster_region, cluster_region,
values_used_in_cluster);
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
for (auto input_value : input_value_list)
ExtractOutsideCompiledOpsConnectedToHead(
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
return outside_compiled_at_head_ops;
}
// Returns output values of extracted outside compiled cluster at head that
// are used by the TPU computation.
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
llvm::SmallVector<Value, 8> outputs;
outputs.reserve(head_outside_compiled_ops.size());
for (auto op : head_outside_compiled_ops) {
for (Operation* user : op->getUsers()) {
if (!head_outside_compiled_ops.count(user)) {
outputs.append(op->result_begin(), op->result_end());
break;
if (has_uses_in_cluster) {
launch_results.push_back(result);
launch_result_types.push_back(result.getType());
}
}
}
return outputs;
builder->setInsertionPoint(cluster);
auto launch = builder->create<tf_device::LaunchOp>(
cluster.getLoc(), builder->getStringAttr(host_device),
launch_result_types);
launch.body().push_back(launch_block);
builder->setInsertionPointToEnd(&launch.GetBody());
builder->create<tf_device::ReturnOp>(cluster.getLoc(), launch_results);
for (auto result : llvm::zip(launch_results, launch.getResults()))
replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
cluster.body());
return launch;
}
// Creates new tf_device.launch op with outside compiled ops extracted
// from the head of TPU computation.
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
OpBuilder* builder, tf_device::ClusterOp cluster,
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
if (head_outside_compiled_ops.empty())
return llvm::Optional<tf_device::LaunchOp>();
// Create tf_device.launch op to separate all extracted outside compiled ops
// before the tf_device.cluster.
auto output_values =
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
llvm::SmallVector<Type, 8> output_return_types;
output_return_types.reserve(output_values.size());
for (auto output : output_values)
output_return_types.emplace_back(output.getType());
builder->setInsertionPoint(cluster);
auto host_launch_op = builder->create<tf_device::LaunchOp>(
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
// Replace all usages of outside compiled ops that are used in TPU
// computation with the results of the above created launch op.
for (auto output_and_index : llvm::enumerate(output_values)) {
auto output_index = output_and_index.index();
auto output = output_and_index.value();
for (auto& use : output.getUses()) {
if (!head_outside_compiled_ops.count(use.getOwner()))
use.set(host_launch_op.getResult(output_index));
// Removes aliased outputs in cluster from head computation after head
// computation has been extracted.
void RemoveHeadComputationAliasedOutputs(OpBuilder* builder,
tf_device::LaunchOp head_computation,
tf_device::ClusterOp cluster) {
llvm::SmallVector<Value, 4> used_old_cluster_results;
llvm::SmallVector<Value, 4> new_cluster_results;
llvm::SmallVector<Type, 4> new_cluster_result_types;
Operation* cluster_terminator = cluster.GetBody().getTerminator();
for (auto result :
llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
Value cluster_terminator_operand = std::get<0>(result);
if (cluster_terminator_operand.getDefiningOp() == head_computation) {
std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
} else {
new_cluster_results.push_back(cluster_terminator_operand);
new_cluster_result_types.push_back(cluster_terminator_operand.getType());
used_old_cluster_results.push_back(std::get<1>(result));
}
}
// Create terminator op for the newly created launch op.
host_launch_op.body().push_back(new Block());
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
auto terminator = builder->create<tf_device::ReturnOp>(
host_launch_op.getLoc(), output_values);
if (new_cluster_results.size() == cluster.getNumResults()) return;
// Move all outside compile ops from cluster op to launch op.
for (auto outside_compiled_op : head_outside_compiled_ops)
outside_compiled_op->moveBefore(terminator);
builder->setInsertionPoint(cluster);
auto new_cluster = builder->create<tf_device::ClusterOp>(
cluster.getLoc(), new_cluster_result_types,
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
new_cluster.body().takeBody(cluster.body());
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
return host_launch_op;
for (auto result :
llvm::zip(used_old_cluster_results, new_cluster.getResults()))
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
cluster.erase();
}
struct TPUExtractHeadTailOutsideCompilation
@ -202,17 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
return signalPassFailure();
OpBuilder builder(&getContext());
module.walk([&](tf_device::ClusterOp cluster) {
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
head_outside_compiled_ops);
llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
module.walk(
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
// TODO(b/156030523): Update device attribute of newly created host launch
// op as well as enclosing Replicate op (if TPU computation is replicated)
// with host device names.
for (tf_device::ClusterOp cluster : clusters) {
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
FindOutsideCompiledOpsAtHead(cluster);
if (head_outside_compiled_ops.empty()) continue;
std::string host_device;
if (failed(GetHostDeviceForHeadTailComputation(devices, cluster,
&host_device)))
return signalPassFailure();
// TODO(b/155115766): Implement tail outside compiled op extraction.
});
tf_device::LaunchOp head_computation = CreateHeadComputation(
&builder, cluster, head_outside_compiled_ops, host_device);
RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster);
// TODO(b/157160906): Implement tail outside compiled op extraction.
}
}
} // anonymous namespace

Some files were not shown because too many files have changed in this diff Show More