Merge pull request #5 from tensorflow/master
Sync master with tensorflow
This commit is contained in:
commit
d221e82eb4
16
.bazelrc
16
.bazelrc
@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# config to build OneDNN backend with a user specified threadpool.
|
||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
||||
build:mkl_threadpool -c opt
|
||||
# This config refers to building with CUDA available. It does not necessarily
|
||||
# mean that we build CUDA op kernels.
|
||||
build:using_cuda --define=using_cuda=true
|
||||
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
|
||||
build:c++17 --cxxopt=-stdlib=libc++
|
||||
build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
# Enable using platform specific build settings, except when cross-compiling for
|
||||
# mobile platforms.
|
||||
build --enable_platform_specific_config
|
||||
build:android --noenable_platform_specific_config
|
||||
build:ios --noenable_platform_specific_config
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:android --copt=-w
|
||||
build:ios --copt=-w
|
||||
build:linux --copt=-w
|
||||
build:macos --copt=-w
|
||||
build:windows --copt=/w
|
||||
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
|
||||
# TF_SYSTEM_LIBS do not work on windows.
|
||||
|
||||
# By default, build TF in C++ 14 mode.
|
||||
build:android --cxxopt=-std=c++14
|
||||
build:android --host_cxxopt=-std=c++14
|
||||
build:ios --cxxopt=-std=c++14
|
||||
build:ios --host_cxxopt=-std=c++14
|
||||
build:linux --cxxopt=-std=c++14
|
||||
build:linux --host_cxxopt=-std=c++14
|
||||
build:macos --cxxopt=-std=c++14
|
||||
|
87
.github/bot_config.yml
vendored
Normal file
87
.github/bot_config.yml
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# A list of assignees
|
||||
assignees:
|
||||
- amahendrakar
|
||||
- ravikyram
|
||||
- Saduf2019
|
||||
# A list of assignees for compiler folder
|
||||
compiler_assignees:
|
||||
- joker-eph
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
* For TF-GPU - See point 1
|
||||
* For TF-CPU - See point 2
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
|
||||
|
||||
|
||||
Make sure you are using compatible TF and CUDA versions.
|
||||
Please refer following TF version and CUDA version compatibility table.
|
||||
|
||||
| TF | CUDA |
|
||||
|
||||
| :-------------: | :-------------: |
|
||||
|
||||
| 2.1.0 - 2.2.0 | 10.1 |
|
||||
|
||||
| 1.13.1 - 2.0 | 10.0 |
|
||||
|
||||
| 1.5.0 - 1.12.0 | 9.0 |
|
||||
|
||||
* If you have above configuration and using _**Windows**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
|
||||
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
|
||||
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
|
||||
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
|
||||
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
|
||||
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
|
||||
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
|
||||
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
|
||||
|
||||
|
||||
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
|
||||
|
||||
|
||||
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
|
||||
|
||||
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
|
||||
|
||||
* Try Google Colab to use TensorFlow.
|
||||
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
|
||||
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
|
||||
* All you need is a good internet connection and you are all set.
|
||||
* Try to build TF from sources by changing CPU optimization flags.
|
||||
|
||||
*Please let us know if this helps.*
|
||||
|
||||
windows_comment: >
|
||||
From the stack trace it looks like you are hitting windows path length limit.
|
||||
* Try to disable path length limit on Windows 10.
|
||||
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
|
||||
|
||||
Please let us know if this helps.
|
@ -142,6 +142,7 @@ Build Type | Status
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
|
38
RELEASE.md
38
RELEASE.md
@ -1,3 +1,41 @@
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||
the output is different from (incorrect) previous versions. Note this
|
||||
breaking change only impacts `tf.image.extract_glimpse` and
|
||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
|
||||
|
||||
# Release 2.0.2
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 1.15.3
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
|
||||
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
|
||||
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
|
||||
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
|
||||
|
||||
# Release 2.2.0
|
||||
|
||||
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).
|
||||
|
@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to not rely on this: First, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
As a general rule, it is incorrect behavior for TensorFlow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
|
@ -114,6 +114,14 @@ http_archive(
|
||||
],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "person_detect_data",
|
||||
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
|
||||
],
|
||||
)
|
||||
|
||||
# Required for dependency @com_github_grpc_grpc
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
10
configure.py
10
configure.py
@ -1368,8 +1368,13 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
try:
|
||||
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
|
||||
_TF_MAX_BAZEL_VERSION)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Error checking bazel version: ", e.output.decode('UTF-8').strip())
|
||||
raise e
|
||||
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1387,7 +1392,6 @@ def main():
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
@ -524,7 +524,10 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
package_group(
|
||||
name = "ndarray_tensor_allow_list",
|
||||
packages = ["//learning/pathways/..."],
|
||||
)
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
|
@ -216,6 +216,7 @@ tf_cuda_library(
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
@ -394,8 +395,14 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":tf_status",
|
||||
":tf_status_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
||||
TF_Status* status) {
|
||||
auto* opts = TFE_NewContextOptions();
|
||||
|
||||
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
||||
// context.
|
||||
auto* config = TF_CreateConfig(
|
||||
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
10);
|
||||
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
||||
if (!status->status.ok()) {
|
||||
CHECK(!config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
||||
TF_DeleteBuffer(config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// TODO: retrieve the device string via TFE_ContextListDevices()
|
||||
static const char DEFAULT_CPU_DEVICE[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
int tensor_id, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
||||
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
||||
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
||||
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
||||
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
||||
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
||||
shared_name.size());
|
||||
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
||||
|
||||
// TODO: consider making this an unknown shape.
|
||||
const int64_t* dims_ptr = nullptr;
|
||||
int num_dims = 0;
|
||||
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
||||
/*num_values*/ 0, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* queue = nullptr;
|
||||
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, tensor, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
||||
if (!status->status.ok()) return;
|
||||
CHECK_EQ(num_retvals, 0);
|
||||
}
|
||||
|
||||
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
||||
TF_DataType inputType,
|
||||
TFE_TensorHandle* queue,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
TFE_TensorHandle* ret;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &ret, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
||||
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||
status->status = tensorflow::errors::Internal(errMsg);
|
||||
}
|
||||
@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
||||
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
||||
const TF_DataType* values, int num_values) {
|
||||
auto iter = builder->attr_names.insert(attr_name).first;
|
||||
builder->Set(
|
||||
(*iter).c_str(),
|
||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||
reinterpret_cast<const tensorflow::DataType*>(values),
|
||||
num_values));
|
||||
}
|
||||
|
||||
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
||||
|
@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
// Create a serialized tensorflow.ServerDef proto.
|
||||
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
|
||||
|
||||
// TODO: remove this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
// Creates from `session` a new eager context to run a graph function or
|
||||
// sends/recvs, so that these concurrent TFE executions can share (via
|
||||
// `session` and its associated device mgr) the same set of fifo queue resource
|
||||
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
|
||||
// graph function execution can access the same fifo queue resource handles
|
||||
// (associated with devices managed by the device manager, which can be obtained
|
||||
// from `session`).
|
||||
//
|
||||
// TODO: Remove this function once we migrate away from using session.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
|
||||
TF_Session* session, TF_Status* status);
|
||||
|
||||
// TODO: Retire this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
|
||||
TF_Session* session, int tensor_id, TF_DataType inputType,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: consider folding the 2 APIs below into the ones above.
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||
const char* errMsg);
|
||||
|
||||
|
@ -144,6 +144,24 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_unified_internal",
|
||||
hdrs = [
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
@ -319,6 +337,7 @@ tf_cuda_cc_test(
|
||||
tags = [
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"notap", # TODO(b/156981931): flaky
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -349,7 +368,10 @@ tf_cuda_cc_test(
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
tags = [
|
||||
"noasan", # leaks gRPC server instances
|
||||
"notsan", # b/157098283
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
@ -357,10 +379,13 @@ tf_cuda_cc_test(
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -507,6 +532,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
TF_Session* sess, TF_Status* status) {
|
||||
const tensorflow::DeviceMgr* device_mgr = nullptr;
|
||||
status->status = sess->session->LocalDeviceManager(&device_mgr);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
|
||||
return tensorflow::wrap(new tensorflow::EagerContext(
|
||||
opts->session_options.options,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
opts->device_placement_policy),
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
||||
opts->async, opts->lazy_remote_inputs_copy, device_mgr,
|
||||
/*device_mgr_owned*/ false, r,
|
||||
tensorflow::GetDefaultCustomKernelCreator()));
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return;
|
||||
@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
status->status = tensorflow::Status::OK();
|
||||
#else // !defined(IS_MOBILE_PLATFORM)
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->SyncExecutors();
|
||||
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
const string first_name =
|
||||
keep_localhost_for_first_connect ? "localhost" : "abc";
|
||||
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
|
||||
tensorflow::Status status2;
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
|
||||
|
||||
// Rename local device
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev1_name =
|
||||
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
|
||||
|
||||
// Another renaming of local device
|
||||
const string second_name = "def";
|
||||
server_def.set_job_name(second_name);
|
||||
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
|
||||
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
|
||||
absl::StrCat(second_name, ":",
|
||||
tensorflow::testing::PickUnusedPortOrDie());
|
||||
|
||||
serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle2, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
TFE_DeleteTensorHandle(var_handle2);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
|
||||
|
||||
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
|
||||
|
||||
} // namespace
|
||||
|
@ -19,11 +19,16 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
@ -574,6 +579,181 @@ TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||
TestFunctionWithPackedInput(/*remote=*/true);
|
||||
}
|
||||
|
||||
string VariableAddFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'VariableAddFunction'"
|
||||
" input_arg {"
|
||||
" name: 'var0'"
|
||||
" type: DT_RESOURCE"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'var0_value'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'read0'"
|
||||
" op: 'ReadVariableOp'"
|
||||
" input: 'var0'"
|
||||
" attr {"
|
||||
" key: 'dtype'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'add'"
|
||||
" op: 'Add'"
|
||||
" input: 'read0:value:0'"
|
||||
" input: 'read0:value:0'"
|
||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'identity'"
|
||||
" op: 'Identity'"
|
||||
" input: 'add:z:0'"
|
||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'var0_value'"
|
||||
" value: 'identity:output:0'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||
public:
|
||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||
: error_node_(error_node), error_device_(error_device) {}
|
||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
std::unique_ptr<tensorflow::Graph>* graph,
|
||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override {
|
||||
// Inject failure to function instantiation if finding a node that contains
|
||||
// the given node name (error_node_) and requested device (error_device_).
|
||||
for (const auto node : graph->get()->nodes()) {
|
||||
if (node->name().find(error_node_) != string::npos &&
|
||||
node->requested_device() == error_device_) {
|
||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
const string error_node_;
|
||||
const string error_device_;
|
||||
};
|
||||
|
||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
if (inject_error) {
|
||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||
// having a requested device `dev2_name`. During execution:
|
||||
// * task:0 processes the main function `VariableAddFunction` and places
|
||||
// the read0 op on task:2
|
||||
// * task:0 partitions the main function with a subgraph containing read0
|
||||
// sent to task:2
|
||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||
tensorflow::function_optimization_registration::
|
||||
FunctionOptimizationPassRegistration register_test_pass(
|
||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle, nullptr);
|
||||
|
||||
const string function_def = VariableAddFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(func, var_handle, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||
|
||||
if (inject_error) {
|
||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
float sum = 0;
|
||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
ASSERT_EQ(sum, 4.0);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(func);
|
||||
TFE_DeleteTensorHandle(var_handle);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionNoError) {
|
||||
TestDistributedFunctionCancellation(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||
TestDistributedFunctionCancellation(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
|
@ -58,7 +58,7 @@ T* dyncast(S source) {
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -101,7 +101,7 @@ class AbstractFunction {
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
@ -129,7 +129,7 @@ class AbstractOp {
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
|
@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Values("graphdef", "mlir"));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -101,6 +101,9 @@ class AbstractContextInterface {
|
||||
// Destroy the step resource container for a training step.
|
||||
virtual void EndStep() = 0;
|
||||
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
|
||||
delete_function_(delete_function),
|
||||
rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||
server_def, init_function_, start_function_, stop_function_,
|
||||
|
@ -155,6 +155,7 @@ cc_library(
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
return tensorflow::wrap(result.release());
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) {
|
||||
delete tensorflow::unwrap(model);
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
tensorflow::unwrap(model)->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key,
|
||||
&result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
return new TF_ConcreteFunctionList{
|
||||
tensorflow::unwrap(model)->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -18,13 +18,18 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
@ -20,7 +20,7 @@ load(
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags")
|
||||
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
||||
|
||||
def tf_library(
|
||||
name,
|
||||
@ -42,7 +42,8 @@ def tf_library(
|
||||
mlir_components = "None",
|
||||
deps = None,
|
||||
tags = []):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
|
||||
math enabled on cpu.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
@ -187,7 +188,9 @@ def tf_library(
|
||||
# `find` on such an object.
|
||||
need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1
|
||||
|
||||
flags = tfcompile_extra_flags() + flags
|
||||
target_cpu = tfcompile_target_cpu()
|
||||
extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " "
|
||||
flags = extra_flags + flags
|
||||
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
@ -207,6 +210,15 @@ def tf_library(
|
||||
srcs.append(debug_info)
|
||||
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
|
||||
|
||||
default_fast_math_xla_flags = ("XLA_FLAGS='" +
|
||||
"--xla_cpu_enable_fast_math=true " +
|
||||
"--xla_cpu_fast_math_honor_nans=false " +
|
||||
"--xla_cpu_fast_math_honor_infs=false " +
|
||||
"--xla_cpu_fast_math_honor_functions=false " +
|
||||
"--xla_cpu_fast_math_honor_division=false " +
|
||||
"--xla_cpu_enable_fast_min_max=true " +
|
||||
"$${XLA_FLAGS:-}' ")
|
||||
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = srcs,
|
||||
@ -216,6 +228,7 @@ def tf_library(
|
||||
function_object_file,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
@ -256,6 +269,7 @@ def tf_library(
|
||||
session_module_pb,
|
||||
],
|
||||
cmd = (
|
||||
default_fast_math_xla_flags +
|
||||
"CUDA_VISIBLE_DEVICES='' " +
|
||||
"$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
|
@ -67,6 +67,8 @@ int main(int argc, char** argv) {
|
||||
flags.entry_point = "entry";
|
||||
flags.debug_info_path_begin_marker = "";
|
||||
|
||||
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
|
||||
// generally the preferred case.
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
xla::AppendDebugOptionsFlags(&flag_list);
|
||||
|
@ -505,6 +505,7 @@ cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
hdrs = ["shape_inference.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":shape_inference_helpers",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorArraySplitV3",
|
||||
"TensorArrayV3",
|
||||
"TensorArrayWriteV3",
|
||||
"TensorListConcatV2",
|
||||
"TensorListElementShape",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGather",
|
||||
@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
"TensorListPushBack",
|
||||
"TensorListReserve",
|
||||
"TensorListSetItem",
|
||||
"TensorListSplit",
|
||||
"TensorListStack",
|
||||
"TensorScatterAdd",
|
||||
"TensorScatterSub",
|
||||
|
@ -104,6 +104,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -260,6 +260,41 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tftext_utils",
|
||||
srcs = [
|
||||
"utils/tftext_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/tftext_utils.h",
|
||||
],
|
||||
copts = ["-std=c++14"],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tftext_utils_test",
|
||||
size = "small",
|
||||
srcs = ["utils/lstm_utils_test.cc"],
|
||||
deps = [
|
||||
":lstm_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stateful_ops_utils",
|
||||
srcs = [
|
||||
@ -320,6 +355,7 @@ cc_library(
|
||||
":lstm_utils",
|
||||
":stateful_ops_utils",
|
||||
":tensorflow_lite",
|
||||
":tftext_utils",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
|
@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
|
||||
|
||||
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
std::string node_def_str;
|
||||
if (!node_def.SerializeToString(&node_def_str)) {
|
||||
return emitError(loc, "failed to serialize tensorflow node_def"),
|
||||
llvm::None;
|
||||
}
|
||||
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
|
||||
return builder_.CreateVector(flex_builder->GetBuffer());
|
||||
}
|
||||
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
|
||||
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
|
||||
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
|
||||
size_t map_start = flex_builder->StartMap();
|
||||
for (const auto& pair : node_def.attr()) {
|
||||
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
|
||||
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
|
||||
std::sort(attrs.begin(), attrs.end(),
|
||||
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
|
||||
for (const Item& pair : attrs) {
|
||||
const char* key = pair.first.c_str();
|
||||
const auto& attr = pair.second;
|
||||
const ::tensorflow::AttrValue& attr = pair.second;
|
||||
switch (attr.value_case()) {
|
||||
case ::tensorflow::AttrValue::kS:
|
||||
flex_builder->String(key, attr.s());
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -254,6 +254,14 @@ class TFL_TFOperandTypesWithSameBits<int i, int j, int num> :
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRank<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[
|
||||
@ -285,14 +293,18 @@ def TFL_FloatNonNegative : AttrConstraint<
|
||||
CPred<"!$_self.cast<FloatAttr>().getValue().isNegative()">,
|
||||
"whose value is non-negative">;
|
||||
|
||||
def TFL_BoolTrue: AttrConstraint<
|
||||
def TFL_BoolTrue : AttrConstraint<
|
||||
CPred<"$_self.cast<BoolAttr>().getValue()">,
|
||||
"whose value is true">;
|
||||
|
||||
def TFL_BoolFalse: AttrConstraint<
|
||||
def TFL_BoolFalse : AttrConstraint<
|
||||
CPred<"!$_self.cast<BoolAttr>().getValue()">,
|
||||
"whose value is false">;
|
||||
|
||||
class TFL_StringEqualsTo<string value> : AttrConstraint<
|
||||
CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">,
|
||||
"whose value equals to '" # value # "'">;
|
||||
|
||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||
TCOpResIsShapedTypePred<i, j>,
|
||||
@ -1892,7 +1904,10 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TFL_RoundOp: TFL_Op<"round", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Round operator";
|
||||
|
||||
let description = [{
|
||||
@ -1909,7 +1924,14 @@ Rounds the values of a tensor to the nearest integer, element-wise.
|
||||
}
|
||||
|
||||
def TFL_SliceOp : TFL_Op<"slice", [
|
||||
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankAtMost<0, 4>,
|
||||
TFL_OperandHasRankAtMost<1, 1>,
|
||||
TFL_OperandHasRankAtMost<2, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Return a slice from 'input'.";
|
||||
|
||||
let description = [{
|
||||
@ -1927,13 +1949,13 @@ equivalent to setting:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTensor:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$begin,
|
||||
TFL_I32OrI64Tensor:$size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
AnyTensor:$output
|
||||
TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
@ -1961,7 +1983,10 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
|
||||
}
|
||||
|
||||
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Min-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1969,19 +1994,23 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTensor:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor);
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Max-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -1989,18 +2018,22 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTensor:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor);
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
}
|
||||
|
||||
def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
|
||||
def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect]> {
|
||||
let summary = "Prod-reduction operator";
|
||||
|
||||
let description = [{
|
||||
@ -2008,12 +2041,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I32, I64]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32Tensor:$axes,
|
||||
BoolAttr:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor);
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
let customOption = "ReducerOptions";
|
||||
@ -2308,10 +2342,13 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_ReluOp: TFL_Op<"relu", [
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Relu operator";
|
||||
|
||||
let description = [{
|
||||
@ -2319,9 +2356,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
x -> max(0, x)
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
@ -2335,10 +2372,13 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect,
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
def TFL_Relu6Op: TFL_Op<"relu6", [
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Relu6 operator";
|
||||
|
||||
let description = [{
|
||||
@ -2346,9 +2386,9 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
x -> max(0, min(6, x))
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
@ -2362,9 +2402,12 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect,
|
||||
];
|
||||
}
|
||||
|
||||
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale]> {
|
||||
def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [
|
||||
PredOpTrait<"x and y must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultShape,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "Relu1 operator";
|
||||
|
||||
let description = [{
|
||||
@ -2372,9 +2415,9 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect,
|
||||
x -> max(-1, min(1, x))
|
||||
}];
|
||||
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x);
|
||||
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y);
|
||||
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y);
|
||||
|
||||
// This builder doesn't work with quantized type, so it can only be used by
|
||||
// non-quantization tablegen patterns. Currently, it is used by the
|
||||
@ -2406,7 +2449,11 @@ def TFL_ReshapeOp: TFL_Op<"reshape", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [NoSideEffect]> {
|
||||
def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
TFL_OperandHasRank<1, 1>]> {
|
||||
let summary = "Reverses variable length slices.";
|
||||
|
||||
let description = [{
|
||||
@ -2423,15 +2470,15 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input,
|
||||
TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32OrI64Tensor:$seq_lengths,
|
||||
|
||||
I32Attr:$seq_dim,
|
||||
I32Attr:$batch_dim
|
||||
Confined<I32Attr, [IntNonNegative]>:$seq_dim,
|
||||
Confined<I32Attr, [IntNonNegative]>:$batch_dim
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -2439,6 +2486,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
|
||||
|
||||
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
|
||||
SameOperandsAndResultType,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Reciprocal of square root operator";
|
||||
@ -2463,7 +2511,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins AnyTensor:$input);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[I32, I64]>:$output);
|
||||
|
||||
DerivedTypeAttr out_type = DerivedTypeAttr<[{
|
||||
return getResult().getType().cast<TensorType>().getElementType();
|
||||
@ -2472,9 +2520,11 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Flesh this out.
|
||||
def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>,
|
||||
TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>,
|
||||
def TFL_RangeOp: TFL_Op<"range", [
|
||||
NoSideEffect,
|
||||
TFL_OperandHasRank<0, 0>,
|
||||
TFL_OperandHasRank<1, 0>,
|
||||
TFL_OperandHasRank<2, 0>,
|
||||
PredOpTrait<"operands and output must have same element type",
|
||||
And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>,
|
||||
TCresVTEtIsSameAsOp<0, 2>]>>]> {
|
||||
@ -2486,17 +2536,20 @@ def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>,
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyTensor:$start,
|
||||
AnyTensor:$limit,
|
||||
AnyTensor:$delta);
|
||||
TFL_TensorOf<[I32, F32]>:$start,
|
||||
TFL_TensorOf<[I32, F32]>:$limit,
|
||||
TFL_TensorOf<[I32, F32]>:$delta);
|
||||
|
||||
let results = (outs AnyTensor:$result);
|
||||
let results = (outs TFL_TensorOf<[I32, F32]>:$result);
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
[NoSideEffect, TFL_OperandHasRank<1,1>]> {
|
||||
def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoSideEffect,
|
||||
TFL_OperandHasRank<1, 1>]> {
|
||||
let summary = "ReverseV2 Operator";
|
||||
|
||||
let description = [{
|
||||
@ -2518,18 +2571,18 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
|
||||
|
||||
let arguments = (
|
||||
ins
|
||||
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input,
|
||||
TFL_TensorOf<[I32, I64]>:$axis
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input,
|
||||
TFL_I32Tensor:$axis
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output
|
||||
);
|
||||
TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output);
|
||||
}
|
||||
|
||||
// Select has many instances in TF models where one or more of its operands
|
||||
// are unranked. Therefore, we skip adding shape constraints here.
|
||||
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
def TFL_SelectOp : TFL_Op<"select", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
@ -2545,9 +2598,11 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
|
||||
let arguments = (ins
|
||||
TFL_BoolTensor:$condition,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
|
||||
let results = (outs AnyTensor:$output);
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
|
||||
|
||||
// TODO(jpienaar): autogenerate this.
|
||||
let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
@ -2561,7 +2616,12 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
|
||||
def TFL_SelectV2Op : TFL_Op<"select_v2", [
|
||||
NoSideEffect,
|
||||
TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>,
|
||||
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
|
||||
PredOpTrait<"operands and result have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 1>>]> {
|
||||
let summary = "SelectV2 operator";
|
||||
|
||||
let description = [{
|
||||
@ -2574,9 +2634,11 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins
|
||||
TFL_BoolTensor:$condition,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y);
|
||||
let results = (outs AnyTensor:$output);
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x,
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output);
|
||||
|
||||
let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
"Value cond, Value x, Value y",
|
||||
@ -2605,9 +2667,11 @@ def TFL_SinOp: TFL_Op<"sin", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
// TODO(b/130643170): Adds some constraint for the input/output element types.
|
||||
def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRankRange<0, 1, 4>,
|
||||
SameOperandsAndResultShape,
|
||||
// zero_point = 0
|
||||
// scale = 1. / (max_value + 1)
|
||||
@ -2623,11 +2687,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
F32Attr:$beta
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
@ -2914,6 +2978,7 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
|
||||
def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_OperandHasRankRange<0, 3, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TCresVTEtIsSameAsOp<0, 0>>
|
||||
]> {
|
||||
@ -2924,13 +2989,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input,
|
||||
TFL_TensorOf<[I32]>:$block_shape,
|
||||
TFL_TensorOf<[I32]>:$paddings
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32Tensor:$block_shape,
|
||||
TFL_I32Tensor:$paddings
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
@ -3045,7 +3110,12 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]
|
||||
}
|
||||
|
||||
def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "ResizeBilinear Op";
|
||||
|
||||
let description = [{
|
||||
@ -3053,23 +3123,26 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
// TODO(ycling): Support quantized types.
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8]>:$input,
|
||||
TFL_TensorOf<[I32]>:$size,
|
||||
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input,
|
||||
TFL_I32Tensor:$size,
|
||||
BoolAttr:$align_corners,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$output
|
||||
TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
|
||||
[NoSideEffect,
|
||||
SameOperandsAndResultsScale]> {
|
||||
def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
TFL_OperandHasRank<1, 1>,
|
||||
SameOperandsAndResultsScale]> {
|
||||
let summary = "ResizeNearestNeighbor Op";
|
||||
|
||||
let description = [{
|
||||
@ -3077,14 +3150,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input,
|
||||
TFL_TensorOf<[I32]>:$size,
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input,
|
||||
TFL_I32Tensor:$size,
|
||||
BoolAttr:$align_corners,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output
|
||||
TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
@ -3349,7 +3422,9 @@ def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [
|
||||
}
|
||||
|
||||
def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
FirstAttrDerivedResultType, NoQuantizableResult]> {
|
||||
FirstAttrDerivedResultType,
|
||||
SameOperandsAndResultShape,
|
||||
NoQuantizableResult]> {
|
||||
let summary = "Quantize operator";
|
||||
|
||||
let description = [{
|
||||
@ -3358,11 +3433,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins AnyTensor:$input,
|
||||
ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input,
|
||||
TensorTypeAttr:$qtype
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[QI8, QUI8, QI16, TFL_Quint8]>:$output);
|
||||
}
|
||||
|
||||
def TFL_DensifyOp: TFL_Op<"densify", [
|
||||
@ -3472,6 +3547,19 @@ def TFL_LSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
TFL_OperandHasRank<2, 2>, // input_to_forget_weights
|
||||
TFL_OperandHasRank<3, 2>, // input_to_cell_weights
|
||||
TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights
|
||||
TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights
|
||||
TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights
|
||||
TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights
|
||||
TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights
|
||||
TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights
|
||||
TFL_OperandHasRank<13, 1>, // forget_gate_bias
|
||||
TFL_OperandHasRank<14, 1>, // cell_gate_bias
|
||||
TFL_OperandHasRank<15, 1>, // output_gate_bias
|
||||
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
|
||||
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "The full lstm operator";
|
||||
|
||||
@ -3498,23 +3586,23 @@ Ba et al. 'Layer Normalization'
|
||||
ins TFL_TensorOf<[F32, QI8]>:$input,
|
||||
|
||||
// Weights
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
|
||||
|
||||
// Recurrent weights
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
|
||||
|
||||
// Cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_input_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_forget_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_output_weights,
|
||||
|
||||
// Bias
|
||||
TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias,
|
||||
@ -3523,7 +3611,7 @@ Ba et al. 'Layer Normalization'
|
||||
TFL_TensorOf<[F32, QI32]>:$output_gate_bias,
|
||||
|
||||
// Projection weight and bias
|
||||
TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias,
|
||||
|
||||
@ -3539,8 +3627,8 @@ Ba et al. 'Layer Normalization'
|
||||
|
||||
// Attributes
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
|
||||
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
|
||||
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
|
||||
// Since this op is the FULL kernel only, constrain it.
|
||||
Confined<
|
||||
DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
|
||||
@ -3580,6 +3668,24 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
LstmOptionalPeepholeWeightConstraint,
|
||||
LstmProjectionWeightBiasConstraint,
|
||||
LstmResultConstraint,
|
||||
TFL_OperandHasRankAtLeast<0, 2>, // input
|
||||
TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights
|
||||
TFL_OperandHasRank<2, 2>, // input_to_forget_weights
|
||||
TFL_OperandHasRank<3, 2>, // input_to_cell_weights
|
||||
TFL_OperandHasRank<4, 2>, // input_to_output_weights
|
||||
TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights
|
||||
TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights
|
||||
TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights
|
||||
TFL_OperandHasRank<8, 2>, // recurrent_to_output_weights
|
||||
TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights
|
||||
TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights
|
||||
TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights
|
||||
TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias
|
||||
TFL_OperandHasRank<13, 1>, // forget_gate_bias
|
||||
TFL_OperandHasRank<14, 1>, // cell_gate_bias
|
||||
TFL_OperandHasRank<15, 1>, // output_gate_bias
|
||||
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
|
||||
TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Unidirectional sequence lstm operator";
|
||||
|
||||
@ -3595,35 +3701,35 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8]>:$input,
|
||||
ins TFL_FpTensor:$input,
|
||||
|
||||
// Weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_output_weights,
|
||||
|
||||
// Recurrent weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights,
|
||||
|
||||
// Cell weights
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_input_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_forget_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_output_weights,
|
||||
|
||||
// Bias
|
||||
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$forget_gate_bias,
|
||||
TFL_TensorOf<[F32]>:$cell_bias,
|
||||
TFL_TensorOf<[F32]>:$output_gate_bias,
|
||||
TFL_FpTensor:$forget_gate_bias,
|
||||
TFL_FpTensor:$cell_bias,
|
||||
TFL_FpTensor:$output_gate_bias,
|
||||
|
||||
// Projection weight and bias
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights,
|
||||
// Optional input
|
||||
TFL_TensorOfOrNone<[F32]>:$projection_bias,
|
||||
|
||||
@ -3632,19 +3738,19 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
TFL_StatefulTensor:$input_cell_state,
|
||||
|
||||
// Layer norm coefficients
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$input_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$forget_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$cell_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, I8]>:$output_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$input_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$forget_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$cell_layer_norm_coefficients,
|
||||
TFL_TensorOfOrNone<[F32, QI8]>:$output_layer_norm_coefficients,
|
||||
|
||||
// Attributes
|
||||
TFL_AFAttr:$fused_activation_function,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
|
||||
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
|
||||
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip,
|
||||
Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip,
|
||||
BoolAttr:$time_major
|
||||
);
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -3841,15 +3947,14 @@ def TFL_BidirectionalSequenceLSTMOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def RnnResultConstraint : PredOpTrait<
|
||||
"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
||||
// UnidirectionalSequenceRNN op.
|
||||
def TFL_UnidirectionalSequenceRNNOp :
|
||||
TFL_Op<"unidirectional_sequence_rnn",
|
||||
[RnnResultConstraint, TFL_StatefulOp]> {
|
||||
|
||||
def TFL_UnidirectionalSequenceRNNOp : TFL_Op<"unidirectional_sequence_rnn", [
|
||||
TFL_OperandHasRank<4, 2>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
PredOpTrait<"input and constant value operands must have same element type",
|
||||
TFL_TCopVTEtAreSameAt<1, 2>>,
|
||||
TFL_StatefulOp]> {
|
||||
let summary = "Unidirectional sequence rnn operator";
|
||||
|
||||
let description = [{
|
||||
@ -3866,16 +3971,16 @@ def TFL_UnidirectionalSequenceRNNOp :
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8]>:$input,
|
||||
ins TFL_FpTensor:$input,
|
||||
|
||||
// Weights
|
||||
TFL_TensorOf<[F32, I8]>:$input_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$input_to_input_weights,
|
||||
|
||||
// Recurrent weights
|
||||
TFL_TensorOf<[F32, I8]>:$recurrent_to_input_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$recurrent_to_input_weights,
|
||||
|
||||
// Bias
|
||||
TFL_TensorOf<[F32]>:$input_gate_bias,
|
||||
TFL_FpTensor:$input_gate_bias,
|
||||
|
||||
// Hidden state.
|
||||
TFL_StatefulTensor:$hidden_state,
|
||||
@ -3885,7 +3990,7 @@ def TFL_UnidirectionalSequenceRNNOp :
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8]>:$output);
|
||||
let results = (outs TFL_FpTensor:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -3941,14 +4046,12 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def SVDFResultConstraint: PredOpTrait<
|
||||
"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>;
|
||||
|
||||
// SVDF op.
|
||||
def TFL_SVDFOp :
|
||||
TFL_Op<"svdf",
|
||||
[SVDFResultConstraint, TFL_StatefulOp]> {
|
||||
TFL_Op<"svdf", [
|
||||
PredOpTrait<"the input and result tensor elemental types must be same",
|
||||
TCresVTEtIsSameAsOp<0, 0>>,
|
||||
TFL_StatefulOp]> {
|
||||
|
||||
let summary = "Single value decomposition filter operator";
|
||||
|
||||
@ -3960,13 +4063,13 @@ def TFL_SVDFOp :
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I8]>:$input,
|
||||
ins TFL_TensorOf<[F32, QI8]>:$input,
|
||||
|
||||
// Feature Weights.
|
||||
TFL_TensorOf<[F32, I8]>:$feature_weights,
|
||||
TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights,
|
||||
|
||||
// Time weights
|
||||
TFL_TensorOf<[F32, I8]>:$time_weights,
|
||||
TFL_TensorOf<[F32, QI8]>:$time_weights,
|
||||
|
||||
// Bias
|
||||
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
|
||||
@ -3975,11 +4078,11 @@ def TFL_SVDFOp :
|
||||
TFL_StatefulTensor:$activation_state,
|
||||
|
||||
// Attributes
|
||||
I32Attr:$rank,
|
||||
Confined<I32Attr, [IntPositive]>:$rank,
|
||||
TFL_AFAttr:$fused_activation_function
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I8]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, QI8]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
||||
@ -3991,7 +4094,10 @@ def TFL_SVDFOp :
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
def TFL_SegmentSumOp: TFL_Op<"segment_sum", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
|
||||
let summary = "SegmentSum operator";
|
||||
|
||||
let description = [{
|
||||
@ -3999,7 +4105,7 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32, I32]>:$data,
|
||||
TFL_TensorOf<[F32, I32]>:$input,
|
||||
TFL_I32Tensor:$segment_ids
|
||||
);
|
||||
let results = (outs TFL_TensorOf<[F32, I32]>:$output);
|
||||
@ -4030,7 +4136,7 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
|
||||
|
||||
input: A list of input tensors whose types are T.
|
||||
output: A list of output tensors whose types are T.
|
||||
cond: A region takes 'input' and returns a boolean scalar tensor.
|
||||
cond: A region that takes 'input' and returns a boolean scalar tensor.
|
||||
body: A region that takes a list of tensors and returns another
|
||||
list of tensors. Both lists have the same types.
|
||||
}];
|
||||
|
@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
|
||||
return DT_STRING;
|
||||
case toco::IODataType::BOOL:
|
||||
return DT_BOOL;
|
||||
case toco::IODataType::COMPLEX64:
|
||||
return DT_COMPLEX64;
|
||||
default:
|
||||
return DT_INVALID;
|
||||
}
|
||||
@ -252,7 +254,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
std::string error_message;
|
||||
auto output = mlir::openOutputFile(filename, &error_message);
|
||||
if (!error_message.empty()) {
|
||||
return errors::InvalidArgument("Failed to open file in %s.", filename);
|
||||
return errors::InvalidArgument("Failed to open file in ", filename);
|
||||
}
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
@ -35,6 +36,7 @@ limitations under the License.
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
||||
}
|
||||
};
|
||||
|
||||
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
|
||||
template <typename RQ>
|
||||
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
explicit FoldTrivalRequantizeOp(MLIRContext* context)
|
||||
: OpRewritePattern<RQ>(context, 1) {}
|
||||
|
||||
LogicalResult matchAndRewrite(RQ op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
Value pre_quantized = op.input();
|
||||
auto pre_quantized_type =
|
||||
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
|
||||
if (!pre_quantized_type) return failure();
|
||||
|
||||
Operation* def = pre_quantized.getDefiningOp();
|
||||
if (!def) return failure();
|
||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
|
||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
|
||||
|
||||
llvm::SmallVector<Type, 4> new_output_types;
|
||||
for (auto result : def->getResults()) {
|
||||
result.getUsers().begin()->dump();
|
||||
op.dump();
|
||||
if (result.hasOneUse() && *result.getUsers().begin() == op) {
|
||||
new_output_types.push_back(op.qtype());
|
||||
} else {
|
||||
new_output_types.push_back(result.getType());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove this rescale op.
|
||||
rewriter.replaceOp(op, {pre_quantized});
|
||||
|
||||
// Replace the output scale of the preceding op.
|
||||
rewriter.setInsertionPointAfter(def);
|
||||
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
|
||||
def->getOperands(), new_output_types,
|
||||
def->getAttrs());
|
||||
Operation* new_op = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(def, new_op->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Given a quantized type `input`, magnifying its scales by the factor stored in
|
||||
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
|
||||
// dimension size of `input` or isn't floating-point, nullptr will be returned.
|
||||
|
@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
|
||||
return %1 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacent
|
||||
// CHECK: %cst = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return
|
||||
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %[[RESHAPE]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output has more than one
|
||||
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
|
||||
return %3 : tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
|
||||
// CHECK: %cst = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %2 = addf %0, %1
|
||||
// CHECK: return %2
|
||||
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be kept if its output has more than one
|
||||
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
|
||||
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
|
||||
|
||||
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
|
||||
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
|
||||
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
|
||||
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %0, %1
|
||||
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
|
||||
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
|
||||
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
|
||||
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
|
||||
}
|
||||
|
||||
// Checks that tfl.reshape should be removed if its output type is the same
|
||||
|
@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<9> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
%2 = constant dense< 4> : tensor<4xi32>
|
||||
%3 = constant dense<-2> : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<7> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
|
||||
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
|
||||
|
||||
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
|
||||
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
|
||||
%2 = constant dense< 3.5> : tensor<4xf32>
|
||||
%3 = constant dense<-0.5> : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
|
||||
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
|
||||
|
||||
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
|
||||
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_int
|
||||
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_same_shape
|
||||
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
|
||||
|
||||
return %2 : tensor<4xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
|
||||
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
|
||||
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
|
||||
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
|
||||
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
return %0 : tensor<2x2xi32>
|
||||
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_splat_float
|
||||
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_splat_dense_float
|
||||
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_same_shape
|
||||
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
|
||||
|
||||
return %2 : tensor<4xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
|
||||
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
|
||||
|
||||
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: return %cst, %cst_0, %cst_1
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
|
||||
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
|
||||
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rank
|
||||
func @rank() -> tensor<1xi32> {
|
||||
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rank_input_known_rank
|
||||
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
|
||||
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
|
||||
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%shape = constant dense<[4]> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
|
||||
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
|
||||
%shape = constant dense<[4]> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
|
||||
|
||||
// CHECK-LABEL: @pseudo_const
|
||||
func @pseudo_const() -> tensor<i32> {
|
||||
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
|
||||
%cst_1 = constant dense<4> : tensor<i32>
|
||||
%cst_2 = constant dense<1> : tensor<i32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<4.0> : tensor<f32>
|
||||
%cst_2 = constant dense<1.0> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<-4.0> : tensor<f32>
|
||||
%cst_2 = constant dense<-1.0> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
|
||||
%cst_1 = constant dense<7.0> : tensor<f32>
|
||||
%cst_2 = constant dense<1.5> : tensor<f32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
|
||||
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst_perm = constant dense<0> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
|
||||
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
|
||||
%cst_perm = constant dense<0> : tensor<1xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
|
||||
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
|
||||
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
|
||||
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
|
||||
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
|
||||
return %0 : tensor<4x2x3xi32>
|
||||
}
|
||||
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
|
||||
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %87 : tensor<?xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
|
||||
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
|
||||
|
||||
return %2 : tensor<?xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concat_2_tensors_1_empty
|
||||
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
|
||||
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
|
||||
return %3 : tensor<2xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
|
||||
// CHECK: return [[cst]] : tensor<2xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
|
||||
// CHECK: return %[[CST]] : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concat_3_tensors_1_empty
|
||||
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
|
||||
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
|
||||
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
|
||||
// CHECK: return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
|
||||
return %0 : tensor<2x2x3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concatConstantTensorsMiddleDim
|
||||
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
|
||||
return %0 : tensor<1x4x3xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @concatConstantTensorsLastDim
|
||||
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
|
||||
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
|
||||
return %0 : tensor<1x2x6xi32>
|
||||
|
||||
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
|
||||
// CHECK-NOT: constant-dense
|
||||
// CHECK-NOT: "tfl.concatenation"
|
||||
// CHECK: return [[cst]]
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
|
||||
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
|
||||
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div_dense_different_rank
|
||||
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
|
||||
|
||||
return %0 : tensor<1x2x2xf32>
|
||||
|
||||
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %cst
|
||||
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
|
||||
// CHECK: return %[[CST]]
|
||||
}
|
||||
|
@ -0,0 +1,101 @@
|
||||
# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 5
|
||||
}
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: 3
|
||||
}
|
||||
dim {
|
||||
size: 7
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "MatMul"
|
||||
op: "BatchMatMulV2"
|
||||
input: "Placeholder"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "adj_x"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "adj_y"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 175
|
||||
}
|
||||
|
||||
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
|
||||
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
|
||||
# CHECK: %[[VAL_5:.*]] = constant unit
|
||||
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
|
||||
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
|
||||
# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
|
||||
# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32>
|
||||
# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32>
|
||||
# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32>
|
||||
# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32>
|
||||
# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
|
||||
# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32>
|
||||
# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32>
|
||||
# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32>
|
||||
# CHECK: }
|
@ -1,15 +1,15 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
|
||||
// Ensure lstm roundtrip exactly
|
||||
|
||||
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
// CHECK-LABEL: main
|
||||
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
|
||||
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
// CHECK: return %[[RES0]]
|
||||
|
||||
}
|
||||
|
14
tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir
Normal file
14
tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure
|
||||
module {
|
||||
|
||||
func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
%0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>)
|
||||
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<?xi64>
|
||||
%2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<?xi64>)
|
||||
return %2#0, %2#1 : tensor<?x!tf.string>, tensor<?xi64>
|
||||
}
|
||||
|
||||
// CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
|
||||
// CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
|
||||
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
|
||||
}
|
@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2I64Axis
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: opcode_index: 1,
|
||||
// CHECK-NEXT: inputs: [ 2, 1 ],
|
||||
// CHECK-NEXT: outputs: [ 3 ],
|
||||
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: opcode_index: 2,
|
||||
// CHECK-NEXT: inputs: [ 3 ],
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
@ -72,21 +72,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 10,
|
||||
// CHECK-NEXT: name: "arg9",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 11,
|
||||
// CHECK-NEXT: name: "arg10",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 12,
|
||||
// CHECK-NEXT: name: "arg11",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -100,21 +100,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 14,
|
||||
// CHECK-NEXT: name: "arg13",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 15,
|
||||
// CHECK-NEXT: name: "arg14",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 16,
|
||||
// CHECK-NEXT: name: "arg15",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -128,7 +128,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 1, 4 ],
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: buffer: 18,
|
||||
// CHECK-NEXT: name: "arg17",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -261,9 +261,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t
|
||||
// CHECK-EMPTY:
|
||||
|
||||
|
||||
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
|
||||
^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>):
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
|
||||
|
||||
func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK: {
|
||||
// CHECK-NEXT: version: 3,
|
||||
// CHECK-NEXT: operator_codes: [ {
|
||||
@ -9,63 +9,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: subgraphs: [ {
|
||||
// CHECK-NEXT: tensors: [ {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 1,
|
||||
// CHECK-NEXT: name: "arg0",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 2,
|
||||
// CHECK-NEXT: name: "arg1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 3,
|
||||
// CHECK-NEXT: name: "arg2",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 4,
|
||||
// CHECK-NEXT: name: "arg3",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 5,
|
||||
// CHECK-NEXT: name: "arg4",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 6,
|
||||
// CHECK-NEXT: name: "arg5",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 7,
|
||||
// CHECK-NEXT: name: "arg6",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 8,
|
||||
// CHECK-NEXT: name: "arg7",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 9,
|
||||
// CHECK-NEXT: name: "arg8",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -121,63 +121,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 17,
|
||||
// CHECK-NEXT: name: "arg16",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 18,
|
||||
// CHECK-NEXT: name: "arg17",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 19,
|
||||
// CHECK-NEXT: name: "arg18",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 20,
|
||||
// CHECK-NEXT: name: "arg19",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 21,
|
||||
// CHECK-NEXT: name: "arg20",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 22,
|
||||
// CHECK-NEXT: name: "arg21",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: name: "Const1",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: },
|
||||
// CHECK-NEXT: is_variable: true
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: buffer: 25,
|
||||
// CHECK-NEXT: name: "tfl.unidirectional_sequence_lstm",
|
||||
// CHECK-NEXT: quantization: {
|
||||
@ -244,9 +244,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
@ -259,9 +259,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-EMPTY:
|
||||
|
||||
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %2 : tensor<4xf32>
|
||||
^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %2 : tensor<4x4xf32>
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: shape: [ 4 ],
|
||||
// CHECK-NEXT: shape: [ 4, 4 ],
|
||||
// CHECK-NEXT: name: "Const",
|
||||
// CHECK-NEXT: quantization: {
|
||||
// CHECK-EMPTY:
|
||||
@ -76,7 +76,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
@ -90,7 +90,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -
|
||||
// CHECK-EMPTY:
|
||||
|
||||
^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>):
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
@ -190,9 +190,9 @@ func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
|
||||
return %0 : tensor<? x f32>
|
||||
}
|
||||
|
||||
func @testQuantizedResizeNearestNeighbor(tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
|
||||
^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
|
||||
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
func @testQuantizedResizeNearestNeighbor(tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
|
||||
^bb0(%arg0: tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
|
||||
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
|
||||
}
|
||||
|
||||
@ -581,36 +581,36 @@ func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceRnn
|
||||
func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x ? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithoutProjection
|
||||
func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstm
|
||||
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
|
||||
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -663,10 +663,10 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.0372480
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLstm
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -689,10 +689,10 @@ func @testQuantizedBasicLstm(%arg0: tensor<1x384x!quant.uniform<u8:f32, 7.812500
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
|
||||
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
|
||||
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
@ -707,11 +707,11 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>
|
||||
|
||||
// -----
|
||||
|
||||
// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor.
|
||||
// test invalid input dimension, the third input operand for lstm op should be 2-D tensor.
|
||||
func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}}
|
||||
// expected-error @+1 {{'tfl.lstm' op failed to verify that operand 2 is 2-D}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %24 : tensor<4xf32>
|
||||
|
||||
@ -720,22 +720,22 @@ func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4
|
||||
// -----
|
||||
|
||||
// 'input_to_output_weights' input for lstm op has unmatched rank with `input`.
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`.
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
|
||||
func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> {
|
||||
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
|
||||
// expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}}
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
|
||||
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32>
|
||||
return %24 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -1201,7 +1201,7 @@ func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>)
|
||||
// -----
|
||||
|
||||
func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> {
|
||||
// expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}}
|
||||
// expected-error @+1 {{'tfl.resize_bilinear' op failed to verify that input and output must have same element type}}
|
||||
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
@ -1499,8 +1499,8 @@ func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!q
|
||||
|
||||
// CHECK-LABEL: testSvdf
|
||||
func @testSvdf(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
|
||||
// CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -439,6 +439,19 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor<
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @NotReorderReshapeAddIfHighDim
|
||||
func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<1x30x96xf32> {
|
||||
%cst = constant dense<2.0> : tensor<f32>
|
||||
%shape = constant dense<[1, 30, 96]> : tensor<3xi32>
|
||||
%1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x30x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32>
|
||||
%2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<f32>) -> tensor<1x30x96xf32>
|
||||
return %2 : tensor<1x30x96xf32>
|
||||
|
||||
// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0
|
||||
// CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]]
|
||||
// CHECK: return %[[rs2]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp
|
||||
func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> {
|
||||
%shape = constant dense<[40, 40]> : tensor<2xi32>
|
||||
|
@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
|
||||
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: RemoveTrival
|
||||
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
|
||||
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
|
||||
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
|
||||
|
||||
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
|
||||
// CHECK-NEXT: return %[[fc]]
|
||||
}
|
||||
|
||||
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
%cst = constant dense<[1, 1001]> : tensor<2xi32>
|
||||
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>
|
||||
|
@ -1,27 +1,27 @@
|
||||
// RUN: tf-opt -tfl-split-merged-operands %s | FileCheck %s
|
||||
|
||||
func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
|
||||
func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK-LABEL: testSingleLstm
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> {
|
||||
func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK-LABEL: testMultipleLstms
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
// CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %2 : tensor<4xf32>
|
||||
%0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const")
|
||||
%1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
%2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
return %2 : tensor<4x4xf32>
|
||||
}
|
||||
|
@ -162,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul));
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
if (pass_config.shape_inference) {
|
||||
// Add a shape inference pass to optimize away the unnecessary casts.
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
}
|
||||
pass_manager->addPass(
|
||||
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||
// The value won't change in the new attribute, but if the value is out of
|
||||
// the bound of i32, the function returns a failure.
|
||||
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
|
||||
if (attr.getType().isInteger(/*width=*/32)) {
|
||||
*attr_i32 = attr;
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t value = attr.getInt();
|
||||
if (value > std::numeric_limits<int>::max() ||
|
||||
value < std::numeric_limits<int>::min()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
*attr_i32 = IntegerAttr::get(
|
||||
IntegerType::get(/*width=*/32, attr.getContext()), value);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
|
||||
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
|
||||
|
||||
// "axis" operand could be a i64 tensor. Resolve it here.
|
||||
IntegerAttr axis_i32;
|
||||
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<ConcatenationOp>(
|
||||
op, output_type, values, ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
op, output_type, values, axis_i32, fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
// TensorFlow operations that doesn't have operands and results of type
|
||||
// variant are legal. Here, we don't distinguish between variants encoding
|
||||
// TensorList or some other type as that information is not available here.
|
||||
// This constraint should be relaxed to support other variant types in TFLite.
|
||||
// Partial legalization is used below to still allow ops with variant types
|
||||
// still.
|
||||
auto is_legal = [](Operation *op) {
|
||||
auto is_not_variant = [](Type ty) {
|
||||
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
|
||||
@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
ConvertTensorListPushBack, ConvertTensorListReserve,
|
||||
ConvertTensorListSetItem, ConvertTensorListStack,
|
||||
ConvertTensorListResize, ConvertWhile>(context);
|
||||
return applyFullConversion(func, target, patterns);
|
||||
return applyPartialConversion(func, target, patterns);
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnOperation() {
|
||||
|
@ -29,6 +29,10 @@ def ExtractSingleElementAsFloat : NativeCodeCall<
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// Checks if the value has rank at most 'n'.
|
||||
class HasRankAtMost<int n> : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -347,7 +351,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
// The result of the new "BinaryOp" will have the same shape as
|
||||
// `input`. In other words, the shape of the `Reshape` op are not
|
||||
// changed after the transformation.
|
||||
(IsTailOfShape $rhs, $input)]>;
|
||||
(IsTailOfShape $rhs, $input),
|
||||
(HasRankAtMost<5> $input),
|
||||
(HasRankAtMost<5> $rhs)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
|
@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
auto* ctx = func.getContext();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
if (!emit_quant_adaptor_ops_) {
|
||||
|
@ -41,15 +41,22 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
// The cmd line flag to turn on/off Tf.Text API fusion.
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> fuse_tftext(
|
||||
"tfl-fuse-tftext", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Fuse TF.Text API ops when it's true"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
||||
constexpr char kTfTextAPIPRefix[] = "tftext:";
|
||||
|
||||
// Abstracts the conversion of the embedded lookup composite function.
|
||||
class ConvertEmbeddedLookupFunc {
|
||||
@ -187,6 +194,10 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
|
||||
OpBuilder builder(func.getBody());
|
||||
if (failed(ConvertKerasLSTMLayer(func, &builder)))
|
||||
return signalPassFailure();
|
||||
} else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) {
|
||||
if (failed(ConvertTFTextAPI(func, attr.getValue()))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +70,7 @@ class PrepareQuantizePass
|
||||
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
|
||||
public:
|
||||
// Constructor used by the PassRegistration and enforce uint8 quantization.
|
||||
// This is only used by test.
|
||||
explicit PrepareQuantizePass() {
|
||||
if (quantize_signed)
|
||||
quant_specs_.inference_type = tensorflow::DT_QINT8;
|
||||
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
|
||||
// convert all of them to signed.
|
||||
OwningRewritePatternList patterns;
|
||||
bool is_signed = quant_specs_.IsSignedInferenceType();
|
||||
int bit_width = quant_specs_.GetQuantizationTypeWidth();
|
||||
if (is_signed) {
|
||||
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
|
||||
// Convert quant stats to int8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
|
||||
} else {
|
||||
// Convert quant stats to uint8 quantization parameters.
|
||||
// Currently, only activation stats are imported, so narrow_range = false.
|
||||
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
|
||||
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
|
||||
}
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
|
||||
|
127
tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
Normal file
127
tensorflow/compiler/mlir/lite/utils/tftext_utils.cc
Normal file
@ -0,0 +1,127 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
|
||||
constexpr char kTFAPIImplements[] = "tf.api_implements";
|
||||
|
||||
inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) {
|
||||
std::string content = "";
|
||||
ShapedType type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
|
||||
return OpaqueElementsAttr::get(
|
||||
builder->getContext()->getRegisteredDialect("tfl"), type, content);
|
||||
}
|
||||
|
||||
inline RankedTensorType getInputType(mlir::FuncOp func, int idx) {
|
||||
return func.getType()
|
||||
.getInput(idx)
|
||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
}
|
||||
|
||||
inline RankedTensorType getResultType(mlir::FuncOp func, int idx) {
|
||||
return func.getType()
|
||||
.getResult(idx)
|
||||
.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
}
|
||||
|
||||
LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) {
|
||||
if (func.getNumResults() != 2) {
|
||||
return failure();
|
||||
}
|
||||
if (func.getNumArguments() != 1) {
|
||||
return failure();
|
||||
}
|
||||
auto input_type = getInputType(func, 0);
|
||||
if (!input_type || input_type.getRank() != 1 ||
|
||||
!input_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
}
|
||||
auto value_type = getResultType(func, 0);
|
||||
if (!value_type || value_type.getRank() != 1 ||
|
||||
!value_type.getElementType().isa<mlir::TF::StringType>()) {
|
||||
return failure();
|
||||
}
|
||||
auto offset_type = getResultType(func, 1);
|
||||
if (offset_type.getRank() != 1 ||
|
||||
!offset_type.getElementType().isInteger(64)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func,
|
||||
llvm::StringRef api) {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext()));
|
||||
|
||||
Value text = func.getArgument(0);
|
||||
auto output_type = func.getType().getResult(0);
|
||||
auto offset_type = func.getType().getResult(1);
|
||||
SmallVector<Type, 2> shape = {output_type, offset_type};
|
||||
ArrayRef<Type> output_types(shape);
|
||||
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types,
|
||||
ValueRange(text), api,
|
||||
emptyCustomOption(&builder));
|
||||
|
||||
builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults());
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) {
|
||||
if (api.str() == kWhitespaceTokenizer) {
|
||||
if (succeeded(VerifyWhitespaceTokenizer(func))) {
|
||||
return ConvertWhitespaceTokenizer(func, api);
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
39
tensorflow/compiler/mlir/lite/utils/tftext_utils.h
Normal file
39
tensorflow/compiler/mlir/lite/utils/tftext_utils.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file defines common utils used by TFLite transformation
|
||||
// passes to work with op attributes.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api);
|
||||
|
||||
} // end namespace TFL
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() {
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
|
@ -224,7 +224,10 @@ cc_library(
|
||||
hdrs = [
|
||||
"ir/tf_attributes.h",
|
||||
],
|
||||
deps = ["@llvm-project//mlir:IR"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -427,6 +430,7 @@ cc_library(
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/raise_control_flow.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
"transforms/resource_device_inference.cc",
|
||||
@ -555,8 +559,7 @@ cc_library(
|
||||
srcs = ["ir/dialect_registration.cc"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFTransforms",
|
||||
"@llvm-project//mlir:Shape",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -785,6 +788,9 @@ cc_library(
|
||||
name = "convert_type",
|
||||
srcs = ["utils/convert_type.cc"],
|
||||
hdrs = ["utils/convert_type.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1140,6 +1146,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
@ -1278,6 +1285,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1292,6 +1300,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
|
55
tensorflow/compiler/mlir/tensorflow/c/BUILD
Normal file
55
tensorflow/compiler/mlir/tensorflow/c/BUILD
Normal file
@ -0,0 +1,55 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
"tfe_xla_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
packages = ["//tensorflow/..."],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "mlir_c_api",
|
||||
srcs = [
|
||||
"c_api_unified_experimental_mlir.cc",
|
||||
],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_c_api_registration",
|
||||
srcs = ["c_api_unified_experimental_mlir_registration.cc"],
|
||||
deps = [
|
||||
":mlir_c_api",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,493 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
using tensorflow::internal::AbstractFunction;
|
||||
using tensorflow::internal::AbstractOp;
|
||||
using tensorflow::internal::AbstractTensor;
|
||||
using tensorflow::internal::dyncast;
|
||||
using tensorflow::internal::ExecutionContext;
|
||||
using tensorflow::internal::OutputList;
|
||||
|
||||
namespace {
|
||||
|
||||
static void RegisterDialects() {
|
||||
static bool init_once = []() {
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
}
|
||||
|
||||
Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder,
|
||||
Type* type) {
|
||||
Status s = tensorflow::ConvertDataType(dtype, builder, type);
|
||||
if (s.ok()) *type = UnrankedTensorType::get(*type);
|
||||
return s;
|
||||
}
|
||||
|
||||
class MlirTensor : public AbstractTensor {
|
||||
public:
|
||||
explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {}
|
||||
|
||||
Value getValue() { return value_; }
|
||||
|
||||
static constexpr AbstractTensorKind kKind = kMlirTensor;
|
||||
|
||||
private:
|
||||
Value value_;
|
||||
};
|
||||
|
||||
class MlirAbstractOp : public AbstractOp {
|
||||
public:
|
||||
explicit MlirAbstractOp(MLIRContext* context)
|
||||
: AbstractOp(kKind), context_(context) {}
|
||||
|
||||
void SetOpType(const char* op_type, TF_Status* s) override;
|
||||
|
||||
void SetAttrType(const char* attr_name, TF_DataType dtype,
|
||||
TF_Status* s) override;
|
||||
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override;
|
||||
|
||||
MLIRContext* GetContext() { return context_; }
|
||||
|
||||
Type AddRef(Type type, TF_Status* s);
|
||||
|
||||
OperationState* Create(ArrayRef<Value> operands, TF_Status* s);
|
||||
|
||||
static constexpr AbstractOpKind kKind = kMlirOp;
|
||||
|
||||
private:
|
||||
MLIRContext* context_;
|
||||
llvm::StringMap<Attribute> attrs_;
|
||||
std::unique_ptr<OperationState> state_;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
// MlirFunction is a thin wrapper over a FuncOp.
|
||||
class MlirFunction : public AbstractFunction {
|
||||
public:
|
||||
explicit MlirFunction(std::unique_ptr<MLIRContext> context,
|
||||
OwningModuleRef module, FuncOp func)
|
||||
: AbstractFunction(kKind),
|
||||
context_(std::move(context)),
|
||||
module_(std::move(module)),
|
||||
func_(func) {}
|
||||
|
||||
TF_Function* GetTfFunction(TF_Status* s) override;
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
|
||||
private:
|
||||
std::unique_ptr<MLIRContext> context_;
|
||||
OwningModuleRef module_;
|
||||
FuncOp func_;
|
||||
};
|
||||
|
||||
class MlirFunctionContext : public ExecutionContext {
|
||||
public:
|
||||
explicit MlirFunctionContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
context_(std::make_unique<MLIRContext>()),
|
||||
builder_(context_.get()) {
|
||||
// TODO(aminim) figure out the location story here
|
||||
module_ = ModuleOp::create(builder_.getUnknownLoc());
|
||||
func_ = FuncOp::create(builder_.getUnknownLoc(), name,
|
||||
builder_.getFunctionType(llvm::None, llvm::None));
|
||||
module_->push_back(func_);
|
||||
builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock());
|
||||
}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
return new MlirAbstractOp(context_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* abstract_op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override;
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override;
|
||||
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override;
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
s->status = tensorflow::errors::Unimplemented(
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kMlirContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<MLIRContext> context_;
|
||||
OpBuilder builder_;
|
||||
FuncOp func_;
|
||||
OwningModuleRef module_;
|
||||
};
|
||||
|
||||
void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) {
|
||||
if (state_) {
|
||||
s->status = tensorflow::errors::FailedPrecondition(
|
||||
"SetOpType called on already built op.");
|
||||
return;
|
||||
}
|
||||
std::string name = "tf.";
|
||||
name += op_type;
|
||||
// TODO(aminim) figure out the location story here
|
||||
state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name);
|
||||
}
|
||||
|
||||
void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype,
|
||||
TF_Status* s) {
|
||||
if (!state_) {
|
||||
s->status = tensorflow::errors::FailedPrecondition(
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
Type mlir_type;
|
||||
Builder builder(context_);
|
||||
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
||||
builder, &mlir_type);
|
||||
if (!s->status.ok()) return;
|
||||
attrs_[attr_name] = TypeAttr::get(mlir_type);
|
||||
}
|
||||
|
||||
void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) {
|
||||
// TODO(aminim): should we use a location?
|
||||
if (op_name_) {
|
||||
s->status = tensorflow::errors::FailedPrecondition(
|
||||
"SetOpName called on already built op.");
|
||||
return;
|
||||
}
|
||||
op_name_ = op_name;
|
||||
}
|
||||
|
||||
Type MlirAbstractOp::AddRef(Type type, TF_Status* s) {
|
||||
Type elt_type = getElementTypeOrSelf(type);
|
||||
if (elt_type.isa<mlir::TF::TensorFlowRefType>()) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Requested reference to a reference type");
|
||||
return nullptr;
|
||||
}
|
||||
elt_type = TensorFlowRefType::get(elt_type);
|
||||
if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) {
|
||||
return RankedTensorType::get(tensor_type.getShape(), elt_type);
|
||||
}
|
||||
return UnrankedTensorType::get(elt_type);
|
||||
}
|
||||
|
||||
OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) {
|
||||
state_->operands = llvm::to_vector<4>(operands);
|
||||
const tensorflow::OpDef* op_def;
|
||||
auto node_name = state_->name.getStringRef().drop_front(
|
||||
TensorFlowDialect::getDialectNamespace().size() + 1);
|
||||
s->status =
|
||||
tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
Builder builder(context_);
|
||||
// Process operands according to the op_def and infer derived attributes.
|
||||
int current_operand = 0;
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) {
|
||||
if (!input_arg.number_attr().empty()) {
|
||||
// TODO(b/156122856): we don't support variadic operands.
|
||||
s->status = tensorflow::errors::Unimplemented(
|
||||
"Unsupported 'number_attr' for '", input_arg.number_attr(), "'");
|
||||
return nullptr;
|
||||
} else if (!input_arg.type_list_attr().empty()) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
if (current_operand >= operands.size()) {
|
||||
s->status = tensorflow::errors::InvalidArgument("Missing operand for '",
|
||||
input_arg.name(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
Type expected_type;
|
||||
if (input_arg.type() != tensorflow::DT_INVALID) {
|
||||
s->status =
|
||||
ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
if (input_arg.is_ref()) expected_type = AddRef(expected_type, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
} else {
|
||||
expected_type = operands[current_operand].getType();
|
||||
}
|
||||
if (!input_arg.type_attr().empty()) {
|
||||
attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type);
|
||||
}
|
||||
++current_operand;
|
||||
}
|
||||
|
||||
for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) {
|
||||
int original_size = state_->types.size();
|
||||
if (!output_arg.number_attr().empty()) {
|
||||
// Same type repeated "repeats" times.
|
||||
Attribute repeats_attr = attrs_[output_arg.number_attr()];
|
||||
if (!repeats_attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Missing attribute '", output_arg.number_attr(),
|
||||
"' required for output list '", output_arg.name(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
if (!repeats_attr.isa<IntegerAttr>()) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Attribute '", output_arg.number_attr(),
|
||||
"' required for output list '", output_arg.name(),
|
||||
"' isn't an integer");
|
||||
return nullptr;
|
||||
}
|
||||
int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt();
|
||||
|
||||
if (!output_arg.type_attr().empty()) {
|
||||
// Same type repeated "repeats" times.
|
||||
Attribute attr = attrs_[output_arg.type_attr()];
|
||||
if (!attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Missing attribute '", output_arg.type_attr(),
|
||||
"' required for output '", output_arg.name(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||
if (!type_attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||
output_arg.name(), "' isn't a type attribute");
|
||||
return nullptr;
|
||||
}
|
||||
for (int i = 0; i < repeats; ++i)
|
||||
state_->types.push_back(type_attr.getType());
|
||||
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||
for (int i = 0; i < repeats; ++i) {
|
||||
Type type;
|
||||
s->status =
|
||||
ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
state_->types.push_back(type);
|
||||
}
|
||||
} else {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Missing type or type_attr field in ",
|
||||
output_arg.ShortDebugString());
|
||||
return nullptr;
|
||||
}
|
||||
} else if (!output_arg.type_attr().empty()) {
|
||||
Attribute attr = attrs_[output_arg.type_attr()];
|
||||
if (!attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Missing attribute '", output_arg.type_attr(),
|
||||
"' required for output '", output_arg.name(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||
if (!type_attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Attribute '", output_arg.type_attr(), "' required for output '",
|
||||
output_arg.name(), "' isn't a type attribute");
|
||||
return nullptr;
|
||||
}
|
||||
state_->types.push_back(type_attr.getValue());
|
||||
} else if (!output_arg.type_list_attr().empty()) {
|
||||
// This is pointing to an attribute which is an array of types.
|
||||
Attribute attr = attrs_[output_arg.type_list_attr()];
|
||||
if (!attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Missing attribute '", output_arg.type_list_attr(),
|
||||
"' required for output '", output_arg.name(), "'");
|
||||
return nullptr;
|
||||
}
|
||||
ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>();
|
||||
if (!array_attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Attribute '", output_arg.type_list_attr(),
|
||||
"' required for output '", output_arg.name(),
|
||||
"' isn't an array attribute");
|
||||
return nullptr;
|
||||
}
|
||||
for (Attribute attr : array_attr) {
|
||||
TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
|
||||
if (!type_attr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Array Attribute '", output_arg.type_list_attr(),
|
||||
"' required for output '", output_arg.name(),
|
||||
"' has a non-Type element");
|
||||
return nullptr;
|
||||
}
|
||||
state_->types.push_back(type_attr.getValue());
|
||||
}
|
||||
} else if (output_arg.type() != tensorflow::DT_INVALID) {
|
||||
Type type;
|
||||
Builder builder(context_);
|
||||
s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
state_->types.push_back(type);
|
||||
} else {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"No type fields in ", output_arg.ShortDebugString());
|
||||
if (!s->status.ok()) return nullptr;
|
||||
}
|
||||
if (output_arg.is_ref()) {
|
||||
// For all types that were added by this function call, make them refs.
|
||||
for (Type& type : llvm::make_range(&state_->types[original_size],
|
||||
state_->types.end())) {
|
||||
type = AddRef(type, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
return state_.get();
|
||||
}
|
||||
|
||||
TF_Function* MlirFunction::GetTfFunction(TF_Status* s) {
|
||||
PassManager pm(func_.getContext());
|
||||
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||
|
||||
// In case of failure, the `diag_handler` converts MLIR errors emitted to
|
||||
// the MLIRContext into a tensorflow::Status.
|
||||
StatusScopedDiagnosticHandler diag_handler(func_.getContext());
|
||||
LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>());
|
||||
(void)result;
|
||||
s->status = diag_handler.ConsumeStatus();
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
tensorflow::GraphExportConfig configs;
|
||||
std::unique_ptr<TF_Function> tf_function(new TF_Function);
|
||||
s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs,
|
||||
&tf_function->fdef);
|
||||
return tf_function.release();
|
||||
}
|
||||
|
||||
void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op,
|
||||
int num_inputs,
|
||||
AbstractTensor* const* inputs,
|
||||
OutputList* o, TF_Status* s) {
|
||||
auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op);
|
||||
if (mlir_op == nullptr) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
SmallVector<Value, 8> operands;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* operand = dyncast<MlirTensor>(inputs[i]);
|
||||
if (!operand) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
}
|
||||
if (operand->getValue().getContext() != context_.get()) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Capturing tensors from other context is not supported.");
|
||||
return;
|
||||
}
|
||||
operands.push_back(operand->getValue());
|
||||
}
|
||||
OperationState* state = mlir_op->Create(operands, s);
|
||||
if (!s->status.ok() || !state) return;
|
||||
Operation* op = builder_.createOperation(*state);
|
||||
int num_results = op->getNumResults();
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_results);
|
||||
for (Value result : op->getResults())
|
||||
o->outputs.push_back(new MlirTensor(result));
|
||||
}
|
||||
|
||||
AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype,
|
||||
TF_Status* s) {
|
||||
Type type;
|
||||
s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype),
|
||||
builder_, &type);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
return new MlirTensor(func_.getBody().front().addArgument(type));
|
||||
}
|
||||
|
||||
AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs,
|
||||
TF_Status* s) {
|
||||
Block& body = func_.getBody().front();
|
||||
SmallVector<Value, 8> ret_operands;
|
||||
for (AbstractTensor* output : outputs->outputs) {
|
||||
auto* operand = dyncast<MlirTensor>(output);
|
||||
if (!operand) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return nullptr;
|
||||
}
|
||||
if (operand->getValue().getContext() != context_.get()) {
|
||||
s->status = tensorflow::errors::InvalidArgument(
|
||||
"Capturing tensors from other context is not supported.");
|
||||
return nullptr;
|
||||
}
|
||||
ret_operands.push_back(operand->getValue());
|
||||
}
|
||||
builder_.create<ReturnOp>(func_.getLoc(), ret_operands);
|
||||
|
||||
auto arg_types = llvm::to_vector<8>(body.getArgumentTypes());
|
||||
auto result_types =
|
||||
llvm::to_vector<8>(body.getTerminator()->getOperandTypes());
|
||||
func_.setType(FunctionType::get(arg_types, result_types, func_.getContext()));
|
||||
return new MlirFunction(std::move(context_), std::move(module_), func_);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) {
|
||||
RegisterDialects();
|
||||
return new MlirFunctionContext(fn_name);
|
||||
}
|
||||
}
|
||||
|
||||
} // end anonymous namespace
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
|
||||
using tensorflow::internal::ExecutionContext;
|
||||
|
||||
extern "C" {
|
||||
ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Register the tracing implemented in this file as the default tracing engine.
|
||||
static bool register_tracing = [] {
|
||||
RegisterTracingEngineFactory("mlir", MlirTracingFactory);
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TFControlFlow {
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect>
|
||||
tf_device_dialect;
|
||||
static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect>
|
||||
tf_saved_model_dialect;
|
||||
static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect;
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
@ -45,6 +47,28 @@ struct ShapeAttrStorage : public AttributeStorage {
|
||||
bool unranked = false;
|
||||
};
|
||||
|
||||
// The storage class for FuncAttr.
|
||||
struct FuncAttrStorage : public AttributeStorage {
|
||||
using KeyTy = std::pair<Attribute, Attribute>;
|
||||
|
||||
explicit FuncAttrStorage(Attribute name, Attribute attrs)
|
||||
: name(name), attrs(attrs) {}
|
||||
|
||||
bool operator==(const KeyTy& key) const { return key == KeyTy(name, attrs); }
|
||||
static unsigned hashKey(const KeyTy& key) {
|
||||
return llvm::hash_combine(key.first, key.second);
|
||||
}
|
||||
|
||||
static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator,
|
||||
const KeyTy& key) {
|
||||
return new (allocator.allocate<FuncAttrStorage>())
|
||||
FuncAttrStorage(key.first, key.second);
|
||||
}
|
||||
|
||||
Attribute name;
|
||||
Attribute attrs;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Get or create a shape attribute.
|
||||
@ -85,5 +109,24 @@ bool ShapeAttr::hasStaticShape() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name,
|
||||
DictionaryAttr attr) {
|
||||
auto symbol = SymbolRefAttr::get(name, context);
|
||||
return Base::get(context, AttrKind::FUNC, symbol, attr);
|
||||
}
|
||||
|
||||
FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol,
|
||||
DictionaryAttr attr) {
|
||||
return Base::get(context, AttrKind::FUNC, symbol, attr);
|
||||
}
|
||||
|
||||
SymbolRefAttr FuncAttr::GetName() const {
|
||||
return getImpl()->name.cast<SymbolRefAttr>();
|
||||
}
|
||||
|
||||
DictionaryAttr FuncAttr::GetAttrs() const {
|
||||
return getImpl()->attrs.cast<DictionaryAttr>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
@ -30,6 +31,7 @@ namespace AttrKind {
|
||||
enum Kind {
|
||||
FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR,
|
||||
SHAPE = FIRST_USED_TENSORFLOW_ATTR,
|
||||
FUNC,
|
||||
LAST_USED_TENSORFLOW_ATTR,
|
||||
};
|
||||
|
||||
@ -38,6 +40,7 @@ enum Kind {
|
||||
namespace detail {
|
||||
|
||||
struct ShapeAttrStorage;
|
||||
struct FuncAttrStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -71,6 +74,33 @@ class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute,
|
||||
static bool kindof(unsigned kind) { return kind == AttrKind::SHAPE; }
|
||||
};
|
||||
|
||||
// Custom attribute to model AttrValue.value.func (NameAttrList type attribute).
|
||||
// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a
|
||||
// DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is
|
||||
// currently printed and parsed for the following format:
|
||||
//
|
||||
// #tf.func<@symbol, {attr = "value"}>
|
||||
//
|
||||
// where the first element is the SymbolRefAttr and the second element is the
|
||||
// DictionaryAttr.
|
||||
class FuncAttr
|
||||
: public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name,
|
||||
DictionaryAttr attr);
|
||||
|
||||
static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol,
|
||||
DictionaryAttr attr);
|
||||
|
||||
SymbolRefAttr GetName() const;
|
||||
|
||||
DictionaryAttr GetAttrs() const;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; }
|
||||
};
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -47,37 +47,6 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_executor {
|
||||
namespace {
|
||||
|
||||
// If the given tensor has elements of type with subtypes, then returns a new
|
||||
// type after dropping subtypes info. Otherwise, returns the original type as
|
||||
// is.
|
||||
ShapedType DropTypeSubTypes(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!subtype_ty) return ty;
|
||||
|
||||
Type default_ty = GetDefaultTypeOf(subtype_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
}
|
||||
|
||||
// If the given tensor has elements of type ref, then returns a new type
|
||||
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
||||
// returns the original type as is.
|
||||
ShapedType DropRefType(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
auto ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
||||
if (!ref_ty) return ty;
|
||||
|
||||
Type default_ty = GetDefaultTypeOf(ref_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF Executor Dialect
|
||||
@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) {
|
||||
|
||||
namespace {
|
||||
|
||||
using TF::DropRefType;
|
||||
using TF::DropTypeSubTypes;
|
||||
|
||||
struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
|
@ -1195,6 +1195,46 @@ subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_CaseOp : TF_Op<"Case", []> {
|
||||
let summary = [{
|
||||
An n-way switch statement which calls a single branch function.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
An n-way switch statement, implementing the following:
|
||||
```
|
||||
switch (branch_index) {
|
||||
case 0:
|
||||
output = branches[0](input);
|
||||
break;
|
||||
case 1:
|
||||
output = branches[1](input);
|
||||
break;
|
||||
...
|
||||
case [[nbranches-1]]:
|
||||
default:
|
||||
output = branches[nbranches-1](input);
|
||||
break;
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$branch_index,
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches,
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Cast x of type SrcT to y of DstT.";
|
||||
|
||||
@ -6331,6 +6371,8 @@ If `x` and `y` are reals, this will return the floating-point division.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
@ -7434,9 +7476,15 @@ select(condition, t, e) ==> [[1, 2],
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> {
|
||||
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
let summary = "";
|
||||
|
||||
let description = [{
|
||||
@ -8309,7 +8357,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> {
|
||||
);
|
||||
}
|
||||
|
||||
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
|
||||
let summary = "Stops gradient computation.";
|
||||
|
||||
let description = [{
|
||||
@ -10586,6 +10634,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
|
||||
let summary = "A host-side computation called from a TPU device.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$inputs,
|
||||
|
||||
StrAttr:$key,
|
||||
DefaultValuedAttr<I64Attr, "0">:$tpu_core
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$outputs
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
|
||||
let summary = "An op that receives embeddng activations on the TPU.";
|
||||
|
||||
|
@ -58,6 +58,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -110,7 +111,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
|
||||
return !type || type.getRank() <= rank;
|
||||
}
|
||||
|
||||
|
||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||
return dim_or_rank == -1;
|
||||
}
|
||||
@ -252,6 +252,39 @@ static LogicalResult VerifyTypesCompatibility(
|
||||
return success();
|
||||
}
|
||||
|
||||
// This is a helper for the Select to SelectV2 canonicalization. The `data` rank
|
||||
// refers to the rank of `t`/`e` (these two inputs have equal rank; this is
|
||||
// checked in the verifier).
|
||||
//
|
||||
// In most cases, the predicate for Select can be used directly as the predicate
|
||||
// for SelectV2. However, there is one case that varies, which is when the
|
||||
// predicate is a tensor and the data is multidimensional. In this case, Select
|
||||
// op semantics dictate that the predicate tensor length must match the size of
|
||||
// the first data dimension. This varies from normal broadcasting semantics
|
||||
// (which are used in SelectV2), so we must reshape the tensor in this case to
|
||||
// be compatible.
|
||||
static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc,
|
||||
Value cond, int data_rank) {
|
||||
auto cond_tensor = cond.getType().cast<RankedTensorType>();
|
||||
// Reshape is only needed in the case that the cond rank is 1 (i.e. it is
|
||||
// a vector) AND t/e rank is > 1.
|
||||
if (cond_tensor.getRank() != 1 || data_rank <= 1) {
|
||||
// No reshape necessary. Leave cond as it is.
|
||||
return cond;
|
||||
}
|
||||
|
||||
// This is the case where a reshape is needed. We want to construct the
|
||||
// shape [x,1,...1], where x is the value in the pred tensor and the
|
||||
// length of the shape is equal to data_rank.
|
||||
SmallVector<int64_t, 8> shape(data_rank, 1);
|
||||
shape[0] = cond_tensor.getShape().front();
|
||||
auto new_shape_type =
|
||||
RankedTensorType::get({data_rank}, builder->getIntegerType(64));
|
||||
auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape);
|
||||
auto new_shape = builder->create<ConstOp>(loc, shape_attr);
|
||||
return builder->create<ReshapeOp>(loc, cond, new_shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper functions detect device capabilities from RuntimeDevices.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -462,9 +495,10 @@ LogicalResult FoldOperandsPermutation(
|
||||
namespace {
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
template <typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
|
||||
template <
|
||||
typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
|
||||
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
@ -479,7 +513,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||
// value zero.
|
||||
int identity =
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
|
||||
std::is_same<OpT, RealDivOp>::value);
|
||||
|
||||
Type element_ty = lhs_type.getElementType();
|
||||
Attribute identity_attr;
|
||||
@ -496,6 +531,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
|
||||
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
|
||||
// TODO(chhe): we could fold and add an identity to force the broadcast.
|
||||
if (result_op_type != rhs_type) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool is_symmetric =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
@ -1256,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) {
|
||||
|
||||
if (rank == 1) {
|
||||
int64_t dim0 = input_ty.getDimSize(0);
|
||||
if (dim0 != ShapedType::kDynamicSize && dim0 != 4)
|
||||
return op.emitOpError("requires 1D input of size 4");
|
||||
if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2)
|
||||
return op.emitOpError("requires 1D input of size 4 or size 2");
|
||||
}
|
||||
|
||||
if (rank == 2) {
|
||||
@ -1620,10 +1661,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
|
||||
OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "fill op has two operand");
|
||||
|
||||
auto type = getType().cast<ShapedType>();
|
||||
// DenseElementsAttr that is used in this folder only supports int and float
|
||||
// types.
|
||||
// TODO(hinsu): Handle complex types once there is a attribute kind for
|
||||
// complex.
|
||||
if (!type.getElementType().isIntOrFloat()) return {};
|
||||
|
||||
auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
|
||||
if (!value) return {};
|
||||
|
||||
auto type = getType().cast<ShapedType>();
|
||||
if (type.hasStaticShape())
|
||||
return DenseElementsAttr::get(type, value.getValue({}));
|
||||
|
||||
@ -1774,75 +1821,125 @@ static LogicalResult Verify(GatherV2Op op) {
|
||||
|
||||
static LogicalResult Verify(IfOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch());
|
||||
if (!thenFn)
|
||||
auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch());
|
||||
if (!then_fn)
|
||||
return op.emitOpError("then_branch refers to an undefined function : ")
|
||||
<< op.then_branch();
|
||||
auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch());
|
||||
if (!elseFn)
|
||||
auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch());
|
||||
if (!else_fn)
|
||||
return op.emitOpError("else_branch refers to an undefined function : ")
|
||||
<< op.else_branch();
|
||||
auto thenFuncType = thenFn.getType();
|
||||
auto elseFuncType = elseFn.getType();
|
||||
auto then_fn_type = then_fn.getType();
|
||||
auto else_fn_type = else_fn.getType();
|
||||
|
||||
// Non-conditional operands starting with the second operand are passed to
|
||||
// branches and should be pair-wise compatible with branches' inputs.
|
||||
unsigned expectedNumInputs = op.getNumOperands() - 1;
|
||||
if (thenFuncType.getNumInputs() != expectedNumInputs ||
|
||||
elseFuncType.getNumInputs() != expectedNumInputs)
|
||||
return op.emitError("branches should have " + Twine(expectedNumInputs) +
|
||||
unsigned expected_num_inputs = op.getNumOperands() - 1;
|
||||
if (then_fn_type.getNumInputs() != expected_num_inputs ||
|
||||
else_fn_type.getNumInputs() != expected_num_inputs)
|
||||
return op.emitError("branches should have " + Twine(expected_num_inputs) +
|
||||
" inputs");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operandType, thenInputType}))
|
||||
for (unsigned i = 0; i < expected_num_inputs; ++i) {
|
||||
auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||
auto then_input_type = then_fn_type.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operand_type, then_input_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
thenInputType, operandType, i));
|
||||
then_input_type, operand_type, i));
|
||||
|
||||
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operandType, elseInputType}))
|
||||
auto else_input_type = else_fn_type.getInput(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({operand_type, else_input_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch input type {0} is incompatible with "
|
||||
"operand type {1} at index {2}",
|
||||
elseInputType, operandType, i));
|
||||
else_input_type, operand_type, i));
|
||||
|
||||
// If branches have incompatible input types that means that no tensor can
|
||||
// serve as input to both the functions. Hence, the op is invalid.
|
||||
if (!AreCastCompatible({thenInputType, elseInputType}))
|
||||
if (!AreCastCompatible({then_input_type, else_input_type}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"branches inputs have incompatible types {0} and {1} at index {2}",
|
||||
thenInputType, elseInputType, i));
|
||||
then_input_type, else_input_type, i));
|
||||
}
|
||||
|
||||
// Branches' results should be pair-wise compatible with the op results.
|
||||
unsigned expectedNumResults = op.getNumResults();
|
||||
if (thenFuncType.getNumResults() != expectedNumResults ||
|
||||
elseFuncType.getNumResults() != expectedNumResults)
|
||||
return op.emitError("branches should have " + Twine(expectedNumResults) +
|
||||
unsigned expected_num_results = op.getNumResults();
|
||||
if (then_fn_type.getNumResults() != expected_num_results ||
|
||||
else_fn_type.getNumResults() != expected_num_results)
|
||||
return op.emitError("branches should have " + Twine(expected_num_results) +
|
||||
" results");
|
||||
|
||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({thenResultType, resultType}))
|
||||
for (unsigned i = 0; i < expected_num_results; ++i) {
|
||||
auto result_type = op.getResult(i).getType().cast<TensorType>();
|
||||
auto then_result_type = then_fn_type.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({then_result_type, result_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("then branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
thenResultType, resultType, i));
|
||||
then_result_type, result_type, i));
|
||||
|
||||
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({elseResultType, resultType}))
|
||||
auto else_result_type = else_fn_type.getResult(i).cast<TensorType>();
|
||||
if (!AreCastCompatible({else_result_type, result_type}))
|
||||
return op.emitError(
|
||||
llvm::formatv("else branch result type {0} is incompatible with op "
|
||||
"result type {1} at index {2}",
|
||||
elseResultType, resultType, i));
|
||||
else_result_type, result_type, i));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(YieldOp op) {
|
||||
auto parent = op.getParentOp();
|
||||
// A YieldOp should be contained within an IfRegion op
|
||||
// (and WhileRegion in future)
|
||||
if (!isa<IfRegionOp>(parent))
|
||||
op.emitError() << " expects parent op "
|
||||
<< "'" << IfRegionOp::getOperationName() << "' but got '"
|
||||
<< parent->getName().getStringRef() << "'";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfRegionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult VerifyRegionResults(Operation *op, Region ®ion,
|
||||
StringRef region_name) {
|
||||
auto op_name = op->getName().getStringRef();
|
||||
// verify that op outputs match yield inputs
|
||||
YieldOp yield = cast<YieldOp>(region.front().getTerminator());
|
||||
unsigned expected_num_results = op->getNumResults();
|
||||
if (yield.getNumOperands() != expected_num_results)
|
||||
return op->emitError(region_name + " region should have " +
|
||||
Twine(expected_num_results) + " results");
|
||||
|
||||
for (int idx : llvm::seq<int>(0, expected_num_results)) {
|
||||
auto op_result_type = op->getResult(idx).getType().cast<TensorType>();
|
||||
auto region_result_type =
|
||||
yield.getOperand(idx).getType().cast<TensorType>();
|
||||
if (!AreCastCompatible({region_result_type, op_result_type}))
|
||||
return op->emitError(llvm::formatv(
|
||||
"{0} result type {1} is incompatible with {2} "
|
||||
"result type {3} at index {4}",
|
||||
region_name, region_result_type, op_name, op_result_type, idx));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult Verify(IfRegionOp op) {
|
||||
if (failed(VerifyRegionResults(op, op.then_branch(), "then")))
|
||||
return failure();
|
||||
if (failed(VerifyRegionResults(op, op.else_branch(), "else")))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InvertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2408,6 +2505,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<RealDivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
|
||||
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2539,6 +2640,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
|
||||
return unranked();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SelectToSelectV2>(context);
|
||||
}
|
||||
|
||||
// Verifies a few extra requirements on SelectOp:
|
||||
// (1) `then` and `else` must have same shape
|
||||
// (2) At least one of the following must be true:
|
||||
// (a) `cond` has the same rank as `then` and `else`
|
||||
// (b) `cond` is a scalar
|
||||
// (c) `cond` is a vector AND `then` and `else` are non-scalar with their
|
||||
// first dimension equal to `cond`.
|
||||
static LogicalResult Verify(SelectOp op) {
|
||||
auto then_tensor = op.t().getType().cast<TensorType>();
|
||||
auto else_tensor = op.e().getType().cast<TensorType>();
|
||||
// Check (1).
|
||||
if (!AreCastCompatible({then_tensor, else_tensor}))
|
||||
return op.emitOpError() << "requires t and e have compatible shapes";
|
||||
|
||||
// Get data rank (if exists).
|
||||
int data_rank;
|
||||
// If data is unranked or data_rank is 0, this will remain -2. Otherwise
|
||||
// refers to first dimension of then and/or else.
|
||||
int data_first_dim = -2;
|
||||
bool then_has_rank = then_tensor.hasRank();
|
||||
bool else_has_rank = else_tensor.hasRank();
|
||||
if (then_has_rank && else_has_rank) {
|
||||
data_rank = then_tensor.getRank();
|
||||
if (then_tensor.getRank() > 0)
|
||||
data_first_dim = then_tensor.getShape().front();
|
||||
if (else_tensor.getRank() > 0)
|
||||
data_first_dim = std::max(
|
||||
static_cast<int>(else_tensor.getShape().front()), data_first_dim);
|
||||
} else if (then_has_rank) {
|
||||
data_rank = then_tensor.getRank();
|
||||
if (then_tensor.getRank() > 0)
|
||||
data_first_dim = then_tensor.getShape().front();
|
||||
} else if (else_has_rank) {
|
||||
data_rank = else_tensor.getRank();
|
||||
if (else_tensor.getRank() > 0)
|
||||
data_first_dim = else_tensor.getShape().front();
|
||||
} else {
|
||||
// Neither has a rank.
|
||||
return success();
|
||||
}
|
||||
|
||||
auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
|
||||
if (!cond_tensor) return success();
|
||||
auto cond_rank = cond_tensor.getRank();
|
||||
// Check (2a) and (2b).
|
||||
if (cond_rank == 0 || cond_rank == data_rank) return success();
|
||||
// Check (2c).
|
||||
if (cond_rank == 1) {
|
||||
auto cond_shape = cond_tensor.getShape().front();
|
||||
if (data_rank == 0) {
|
||||
return op.emitOpError()
|
||||
<< "requires that t and e are nonscalar when pred is a vector";
|
||||
}
|
||||
// We know `data` tensor has a rank of at least 1.
|
||||
if (data_first_dim != -1 && cond_shape != -1 &&
|
||||
data_first_dim != cond_shape) {
|
||||
return op.emitOpError() << "requires that, when pred is a vector, the "
|
||||
"shape matches the first dimension of t and e";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
// None of (2a,b,c) were true; fail.
|
||||
return op.emitOpError() << "requires that pred is a scalar OR has the same "
|
||||
"rank as t and e OR is a vector";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectV2Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2598,9 +2774,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
||||
<< variadic_idx_str << " to match rank of operand"
|
||||
<< variadic_idx_str;
|
||||
} else if (result_ranked_type.hasStaticShape()) {
|
||||
// The operand is an unranked tensor, verify that the result is dynamic.
|
||||
return op->emitOpError("requires dynamic shape result")
|
||||
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
||||
// The operand is an unranked tensor, print a warning if the result
|
||||
// is static.
|
||||
// Note: We do not handle this situation as an error, this would be too
|
||||
// restrictive due to incompleteness of shape inference at this point.
|
||||
op->emitWarning("has static shape result")
|
||||
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
||||
}
|
||||
|
||||
Type element_type = result_ranked_type.getElementType();
|
||||
@ -3551,12 +3730,20 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
||||
if (!const_perm) return {};
|
||||
|
||||
auto const_value = const_perm.value();
|
||||
const auto &elements = const_value.getValues<APInt>();
|
||||
const auto elements = const_value.getValues<APInt>();
|
||||
|
||||
for (auto it : llvm::enumerate(elements)) {
|
||||
if (it.index() != it.value()) return {};
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Remove if/when we handle this more generally.
|
||||
if (op.getType() != op.x().getType()) {
|
||||
// If the types don't match then only fold if all the operands are in the TF
|
||||
// dialect.
|
||||
for (auto user : op.getOperation()->getUsers())
|
||||
if (user->getDialect() != op.getDialect()) return {};
|
||||
}
|
||||
|
||||
return op.x();
|
||||
}
|
||||
|
||||
@ -3700,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
static LogicalResult Verify(WhileOp op) {
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto condFn = module.lookupSymbol<FuncOp>(op.cond());
|
||||
auto bodyFn = module.lookupSymbol<FuncOp>(op.body());
|
||||
if (!condFn) {
|
||||
auto cond_fn = module.lookupSymbol<FuncOp>(op.cond());
|
||||
auto body_fn = module.lookupSymbol<FuncOp>(op.body());
|
||||
if (!cond_fn) {
|
||||
return op.emitOpError("cond refers to an undefined function : ")
|
||||
<< op.cond();
|
||||
}
|
||||
if (!bodyFn) {
|
||||
if (!body_fn) {
|
||||
return op.emitOpError("body refers to an undefined function : ")
|
||||
<< op.body();
|
||||
}
|
||||
|
||||
auto condFuncType = condFn.getType();
|
||||
auto bodyFuncType = bodyFn.getType();
|
||||
auto cond_fn_type = cond_fn.getType();
|
||||
auto body_fn_type = body_fn.getType();
|
||||
|
||||
// Verify that the cond function has exactly one result.
|
||||
if (condFuncType.getNumResults() != 1)
|
||||
if (cond_fn_type.getNumResults() != 1)
|
||||
return op.emitOpError("requires cond function to have exactly one result");
|
||||
|
||||
SmallVector<Type, 4> operands(op.getOperandTypes());
|
||||
|
||||
// Collect all the type lists for the op so that different pairs of type lists
|
||||
// can be compared for the compatibility.
|
||||
int numTypeLists = 5;
|
||||
std::pair<std::string, ArrayRef<Type>> typeLists[] = {
|
||||
{"operand", operands},
|
||||
{"body function result", bodyFuncType.getResults()},
|
||||
{"result", op.getResultTypes()},
|
||||
{"cond function input", condFuncType.getInputs()},
|
||||
{"body function input", bodyFuncType.getInputs()},
|
||||
};
|
||||
constexpr int kNumTypeLists = 5;
|
||||
const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists>
|
||||
type_lists = {{
|
||||
{"operand", operands},
|
||||
{"body function result", body_fn_type.getResults()},
|
||||
{"result", op.getResultTypes()},
|
||||
{"cond function input", cond_fn_type.getInputs()},
|
||||
{"body function input", body_fn_type.getInputs()},
|
||||
}};
|
||||
|
||||
// A pair of type lists should be cast compatible with each other if one is
|
||||
// converted to the another for a function call or assignment or there is a
|
||||
@ -3753,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) {
|
||||
// never converted from one to the another nor there is a common source
|
||||
// tensors. Compatibility requirement is not transitive.
|
||||
|
||||
for (int i = 0; i < numTypeLists; ++i) {
|
||||
for (int i = 0; i < kNumTypeLists; ++i) {
|
||||
// Skip the first pair as the While op operands and body function results
|
||||
// does not need to be compatible with each other.
|
||||
for (int j = std::max(2, i + 1); j < numTypeLists; ++j) {
|
||||
auto &a = typeLists[i];
|
||||
auto &b = typeLists[j];
|
||||
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
|
||||
auto &a = type_lists[i];
|
||||
auto &b = type_lists[j];
|
||||
|
||||
int aSize = a.second.size();
|
||||
if (aSize != b.second.size())
|
||||
int a_size = a.second.size();
|
||||
if (a_size != b.second.size())
|
||||
return op.emitOpError(
|
||||
llvm::formatv("requires the number of {0}s to be equal to the "
|
||||
"number of {1}s. Found {2} and {3}, respectively",
|
||||
a.first, b.first, aSize, b.second.size()));
|
||||
a.first, b.first, a_size, b.second.size()));
|
||||
|
||||
for (int idx = 0; idx < aSize; ++idx) {
|
||||
auto aType = a.second[idx];
|
||||
auto bType = b.second[idx];
|
||||
for (int idx = 0; idx < a_size; ++idx) {
|
||||
auto a_type = a.second[idx];
|
||||
auto b_type = b.second[idx];
|
||||
|
||||
if (!AreCastCompatible({aType, bType}))
|
||||
if (!AreCastCompatible({a_type, b_type}))
|
||||
return op.emitError(llvm::formatv(
|
||||
"{0} type {1} is incompatible with {2} type {3} at index {4}",
|
||||
a.first, aType, b.first, bType, idx));
|
||||
a.first, a_type, b.first, b_type, idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3856,7 +4044,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context)
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
addAttributes<ShapeAttr>();
|
||||
addAttributes<ShapeAttr, FuncAttr>();
|
||||
|
||||
// Support unknown operations because not all TensorFlow operations are
|
||||
// registered.
|
||||
@ -3911,6 +4099,49 @@ void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT
|
||||
os << ">";
|
||||
}
|
||||
|
||||
// Parses a #tf.func attribute of the following format:
|
||||
//
|
||||
// #tf.func<@symbol, {attr = "value"}>
|
||||
//
|
||||
// where the first element is a SymbolRefAttr and the second element is a
|
||||
// DictionaryAttr.
|
||||
FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) {
|
||||
auto emit_error = [&, spec]() {
|
||||
emitError(loc, "invalid TensorFlow func attribute: ") << spec;
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
if (!spec.consume_front("func<")) return emit_error();
|
||||
|
||||
size_t func_name_num_read = 0;
|
||||
Attribute func_name_attr =
|
||||
mlir::parseAttribute(spec, context, func_name_num_read);
|
||||
if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>())
|
||||
return emit_error();
|
||||
spec = spec.drop_front(func_name_num_read);
|
||||
|
||||
if (!spec.consume_front(", ")) return emit_error();
|
||||
|
||||
size_t func_attrs_num_read = 0;
|
||||
Attribute func_attrs_attr =
|
||||
mlir::parseAttribute(spec, context, func_attrs_num_read);
|
||||
if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>())
|
||||
return emit_error();
|
||||
spec = spec.drop_front(func_attrs_num_read);
|
||||
|
||||
if (!spec.consume_front(">")) return emit_error();
|
||||
|
||||
return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(),
|
||||
func_attrs_attr.cast<DictionaryAttr>());
|
||||
}
|
||||
|
||||
// Prints a #tf.func attribute of the following format:
|
||||
//
|
||||
// #tf.func<@symbol, {attr = "value"}>
|
||||
void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) {
|
||||
os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
|
||||
@ -3920,6 +4151,8 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser,
|
||||
|
||||
if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc);
|
||||
|
||||
if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc);
|
||||
|
||||
return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr);
|
||||
}
|
||||
|
||||
@ -3929,6 +4162,9 @@ void TensorFlowDialect::printAttribute(Attribute attr,
|
||||
case AttrKind::SHAPE:
|
||||
PrintShapeAttr(attr.cast<ShapeAttr>(), os);
|
||||
break;
|
||||
case AttrKind::FUNC:
|
||||
PrintFuncAttr(attr.cast<FuncAttr>(), os);
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unexpected tensorflow attribute kind");
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
|
||||
|
@ -207,6 +207,70 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_YieldOp : TF_Op<"Yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
let description = [{
|
||||
The "yield" operation represents a return operation within the conditional
|
||||
and body of structured control flow (e.g., if and while). The operation
|
||||
takes a variable number of operands and produces no results. The number and
|
||||
types of inputs must match the signature of the operation that contains the
|
||||
region.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_IfRegionOp : TF_Op<"IfRegion",
|
||||
[SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "output = cond ? then_branch output : else_branch output";
|
||||
|
||||
let description = [{
|
||||
"output = cond ? then_branch output : else_branch output"
|
||||
|
||||
cond: A Tensor. If the tensor is a scalar of non-boolean type, the
|
||||
scalar is converted to a boolean according to the
|
||||
following rule: if the scalar is a numerical value, non-zero means
|
||||
True and zero means False; if the scalar is a string, non-empty
|
||||
means True and empty means False. If the tensor is not a scalar,
|
||||
being empty means False and being non-empty means True.
|
||||
input: A list of input tensors.
|
||||
then_branch: A region that computes the outputs of the op if cond = true.
|
||||
It returns a list of tensors using tf.yield (as the terminator). The
|
||||
types of these returned tensors is same as that of the else_branch
|
||||
else_branch: A region that computes the outputs of the op if cond = false.
|
||||
It returns a list of tensors using tf.yield (as the terminator). The
|
||||
types of these returned tensors is same as that of the then_branch
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$cond,
|
||||
Variadic<TF_Tensor>:$input,
|
||||
|
||||
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
|
||||
|
||||
// Used to map StatelessIf and If op defined in TensorFlow to a common op.
|
||||
BoolAttr:$is_stateless
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch);
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||
|
||||
|
@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef<Type> types) {
|
||||
return true;
|
||||
}
|
||||
|
||||
ShapedType DropTypeSubTypes(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
if (!subtype_ty) return ty;
|
||||
|
||||
Type default_ty = GetDefaultTypeOf(subtype_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
}
|
||||
|
||||
ShapedType DropRefType(ShapedType ty) {
|
||||
Type element_ty = ty.getElementType();
|
||||
TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>();
|
||||
if (!ref_ty) return ty;
|
||||
|
||||
Type default_ty = TF::GetDefaultTypeOf(ref_ty);
|
||||
if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty);
|
||||
|
||||
return UnrankedTensorType::get(default_ty);
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs,
|
||||
// compatible.
|
||||
bool AreCastCompatible(ArrayRef<Type> types);
|
||||
|
||||
// If the given tensor has elements of type with subtypes, then returns a new
|
||||
// type after dropping subtypes info. Otherwise, returns the original type as
|
||||
// is.
|
||||
ShapedType DropTypeSubTypes(ShapedType ty);
|
||||
|
||||
// If the given tensor has elements of type ref, then returns a new type
|
||||
// of the shape, but corresponding non-ref type as element type. Otherwise,
|
||||
// returns the original type as is.
|
||||
ShapedType DropRefType(ShapedType ty);
|
||||
|
||||
} // end namespace TF
|
||||
} // end namespace mlir
|
||||
|
||||
|
@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectScalarPred
|
||||
func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> {
|
||||
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
|
||||
return %0: tensor<4x2xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectVectorPred
|
||||
func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"
|
||||
// CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1>
|
||||
// CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %0: tensor<2x3xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectAllSameShape
|
||||
func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %0: tensor<2x3xf16>
|
||||
}
|
||||
|
||||
// If we don't have guarantees on input shapes, we can't support canonicalizing
|
||||
// to SelectV2. Test these cases.
|
||||
// CHECK-LABEL: testSelectInvalid
|
||||
func @testSelectInvalid(%arg0: tensor<?xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||
// CHECK-NEXT: tf.Select
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %0: tensor<2x3xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectInvalidUnranked
|
||||
func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
|
||||
// CHECK-NEXT: tf.Select
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16>
|
||||
return %0: tensor<*xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectThenUnranked
|
||||
func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> {
|
||||
// CHECK-NEXT: tf.Select
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16>
|
||||
return %0: tensor<*xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectElseUnranked
|
||||
func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
|
||||
// CHECK-NEXT: tf.Select
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16>
|
||||
return %0: tensor<*xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogicalNotOfEqual
|
||||
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>
|
||||
@ -473,12 +526,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @foldFill
|
||||
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) {
|
||||
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
||||
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
|
||||
%2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
|
||||
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
|
||||
%3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32>
|
||||
|
||||
%complex_cst = "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>>
|
||||
// Here, custom folder doesn't handle complex dtypes and it is folded through
|
||||
// the constant folding hook.
|
||||
// TODO(hinsu): Handle complex dtypes in the custom folder for FillOp.
|
||||
// CHECK: "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex<f32>>} : () -> tensor<*xcomplex<f32>>
|
||||
%4 = "tf.Fill"(%0, %complex_cst) : (tensor<3xi32>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>>
|
||||
|
||||
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
|
||||
}
|
||||
|
@ -302,15 +302,13 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) ->
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAdd
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
@ -331,26 +329,22 @@ func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAddV2
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialSub
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||
@ -362,26 +356,31 @@ func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
|
||||
}
|
||||
|
||||
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialMul
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialDiv
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialRealDiv
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
@ -411,28 +410,35 @@ func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
// CHECK: tf.Div
|
||||
}
|
||||
|
||||
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: DontRemoveTrivialAdd
|
||||
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
|
||||
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
|
||||
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
return %1 :tensor<?x?xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
return %0 :tensor<?x?xf32>
|
||||
|
||||
// CHECK-LABEL: DontRemoveTrivialAdd2
|
||||
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
|
||||
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// Test no fold because of the broadcast.
|
||||
func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> {
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
|
||||
return %1 : tensor<1x6x8x1xf32>
|
||||
// CHECK-LABEL: DontRemoveTrivialMul
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32>
|
||||
}
|
||||
|
@ -0,0 +1,50 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics
|
||||
|
||||
// Tests invalid #tf.func attributes.
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func}}
|
||||
func @main() attributes {tf._implements = #tf.func} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<>}}
|
||||
func @main() attributes {tf._implements = #tf.func<>} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol>}}
|
||||
func @main() attributes {tf._implements = #tf.func<@symbol>} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<{}>}}
|
||||
func @main() attributes {tf._implements = #tf.func<{}>} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<"test", {}>}}
|
||||
func @main() attributes {tf._implements = #tf.func<"test", {}>} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, "">}}
|
||||
func @main() attributes {tf._implements = #tf.func<@symbol, "">} {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, {}, "">}}
|
||||
func @main() attributes {tf._implements = #tf.func<@symbol, {}, "">} {
|
||||
return
|
||||
}
|
13
tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir
Normal file
13
tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir
Normal file
@ -0,0 +1,13 @@
|
||||
// RUN: tf-opt %s | tf-opt | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @func_attr
|
||||
// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>
|
||||
func @func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>} {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @nested_func_attr
|
||||
// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.000000e+00 : f32}>}>
|
||||
func @nested_func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.0 : f32}>}>} {
|
||||
return
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure
|
||||
|
||||
node {
|
||||
name: "custom_relu_func_call"
|
||||
op: "custom_relu"
|
||||
}
|
||||
node {
|
||||
name: "custom_embedding_matmul_func_call"
|
||||
op: "custom_embedding_matmul"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "custom_relu"
|
||||
}
|
||||
attr {
|
||||
key: "_implements"
|
||||
value {
|
||||
func {
|
||||
name: "tensorflow.relu"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "custom_embedding_matmul"
|
||||
}
|
||||
attr {
|
||||
key: "_implements"
|
||||
value {
|
||||
func {
|
||||
name: "tensorflow.embedding_matmul"
|
||||
attr {
|
||||
key: "key1"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "key2"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
@ -2,17 +2,17 @@
|
||||
|
||||
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
return %0 : tensor<?x?x?x?xi32>
|
||||
}
|
||||
|
||||
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0 : tensor<4x4x4x4xi32>
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
|
||||
%0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
|
||||
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
|
||||
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
|
||||
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
%0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %2 : tensor<2x3xf16>
|
||||
}
|
||||
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
|
||||
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %3 : tensor<4x8xf32>
|
||||
@ -682,6 +682,11 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
|
||||
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
|
||||
return %0 : tensor<1x519xf32>
|
||||
}
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
@ -1493,3 +1498,12 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
// CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32>
|
||||
// CHECK: return [[VAL_371]] : tensor<2xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @convert_slice(
|
||||
// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> {
|
||||
// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32>
|
||||
// CHECK: return [[VAL_375]] : tensor<1x519xf32>
|
||||
// CHECK: }
|
||||
|
||||
|
@ -0,0 +1,85 @@
|
||||
// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test case: Basic converting.
|
||||
|
||||
func @f() {
|
||||
// CHECK: "tf.VarHandleOp"
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Two ReadVariable ops.
|
||||
|
||||
func @f() {
|
||||
// CHECK: "tf.VarHandleOp"
|
||||
|
||||
// During lowering to resource variables, this pass will preserve the
|
||||
// locations of the ReadVariableOps as Identity ops to keep the original graph
|
||||
// composition and order.
|
||||
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
%val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No follow-up ReadVariable case.
|
||||
|
||||
func @f() {
|
||||
// CHECK-NOT: "tf.VariableV2"
|
||||
// CHECK-NOT: "tf.VarHandleOp"
|
||||
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No converting when there is another use case.
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}}
|
||||
%val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No class attribute on VariableV2 op.
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}}
|
||||
%val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No named location found on VariableV2 op.
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}}
|
||||
%val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Invalid multiple location information in a class attribute on VariableV2 op.
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}}
|
||||
%val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref>
|
||||
%val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32>
|
||||
return
|
||||
}
|
@ -1,10 +1,11 @@
|
||||
// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail
|
||||
// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants=false -verify-diagnostics | FileCheck %s -dump-input=fail
|
||||
// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s -dump-input=fail
|
||||
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} {
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32>
|
||||
func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> {
|
||||
// CHECK-NOT: tf.Cast
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"
|
||||
// CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: return %[[RESULT]] : tensor<1xi32>
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32>
|
||||
@ -60,8 +61,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @simple_folding
|
||||
func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> {
|
||||
// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]]
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"
|
||||
// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]]
|
||||
// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||
// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32>
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32>
|
||||
@ -300,13 +301,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_cast
|
||||
func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: Cast
|
||||
%0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>)
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @while_variant
|
||||
// CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>>
|
||||
func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> {
|
||||
@ -362,8 +356,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: func @partitioned_call_func_const
|
||||
func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK: return %[[CONST]]
|
||||
return %arg0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -410,4 +402,18 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
%40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: const_fold
|
||||
func @const_fold() -> () {
|
||||
// CHECK: tf.Const
|
||||
// CHECK-SAME: () -> tensor<4xi32>
|
||||
%0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
||||
// CHECK: tf.Const
|
||||
// CHECK-SAME: () -> tensor<4xi32>
|
||||
%1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32>
|
||||
// CHECK: tf.Add
|
||||
// CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -854,6 +854,215 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.Yield operation (parent should be IfRegion)
|
||||
func @testInvalidYieldOp(%arg0: f32) -> () {
|
||||
// expected-error @+1 {{expects parent op 'tf.IfRegion'}}
|
||||
"tf.Yield"(%arg0) : (f32) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.IfRegion operation
|
||||
// CHECK-LABEL: func @testValidIfRegionOp
|
||||
func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.IfRegion operation with multiple results
|
||||
// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults
|
||||
func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t0, %t1, %t2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e0 = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
|
||||
|
||||
%3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %4 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid type for operand #0 for tf.IfRegion operation
|
||||
func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{operand #0 must be tensor of tf.dtype values}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid type for operand #1 for tf.IfRegion operation
|
||||
func @testInvalidIfRegionOpType1(%arg0: tensor<i1>, %arg1: f32) -> f32 {
|
||||
// expected-error @+1 {{operand #1 must be tensor of tf.dtype values}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = addf %arg1, %arg1 : f32
|
||||
"tf.Yield"(%t) : (f32) -> ()
|
||||
}, {
|
||||
%e = mulf %arg1, %arg1 : f32
|
||||
"tf.Yield"(%e) : (f32) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, f32) -> f32
|
||||
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.IfRegion operation should have 2 regions
|
||||
func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{op expected 2 regions}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%te = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%te) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.IfRegion regions should be terminated with a tf.Yield
|
||||
func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}}
|
||||
// expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.Region yield number of results should match op number of results
|
||||
func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{then region should have 1 result}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{else region should have 1 result}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// tf.IfRegion yield types should match op result types
|
||||
func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{then result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
%e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%e) : (tensor<2xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// expected-error @+1 {{else result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}}
|
||||
%0 = "tf.IfRegion"(%arg0, %arg1) ({
|
||||
%t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
"tf.Yield"(%t) : (tensor<2xf32>) -> ()
|
||||
}, {
|
||||
"tf.Yield"(%arg0) : (tensor<i1>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.MatrixBandPart
|
||||
// CHECK-LABEL: func @testValidMatrixBandPartOp
|
||||
func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
@ -1007,6 +1216,116 @@ func @pcall_func_2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.Select
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Test valid tf.Select
|
||||
// CHECK-LABEL: func @testSelect
|
||||
func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
|
||||
return %0: tensor<3x2xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
|
||||
// expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}}
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %0: tensor<2x3xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.Select - broadcasting then/else parameters is not supported
|
||||
func @selectBroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||
// expected-error @+1 {{requires t and e have compatible shapes}}
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}}
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> {
|
||||
// expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}}
|
||||
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32>
|
||||
return %0: tensor<1x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.SelectV2
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// Test valid tf.SelectV2
|
||||
// CHfaECK-LABEL: func @selectV2BroadcastThen
|
||||
func @selectV2BroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.SelectV2
|
||||
// CHECK-LABEL: func @selectV2BroadcastElse
|
||||
func @selectV2BroadcastElse(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test valid tf.SelectV2
|
||||
// CHECK-LABEL: func @selectV2BroadcastPred
|
||||
func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
|
||||
return %0: tensor<2x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @selectV2BroadcastAll
|
||||
func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
|
||||
return %0: tensor<8x8x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @selectV2DynamicRanked
|
||||
func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
|
||||
return %0: tensor<2x?x8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @selectV2Unranked
|
||||
func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0: tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate
|
||||
func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
|
||||
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
|
||||
return %0: tensor<3x2xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.Softmax
|
||||
//===--------------------------------------------------------------------===//
|
||||
@ -1326,7 +1645,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
|
||||
|
||||
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
||||
^bb0(%arg0: tensor<*xf32>):
|
||||
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
|
||||
// expected-warning @+1 {{has static shape result for unranked operand}}
|
||||
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
@ -1370,7 +1689,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
|
||||
|
||||
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
||||
^bb0(%arg0: tensor<*xf32>):
|
||||
// expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}}
|
||||
// expected-warning @+1 {{has static shape result #1 for unranked operand #1}}
|
||||
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
|
||||
return %0#1 : tensor<2xi32>
|
||||
}
|
||||
@ -1428,7 +1747,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1
|
||||
// -----
|
||||
|
||||
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
|
||||
// expected-warning @+1 {{has static shape result for unranked operand}}
|
||||
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
@ -104,3 +104,9 @@ module attributes {tf_saved_model.semantics} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test running the pass on a module that does not have
|
||||
// tf_saved_model.semantics.
|
||||
module {}
|
||||
|
@ -136,3 +136,9 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test running the pass on a module that does not have
|
||||
// tf_saved_model.semantics.
|
||||
module {}
|
||||
|
@ -2,80 +2,183 @@
|
||||
|
||||
// Tests extraction of a outside compiled ops at head of TPU computation.
|
||||
|
||||
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: tf_device.launch
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.B"() : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @single_head_outside_compilation
|
||||
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: tf_device.launch
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.B"() : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
|
||||
"tf.D"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @ops_no_operands
|
||||
func @ops_no_operands() -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {}: (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
|
||||
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK-NOT: tf_device.launch
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @aliased_output
|
||||
func @aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1
|
||||
%0:3 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||
%2 = "tf.B"(%1) {}: (tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.C"(%2) : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return %1, %3, %2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
|
||||
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK: "tf.E"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"() {} : () -> (tensor<i32>)
|
||||
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
// CHECK-LABEL: func @all_head_computation_ops
|
||||
func @all_head_computation_ops(%arg0 : tensor<i32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
//
|
||||
// CHECK: return %[[LAUNCH_OUT]]
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>)
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
|
||||
"tf.D"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
|
||||
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK-NOT: tf_device.launch
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
|
||||
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK: "tf.E"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"() {} : () -> (tensor<i32>)
|
||||
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_replicated_head_outside_compilation
|
||||
func @test_replicated_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
// CHECK: device = "TPU_REPLICATED_HOST"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
// CHECK: "tf.C"
|
||||
// CHECK: "tf.E"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
tf_device.replicate() {n = 2 : i32} {
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%1 = "tf.B"() {} : () -> (tensor<i32>)
|
||||
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
|
||||
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -141,4 +141,133 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
|
||||
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @mixed_input_single_outside_compilation
|
||||
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
|
||||
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
|
||||
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute.
|
||||
|
@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication.
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under
|
||||
// data parallelism replicated host devices are also added to the
|
||||
// tf_device.replicate
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_tpu_cluster_func
|
||||
@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}
|
||||
// CHECK-SAME: n = 2
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
|
||||
@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication and
|
||||
// parallel_execute.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
|
||||
@ -1231,17 +1234,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK: "tf._TPUCompileMlir"
|
||||
// CHECK: "tf.TPUCompileSucceededAssert"
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NOT:"tf._TPUCompileMlir"
|
||||
// CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1
|
||||
// CHECK: "tf.TPUExecute"
|
||||
// CHECK-NOT:"tf._TPUCompileMlir"
|
||||
// CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1
|
||||
%3 = "tf_device.parallel_execute"() ( {
|
||||
"tf.D"() : () -> ()
|
||||
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.D"(%program) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}, {
|
||||
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}, {
|
||||
%status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.E"(%program) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) : () -> (tensor<?xi32>)
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}
|
||||
@ -1317,15 +1329,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01"
|
||||
// -----
|
||||
|
||||
// Tests devices are set properly for replicated model parallelism.
|
||||
// Tests devices are set properly for replicated model parallelism. No
|
||||
// replicated host device should be present.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @replicated_parallel_execute
|
||||
func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: devices =
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]}
|
||||
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
|
||||
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
@ -1357,8 +1368,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that inputs are inputs with maximal and replicate sharding are set properly
|
||||
// for replicated model parallelism.
|
||||
// Tests that inputs are inputs with maximal and replicate sharding are set
|
||||
// properly for replicated model parallelism.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations
|
||||
@ -1392,8 +1403,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests devices are set properly for replicated model parallelism with
|
||||
// outputs to TPU computation placed on logical device 0.
|
||||
// Tests devices are set properly for replicated model parallelism with outputs
|
||||
// to TPU computation placed on logical device 0.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_different_outputs
|
||||
@ -1469,8 +1480,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests inputs are correctly split and fed into TPU computation for
|
||||
// tiled input sharding.
|
||||
// Tests inputs are correctly split and fed into TPU computation for tiled input
|
||||
// sharding.
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
|
@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
|
||||
def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Select op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ReshapeSelectPredIfNecessary : NativeCodeCall<
|
||||
"ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, "
|
||||
"$2.getType().cast<RankedTensorType>().getRank())">;
|
||||
|
||||
// Select supports tensor `condition` where the shape is equal to the first
|
||||
// dimension of t and e. SelectV2 op supports normal broadcasting, so in these
|
||||
// cases the condition needs to be reshaped.
|
||||
def SelectToSelectV2 : Pat<
|
||||
(TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond,
|
||||
StaticShapeTensorOf<[AnyType]>:$t,
|
||||
StaticShapeTensorOf<[AnyType]>:$e),
|
||||
(TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Square op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
@ -48,6 +48,9 @@ struct FreezeGlobalTensorsPass
|
||||
|
||||
void FreezeGlobalTensorsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
|
||||
return;
|
||||
}
|
||||
SymbolTable symbol_table(module);
|
||||
DenseSet<Operation*> frozen_global_tensors;
|
||||
|
||||
@ -66,7 +69,9 @@ void FreezeGlobalTensorsPass::runOnOperation() {
|
||||
// previous optimize global tensors pass). If not, this pass has to fail
|
||||
// since it cannot perform one of its goals.
|
||||
if (global_tensor.is_mutable()) {
|
||||
global_tensor.emitError() << "is not immutable";
|
||||
global_tensor.emitError() << "is not immutable, try running "
|
||||
"tf-saved-model-optimize-global-tensors "
|
||||
"to prove tensors are immutable";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -16,21 +16,61 @@ limitations under the License.
|
||||
// This file implements logic for legalizing HLO to TensorFlow.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::SliceOp slice_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
DenseIntElementsAttr strides = slice_op.strides();
|
||||
// Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op.
|
||||
if (!strides.isSplat() ||
|
||||
strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPointAfter(slice_op);
|
||||
auto start_indices = slice_op.start_indices();
|
||||
auto limit_indices = slice_op.limit_indices();
|
||||
std::vector<int64_t> size_values;
|
||||
for (auto pair : llvm::zip(start_indices.getValues<APInt>(),
|
||||
limit_indices.getValues<APInt>())) {
|
||||
size_values.emplace_back(std::get<1>(pair).getSExtValue() -
|
||||
std::get<0>(pair).getSExtValue());
|
||||
}
|
||||
|
||||
RankedTensorType ty =
|
||||
RankedTensorType::get({static_cast<int64_t>(size_values.size())},
|
||||
rewriter.getIntegerType(64));
|
||||
auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices);
|
||||
auto size = rewriter.create<ConstOp>(
|
||||
slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values));
|
||||
rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(),
|
||||
slice_op.operand(), start, size);
|
||||
return success();
|
||||
};
|
||||
};
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
public:
|
||||
LegalizeHloToTf() = default;
|
||||
@ -63,6 +103,7 @@ void LegalizeHloToTf::runOnFunction() {
|
||||
// Add legalization patterns to the list.
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&context, &patterns);
|
||||
patterns.insert<ConvertSliceOp>(&context);
|
||||
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<TensorFlowDialect>();
|
||||
|
@ -18,12 +18,16 @@ limitations under the License.
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
// Note that these are legalized from chlo.broadcast_* ops, since those are
|
||||
// semantically compatible with the corresponding TF ops. Depending on
|
||||
// context, getting to these ops may require some raising.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Check that two values can be broadcasted together
|
||||
@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
|
||||
"types must be broadcastable">;
|
||||
|
||||
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, TF_DivOp],
|
||||
[HLO_ShiftLeftOp, TF_LeftShiftOp],
|
||||
[HLO_MaxOp, TF_MaximumOp],
|
||||
[HLO_MinOp, TF_MinimumOp],
|
||||
[HLO_MulOp, TF_MulOp],
|
||||
[HLO_PowOp, TF_PowOp],
|
||||
[HLO_SubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, TF_ModOp]] in
|
||||
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
|
||||
foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp],
|
||||
[HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp],
|
||||
[HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp],
|
||||
[HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp],
|
||||
[HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp],
|
||||
[HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp],
|
||||
[HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in {
|
||||
def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
|
||||
def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, TF_BitwiseXorOp]] in
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
|
||||
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in {
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>;
|
||||
def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, TF_LogicalOrOp]] in
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
|
||||
foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in {
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>;
|
||||
def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
}
|
||||
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>;
|
||||
def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>;
|
||||
def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>;
|
||||
def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
|
||||
@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim),
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compare op patterns.
|
||||
// Note that these are legalized from chlo.broadcast_* ops, since those are
|
||||
// semantically compatible with the corresponding TF ops. Depending on
|
||||
// context, getting to these ops may require some raising.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in {
|
||||
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>;
|
||||
}
|
||||
|
||||
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
|
||||
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
|
||||
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in {
|
||||
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>;
|
||||
}
|
||||
|
@ -278,6 +278,10 @@ void EraseUnusedBoundInputs(ModuleOp module) {
|
||||
|
||||
void OptimizeGlobalTensorsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
|
||||
return;
|
||||
}
|
||||
|
||||
EraseUnusedBoundInputs(module);
|
||||
|
||||
ResourceAnalyzer resource_analyzer(module);
|
||||
|
@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass();
|
||||
// functions.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
||||
|
||||
// Creates a pass that converts readonly reference variables to the
|
||||
// corresponding resource variables.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
|
||||
|
||||
// Marks function visibility using tf.entry_function specification. That is,
|
||||
// functions with tf.entry_function attributes are marked with public
|
||||
// visibility while the other functions are marked with private visibility.
|
||||
|
@ -0,0 +1,179 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// Location attribute.
|
||||
constexpr StringRef kClassAttr = "_class";
|
||||
constexpr StringRef kLocationPrefix = "loc:@";
|
||||
|
||||
// A pass that converts readonly reference variables to the corresponding
|
||||
// resource variables.
|
||||
//
|
||||
// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
|
||||
//
|
||||
// For the background, this pass is a part of hoisting VariableV2 ops by
|
||||
// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
|
||||
// can be done by the following passes:
|
||||
// - Capturing resource values into global tensors (importing saved model).
|
||||
// - Promoting VarHandle ops to function input/outputs.
|
||||
// - Freezing global tensor pass.
|
||||
//
|
||||
// This path assumes that all the VariableV2 ops is read-only via verifying the
|
||||
// heuristic method that assumes that all the users of them is Identity op,
|
||||
// fed directly.
|
||||
class ConvertReadonlyReferenceVariablesToResourceVariablesPass
|
||||
: public PassWrapper<
|
||||
ConvertReadonlyReferenceVariablesToResourceVariablesPass,
|
||||
FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Parse node name from "_class" attribute.
|
||||
StringRef GetNodeNameFromClassAttr(Operation *op) {
|
||||
ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
|
||||
if (!classes_attr) {
|
||||
op->emitOpError() << "has no '_class' attribute";
|
||||
return StringRef();
|
||||
}
|
||||
|
||||
StringRef result;
|
||||
for (Attribute class_attr : classes_attr) {
|
||||
StringRef node_name = class_attr.cast<StringAttr>().getValue();
|
||||
if (!node_name.startswith(kLocationPrefix)) {
|
||||
continue;
|
||||
}
|
||||
if (!result.empty()) {
|
||||
// Invalid case since there are multiple loc:@ attributes.
|
||||
op->emitOpError()
|
||||
<< "expects only one named location in '_class' attribute, but got "
|
||||
<< classes_attr;
|
||||
return StringRef();
|
||||
}
|
||||
result = node_name.drop_front(kLocationPrefix.size());
|
||||
}
|
||||
if (result.empty()) {
|
||||
op->emitOpError() << "expects variable name in '_class' attribute, but got "
|
||||
<< classes_attr;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
|
||||
OpBuilder builder(func.getContext());
|
||||
SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
|
||||
|
||||
// Checks all the VariableV2 ops is read-only via verifying the heuristic
|
||||
// method that assumes that all the users of them is Identity op, feeded
|
||||
// directly.
|
||||
auto read_only_vars_fn = [&variable_v2s_to_replace](
|
||||
VariableV2Op variable_v2_op) {
|
||||
if (variable_v2_op.getResult().use_empty()) {
|
||||
// Erase the op when there is no user.
|
||||
variable_v2_op.erase();
|
||||
return mlir::WalkResult::advance();
|
||||
}
|
||||
if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
|
||||
Operation *user) {
|
||||
if (!isa<IdentityOp>(user)) {
|
||||
variable_v2_op.emitOpError()
|
||||
<< "expects all users to be 'tf.Identity', but got user "
|
||||
<< user->getName();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})) {
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
variable_v2s_to_replace.push_back(variable_v2_op);
|
||||
return mlir::WalkResult::advance();
|
||||
};
|
||||
|
||||
WalkResult walk_res = func.walk(read_only_vars_fn);
|
||||
if (walk_res.wasInterrupted()) return signalPassFailure();
|
||||
|
||||
for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
|
||||
builder.setInsertionPoint(variable_v2_op);
|
||||
ShapedType shaped_type =
|
||||
variable_v2_op.getResult().getType().cast<ShapedType>();
|
||||
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
|
||||
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
|
||||
if (!device_attr) device_attr = builder.getStringAttr("");
|
||||
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
|
||||
if (variable_name.empty()) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
VarHandleOp var_handle_op = builder.create<VarHandleOp>(
|
||||
variable_v2_op.getLoc(),
|
||||
ArrayRef<Type>{RankedTensorType::get(
|
||||
{}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
|
||||
builder.getContext()))},
|
||||
ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{
|
||||
builder.getNamedAttr("device", device_attr),
|
||||
builder.getNamedAttr("container", variable_v2_op.containerAttr()),
|
||||
builder.getNamedAttr("shared_name",
|
||||
builder.getStringAttr(variable_name))});
|
||||
for (Operation *user :
|
||||
make_early_inc_range(variable_v2_op.getResult().getUsers())) {
|
||||
builder.setInsertionPoint(user);
|
||||
ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
|
||||
user->getLoc(), ArrayRef<Type>{tensor_type},
|
||||
ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{});
|
||||
user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
|
||||
user->erase();
|
||||
}
|
||||
variable_v2_op.erase();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
|
||||
return std::make_unique<
|
||||
ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<
|
||||
ConvertReadonlyReferenceVariablesToResourceVariablesPass>
|
||||
pass("readonly-references-to-resources",
|
||||
"Convert readonly reference variables to resource variables.");
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
}
|
||||
auto walk_res = func_op.walk([&](Operation* op) {
|
||||
if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) {
|
||||
// Record VarHanldeOp's device attribute.
|
||||
// Record VarHandleOp's device attribute.
|
||||
auto device_attr =
|
||||
var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr);
|
||||
if (!device_attr || device_attr.getValue().empty()) {
|
||||
|
@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp(
|
||||
}
|
||||
|
||||
// Lifts loads/stores from while loop's body and cond functions.
|
||||
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&body.front());
|
||||
RemoveIdentity(&cond.front());
|
||||
@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&cond.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
|
||||
if (failed(HandleWhileLoop(while_op, body, cond))) return failure();
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch =
|
||||
llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch()));
|
||||
|
@ -429,7 +429,8 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||
// existing computed values.
|
||||
Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||
ValueQueryFn values) {
|
||||
LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for "));
|
||||
LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
|
||||
if (auto known = values(value_port)) return known;
|
||||
|
||||
auto op = value_port.producer.dyn_cast<Operation*>();
|
||||
if (!op) return nullptr;
|
||||
@ -454,6 +455,21 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||
ValuePort op_port(op->getOperand(port[1]));
|
||||
return values(op_port);
|
||||
}
|
||||
|
||||
if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
|
||||
if (port.size() == 1)
|
||||
return ComputeOutputComponent(
|
||||
ValuePort(graph.GetFetch().fetches()[port[0]]), values);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
if (port.size() == 1)
|
||||
return ComputeOutputComponent(
|
||||
ValuePort(island.GetYield().fetches()[port[0]]), values);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -462,7 +478,8 @@ Attribute ComputeOutputComponent(const ValuePort& value_port,
|
||||
// TF Graph version, constant values computed, etc.)
|
||||
class ShapeInference {
|
||||
public:
|
||||
ShapeInference(int64_t graph_version, MLIRContext* context);
|
||||
ShapeInference(int64_t graph_version, MLIRContext* context,
|
||||
bool propagate_caller_callee_constants);
|
||||
|
||||
LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
|
||||
ValuePortInputs* inputs) {
|
||||
@ -475,14 +492,19 @@ class ShapeInference {
|
||||
}
|
||||
|
||||
Attribute ComputeOutputComponent(const ValuePort& value_port) {
|
||||
return ::mlir::TF::ComputeOutputComponent(
|
||||
if (auto known_attr = results_[value_port]) return known_attr;
|
||||
auto attr = ::mlir::TF::ComputeOutputComponent(
|
||||
value_port, [this](const ValuePort& port) { return results_[port]; });
|
||||
RecordValue(value_port, attr);
|
||||
return attr;
|
||||
}
|
||||
|
||||
// Returns ShapeHandle if the op result could be computed as shape.
|
||||
ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
|
||||
|
||||
void RecordValue(const ValuePort& value_port, Attribute value) {
|
||||
LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
|
||||
<< value << "\n");
|
||||
results_[value_port] = value;
|
||||
}
|
||||
|
||||
@ -520,19 +542,41 @@ class ShapeInference {
|
||||
LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op,
|
||||
int64_t max_iteration);
|
||||
|
||||
// Propagates any constant operand of call_op to the called function body's
|
||||
// corresponding argument if the callee has only one use.
|
||||
//
|
||||
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
||||
// folding pass.
|
||||
void PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module);
|
||||
|
||||
// Propagates any constant return value of the callee function to the call
|
||||
// op's corresponding result.
|
||||
void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module);
|
||||
|
||||
// Tries to compute the result of folding the op. This doesn't actually
|
||||
// perform constant folding, it is just computes the equivalent constants.
|
||||
// Returns whether it was able to compute constant values.
|
||||
LogicalResult TryToFold(Operation* op);
|
||||
|
||||
private:
|
||||
// Mapping between ValuePort (which corresponds to an OpResult or smaller,
|
||||
// e.g., first element of OpResult produded) to an Attribute if the ValuePort
|
||||
// e.g., first element of OpResult produced) to an Attribute if the ValuePort
|
||||
// corresponds to a constant value.
|
||||
ValuePortResultMap results_;
|
||||
int64_t graph_version_;
|
||||
MLIRContext* context_;
|
||||
Dialect* tf_dialect_;
|
||||
|
||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
|
||||
// SCCP pass instead.
|
||||
bool propagate_caller_callee_constants_;
|
||||
};
|
||||
|
||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context)
|
||||
: graph_version_(graph_version) {
|
||||
context_ = context;
|
||||
ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
|
||||
bool propagate_caller_callee_constants)
|
||||
: graph_version_(graph_version),
|
||||
propagate_caller_callee_constants_(propagate_caller_callee_constants) {
|
||||
tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>();
|
||||
}
|
||||
|
||||
@ -581,7 +625,6 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||
auto ret = ComputeOutputComponent(front);
|
||||
if (!ret) continue;
|
||||
|
||||
RecordValue(front, ret);
|
||||
LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
|
||||
|
||||
// If worklist is empty, then this is the root query op.
|
||||
@ -602,6 +645,8 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
|
||||
}
|
||||
|
||||
bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
|
||||
llvm::dbgs() << "\n");
|
||||
assert(tf_dialect_ == op->getDialect());
|
||||
// The shape function of these ops sometimes does not propagate subtypes
|
||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||
@ -686,10 +731,14 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
size_t index = it.index();
|
||||
|
||||
// If the operand is constant, then convert it to Tensor.
|
||||
ElementsAttr attr;
|
||||
if (matchPattern(operand, m_Constant(&attr))) {
|
||||
ValuePort vp(operand);
|
||||
Attribute attr = ComputeOutputComponent(vp);
|
||||
if (!attr && matchPattern(operand, m_Constant(&attr)))
|
||||
RecordValue(vp, attr);
|
||||
if (attr) {
|
||||
tensorflow::Tensor* input_tensor = &tensors[index];
|
||||
auto status = tensorflow::ConvertToTensor(attr, input_tensor);
|
||||
auto status =
|
||||
tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor);
|
||||
if (status.ok()) {
|
||||
input_tensors[index] = input_tensor;
|
||||
} else {
|
||||
@ -728,10 +777,12 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
!input_tensors[input];
|
||||
});
|
||||
if (requires_inputs) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\trequired input\n");
|
||||
std::vector<ShapeHandle> input_tensors_as_shapes;
|
||||
for (int input : llvm::seq<int>(0, c.num_inputs())) {
|
||||
if (c.requested_input_tensor_as_partial_shape(input) &&
|
||||
!input_tensors[input]) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n");
|
||||
auto op_result = op->getOperand(input).dyn_cast<OpResult>();
|
||||
if (!op_result) continue;
|
||||
// Resize on first valid shape computed.
|
||||
@ -865,45 +916,62 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
return success(all_succeeded);
|
||||
}
|
||||
|
||||
// If the callee has only one use, propagates any constant operand of call_op to
|
||||
// the called function body's corresponding argument.
|
||||
//
|
||||
// TODO(b/154065712): Move this to a more general inter-procedural constant
|
||||
// folding pass.
|
||||
void PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||
void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym,
|
||||
ModuleOp module) {
|
||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
int num_uses = std::distance(func_uses->begin(), func_uses->end());
|
||||
if (num_uses != 1) return;
|
||||
|
||||
OpBuilder builder(&func.front().front());
|
||||
Operation* op = call_op.getOperation();
|
||||
if (num_uses == 1) {
|
||||
// If this is the only caller, and an operand is a constant, propagate
|
||||
// the constant inside the function.
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp();
|
||||
if (isa_and_nonnull<TF::ConstOp>(operand)) {
|
||||
arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0));
|
||||
// If this is the only caller, and an operand is a constant, propagate
|
||||
// the constant value inside the function.
|
||||
for (auto arg : func.getArguments()) {
|
||||
auto operand = op->getOperand(arg.getArgNumber());
|
||||
if (propagate_caller_callee_constants_) {
|
||||
if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
|
||||
arg.replaceAllUsesWith(
|
||||
builder.clone(*operand.getDefiningOp())->getResult(0));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto known_constant = ComputeOutputComponent(ValuePort(operand));
|
||||
if (!known_constant) continue;
|
||||
LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
|
||||
known_constant.print(llvm::dbgs() << " constant ");
|
||||
llvm::dbgs() << "\n");
|
||||
RecordValue(ValuePort(arg), known_constant);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagates any constant return value of the callee function to the call op's
|
||||
// corresponding result.
|
||||
void PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym, ModuleOp module) {
|
||||
void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
|
||||
SymbolRefAttr callee_sym,
|
||||
ModuleOp module) {
|
||||
auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference());
|
||||
// If the return value is a constant, replace the call result with a constant.
|
||||
// If the return value is a constant, use the constant as the value of
|
||||
// the call return.
|
||||
Operation* op = call_op.getOperation();
|
||||
OpBuilder builder(op);
|
||||
builder.setInsertionPointAfter(op);
|
||||
for (auto retval :
|
||||
llvm::enumerate(func.front().getTerminator()->getOperands())) {
|
||||
auto retval_op = retval.value().getDefiningOp();
|
||||
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||
op->getResult(retval.index())
|
||||
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||
if (propagate_caller_callee_constants_) {
|
||||
auto retval_op = retval.value().getDefiningOp();
|
||||
if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
|
||||
op->getResult(retval.index())
|
||||
.replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
ValuePort vp(retval.value());
|
||||
if (auto known_constant = ComputeOutputComponent(vp)) {
|
||||
LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
|
||||
call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
|
||||
RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -938,10 +1006,71 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShapeInference::TryToFold(Operation* op) {
|
||||
LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
|
||||
// If any output result is known, then the op probably has been computed
|
||||
// before.
|
||||
if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
|
||||
return success();
|
||||
|
||||
SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
|
||||
SmallVector<OpFoldResult, 8> fold_results;
|
||||
|
||||
// Check to see if any operands to the operation is constant and whether
|
||||
// the operation knows how to constant fold itself.
|
||||
bool some_unknown = false;
|
||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||
if (!(constant_operands[i] =
|
||||
ComputeOutputComponent(ValuePort(op->getOperand(i)))))
|
||||
some_unknown = true;
|
||||
}
|
||||
|
||||
// Attempt to constant fold the operation.
|
||||
auto* abstract_op = op->getAbstractOperation();
|
||||
LogicalResult folded = failure();
|
||||
if (abstract_op) {
|
||||
folded = abstract_op->foldHook(op, constant_operands, fold_results);
|
||||
}
|
||||
// Attempt dialect fallback if op's fold hook failed.
|
||||
if (failed(folded)) {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if (!dialect) return failure();
|
||||
// Only attempt TF dialect fallback if there are no unknown operands.
|
||||
if (some_unknown && dialect == tf_dialect_) return failure();
|
||||
SmallVector<Attribute, 8> constants;
|
||||
if (failed(dialect->constantFoldHook(op, constant_operands, constants)))
|
||||
return failure();
|
||||
fold_results.assign(constants.begin(), constants.end());
|
||||
}
|
||||
|
||||
for (auto result : zip(op->getResults(), fold_results)) {
|
||||
auto fold_result = std::get<1>(result);
|
||||
Attribute attr = nullptr;
|
||||
if ((attr = fold_result.dyn_cast<Attribute>())) {
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
} else {
|
||||
auto value = fold_result.get<Value>();
|
||||
if ((attr = ComputeOutputComponent(ValuePort(value))))
|
||||
RecordValue(ValuePort(std::get<0>(result)), attr);
|
||||
}
|
||||
|
||||
if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
|
||||
if (std::get<0>(result).getType() == eattr.getType()) continue;
|
||||
|
||||
// Inserts a cast back to the original type if any user is not in the
|
||||
// TF dialect.
|
||||
Type old_type = std::get<0>(result).getType();
|
||||
std::get<0>(result).setType(eattr.getType());
|
||||
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_,
|
||||
old_type);
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
int64_t max_iteration) {
|
||||
// An operation folder that is used to attempt folding before inference._
|
||||
OperationFolder folder(context_);
|
||||
bool changed = true;
|
||||
|
||||
// TODO(aminim): we could have a more efficient traversal by guiding the
|
||||
@ -955,9 +1084,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
region->walk([&](Operation* op) {
|
||||
if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
|
||||
changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_);
|
||||
// TODO(jpienaar): Debug why we can't just return here. We end up with
|
||||
// additional constant due to the propagation of constant into attached
|
||||
// function if we return already.
|
||||
return;
|
||||
}
|
||||
|
||||
if (op->getDialect() != tf_dialect_) {
|
||||
@ -965,8 +1092,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
return;
|
||||
}
|
||||
|
||||
// Before attempting inference, just try to fold the operation.
|
||||
if (succeeded(folder.tryToFold(op))) return;
|
||||
// Before attempting inference, just try to compute the folded
|
||||
// value/shape.
|
||||
if (succeeded(TryToFold(op))) return;
|
||||
|
||||
// Best-effort shape inference in attached functions. Do not return
|
||||
// failure even if it doesn't get to fixed point.
|
||||
@ -989,8 +1117,10 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region,
|
||||
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version) {
|
||||
ShapeInference context(graph_version, func.getContext());
|
||||
int64_t graph_version,
|
||||
bool propagate_caller_callee_constants) {
|
||||
ShapeInference context(graph_version, func.getContext(),
|
||||
propagate_caller_callee_constants);
|
||||
if (arg_shapes.empty()) {
|
||||
if (failed(context.InferShapeUntilFixPoint(&func.getBody())))
|
||||
return failure();
|
||||
@ -1014,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<int64_t> shape = arg_shapes[i];
|
||||
Type element_type;
|
||||
if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
|
||||
if (!input_ty || input_ty.getShape().size() != shape.size()) {
|
||||
if (input_ty.getRank() != shape.size()) {
|
||||
return failure();
|
||||
}
|
||||
element_type = input_ty.getElementType();
|
||||
|
@ -30,9 +30,11 @@ namespace TF {
|
||||
// Given a list of refined shapes matching the function arguments of func, runs
|
||||
// shape inference over the function to propagate this updated information.
|
||||
// If arg_shapes are empty, then argument shapes will be left unchanged.
|
||||
LogicalResult InferShapeForFunction(FuncOp func,
|
||||
ArrayRef<ArrayRef<int64_t>> arg_shapes,
|
||||
int64_t graph_version);
|
||||
// TODO(b/154065712): Remove propagate_caller_callee_constants once using
|
||||
// SCCP pass instead.
|
||||
LogicalResult InferShapeForFunction(
|
||||
FuncOp func, ArrayRef<ArrayRef<int64_t>> arg_shapes, int64_t graph_version,
|
||||
bool propagate_caller_callee_constants = true);
|
||||
|
||||
} // namespace TF
|
||||
|
||||
|
@ -47,8 +47,15 @@ namespace {
|
||||
|
||||
// This transformation pass propagate shapes on the TensorFlow graph.
|
||||
// It is a ModulePass in order to be able to change function types.
|
||||
struct ShapeInference
|
||||
class ShapeInference
|
||||
: public PassWrapper<ShapeInference, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
ShapeInference() = default;
|
||||
ShapeInference(const ShapeInference& that) {
|
||||
propagate_caller_callee_constants_ =
|
||||
that.propagate_caller_callee_constants_;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
|
||||
@ -58,10 +65,17 @@ struct ShapeInference
|
||||
}
|
||||
int64_t producer = producer_or.ValueOrDie();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer)))
|
||||
if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer,
|
||||
propagate_caller_callee_constants_)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Option<bool> propagate_caller_callee_constants_{
|
||||
*this, "propagate-caller-callee-constants",
|
||||
llvm::cl::desc("Propagate constants between callers and callees"),
|
||||
llvm::cl::init(true)};
|
||||
};
|
||||
|
||||
PassRegistration<ShapeInference> pass(
|
||||
|
@ -14,23 +14,30 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
@ -46,146 +53,184 @@ bool HasOutsideCompilationAttribute(Operation* op) {
|
||||
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
||||
}
|
||||
|
||||
// Returns whether all operands of `op` are from values inside the
|
||||
// `input_value_set`.
|
||||
bool OpContainsOperandsFromSet(Operation* op,
|
||||
const llvm::SetVector<Value>& input_value_set) {
|
||||
for (auto operand : op->getOperands())
|
||||
if (input_value_set.count(operand) == 0) return false;
|
||||
Operation* GetOpOfValue(Value value) {
|
||||
if (auto block_arg = value.dyn_cast<BlockArgument>())
|
||||
return block_arg.getOwner()->getParentOp();
|
||||
|
||||
return true;
|
||||
return value.getDefiningOp();
|
||||
}
|
||||
|
||||
void RecordOutsideCompiledOpsAndUsages(
|
||||
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
|
||||
llvm::SetVector<Value>* outside_compiled_op_usages) {
|
||||
if (HasOutsideCompilationAttribute(op) &&
|
||||
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
|
||||
outside_compiled_ops->insert(op);
|
||||
outside_compiled_op_usages->insert(op->getResults().begin(),
|
||||
op->getResults().end());
|
||||
// Returns a set of ops that are outside compiled and can be extracted to before
|
||||
// the TPU computation. These ops are either connected to the inputs of the TPU
|
||||
// computation or other ops that can be extracted, and have no dependencies with
|
||||
// other ops in the TPU computation that cannot be extracted.
|
||||
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
tf_device::ClusterOp cluster) {
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
auto cluster_ops = cluster.GetBody().without_terminator();
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
// An outside compiled op can be extracted if its operands are not from
|
||||
// other ops in the cluster that cannot be extracted.
|
||||
auto result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
if (operand_op->isProperAncestor(cluster) ||
|
||||
cluster_op.isAncestor(operand_op) ||
|
||||
head_outside_compiled_ops.count(operand_op))
|
||||
continue;
|
||||
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op);
|
||||
}
|
||||
|
||||
return head_outside_compiled_ops.takeVector();
|
||||
}
|
||||
|
||||
// Traverses the MLIR graph and returns a set of ops that
|
||||
// are connected to inputs of TPU computation and outside compiled.
|
||||
void ExtractOutsideCompiledOpsConnectedToHead(
|
||||
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
|
||||
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
|
||||
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
|
||||
for (auto& usage : input_value.getUses()) {
|
||||
auto head_operation = usage.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(head_operation,
|
||||
&parent_outside_compiled_ops_at_head,
|
||||
values_used_in_host_cluster);
|
||||
// Parses TPU compilation and execution devices from a TPU cluster and returns
|
||||
// the host device for the head and tail computations. If the TPU computation is
|
||||
// replicated, kTPUReplicatedHost is returned instead.
|
||||
LogicalResult GetHostDeviceForHeadTailComputation(
|
||||
mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster,
|
||||
std::string* host_device) {
|
||||
auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>();
|
||||
if (replicate) {
|
||||
*host_device = tensorflow::kTPUReplicatedHost;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Traverse the graph and find all outside compiled ops connected from
|
||||
// the `input_value`.
|
||||
while (!parent_outside_compiled_ops_at_head.empty()) {
|
||||
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
|
||||
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
|
||||
auto op_results = head_outside_compiled_op->getOpResults();
|
||||
for (auto op_result : op_results) {
|
||||
for (auto& use : op_result.getUses()) {
|
||||
auto connected_op = use.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(connected_op,
|
||||
&connected_outside_compiled_ops,
|
||||
values_used_in_host_cluster);
|
||||
auto num_cores_per_replica_attr =
|
||||
cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return cluster.emitOpError(
|
||||
"cluster op missing `num_cores_per_replica` attribute");
|
||||
|
||||
if (num_cores_per_replica_attr.getInt() != 1)
|
||||
return cluster.emitOpError(
|
||||
"outside compilation is not supported with model parallelism.");
|
||||
|
||||
auto topology_attr =
|
||||
cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return cluster.emitOpError("cluster op missing `topology` attribute");
|
||||
|
||||
auto device_assignment_attr =
|
||||
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
|
||||
auto status_or_device_coodinates =
|
||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||
|
||||
if (!status_or_device_coodinates.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching tpu device coordinates: "
|
||||
<< status_or_device_coodinates.status().error_message();
|
||||
|
||||
// Determine compilation and execution devices.
|
||||
auto status_or_tpu_device_assignment =
|
||||
tensorflow::GetTPUCompilationAndExecutionDevices(
|
||||
devices.device_names(), /*num_replicas=*/1,
|
||||
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
||||
status_or_device_coodinates.ConsumeValueOrDie());
|
||||
if (!status_or_tpu_device_assignment.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
<< status_or_tpu_device_assignment.status().error_message();
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
|
||||
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Moves head outside compiled ops into its own `tf_device.LaunchOp`
|
||||
// computation.
|
||||
tf_device::LaunchOp CreateHeadComputation(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
|
||||
llvm::StringRef host_device) {
|
||||
Block* launch_block = new Block;
|
||||
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
|
||||
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
|
||||
|
||||
// Find results of ops in head computation that needs to returned.
|
||||
llvm::SmallVector<Value, 4> launch_results;
|
||||
llvm::SmallVector<Type, 4> launch_result_types;
|
||||
for (Operation& head_outside_compiled_op : *launch_block) {
|
||||
for (Value result : head_outside_compiled_op.getResults()) {
|
||||
bool has_uses_in_cluster = false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (user->getParentRegion() &&
|
||||
cluster.body().isAncestor(user->getParentRegion())) {
|
||||
has_uses_in_cluster = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
|
||||
parent_outside_compiled_ops_at_head.end());
|
||||
std::swap(parent_outside_compiled_ops_at_head,
|
||||
connected_outside_compiled_ops);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hongjunchoi): Also handle ops without inputs that are outside
|
||||
// compiled.
|
||||
//
|
||||
// Returns set of ops that are outside compiled and are directly connected
|
||||
// to inputs to the TPU computation.
|
||||
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
|
||||
tf_device::ClusterOp tpu_cluster) {
|
||||
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
|
||||
llvm::SetVector<Value> values_used_in_cluster;
|
||||
auto& cluster_region = tpu_cluster.body();
|
||||
getUsedValuesDefinedAbove(cluster_region, cluster_region,
|
||||
values_used_in_cluster);
|
||||
|
||||
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
|
||||
for (auto input_value : input_value_list)
|
||||
ExtractOutsideCompiledOpsConnectedToHead(
|
||||
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
|
||||
return outside_compiled_at_head_ops;
|
||||
}
|
||||
|
||||
// Returns output values of extracted outside compiled cluster at head that
|
||||
// are used by the TPU computation.
|
||||
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
llvm::SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(head_outside_compiled_ops.size());
|
||||
|
||||
for (auto op : head_outside_compiled_ops) {
|
||||
for (Operation* user : op->getUsers()) {
|
||||
if (!head_outside_compiled_ops.count(user)) {
|
||||
outputs.append(op->result_begin(), op->result_end());
|
||||
break;
|
||||
if (has_uses_in_cluster) {
|
||||
launch_results.push_back(result);
|
||||
launch_result_types.push_back(result.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto launch = builder->create<tf_device::LaunchOp>(
|
||||
cluster.getLoc(), builder->getStringAttr(host_device),
|
||||
launch_result_types);
|
||||
launch.body().push_back(launch_block);
|
||||
|
||||
builder->setInsertionPointToEnd(&launch.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(cluster.getLoc(), launch_results);
|
||||
|
||||
for (auto result : llvm::zip(launch_results, launch.getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
|
||||
cluster.body());
|
||||
|
||||
return launch;
|
||||
}
|
||||
|
||||
// Creates new tf_device.launch op with outside compiled ops extracted
|
||||
// from the head of TPU computation.
|
||||
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
if (head_outside_compiled_ops.empty())
|
||||
return llvm::Optional<tf_device::LaunchOp>();
|
||||
|
||||
// Create tf_device.launch op to separate all extracted outside compiled ops
|
||||
// before the tf_device.cluster.
|
||||
auto output_values =
|
||||
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
|
||||
|
||||
llvm::SmallVector<Type, 8> output_return_types;
|
||||
output_return_types.reserve(output_values.size());
|
||||
for (auto output : output_values)
|
||||
output_return_types.emplace_back(output.getType());
|
||||
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto host_launch_op = builder->create<tf_device::LaunchOp>(
|
||||
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
|
||||
|
||||
// Replace all usages of outside compiled ops that are used in TPU
|
||||
// computation with the results of the above created launch op.
|
||||
for (auto output_and_index : llvm::enumerate(output_values)) {
|
||||
auto output_index = output_and_index.index();
|
||||
auto output = output_and_index.value();
|
||||
for (auto& use : output.getUses()) {
|
||||
if (!head_outside_compiled_ops.count(use.getOwner()))
|
||||
use.set(host_launch_op.getResult(output_index));
|
||||
// Removes aliased outputs in cluster from head computation after head
|
||||
// computation has been extracted.
|
||||
void RemoveHeadComputationAliasedOutputs(OpBuilder* builder,
|
||||
tf_device::LaunchOp head_computation,
|
||||
tf_device::ClusterOp cluster) {
|
||||
llvm::SmallVector<Value, 4> used_old_cluster_results;
|
||||
llvm::SmallVector<Value, 4> new_cluster_results;
|
||||
llvm::SmallVector<Type, 4> new_cluster_result_types;
|
||||
Operation* cluster_terminator = cluster.GetBody().getTerminator();
|
||||
for (auto result :
|
||||
llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
|
||||
Value cluster_terminator_operand = std::get<0>(result);
|
||||
if (cluster_terminator_operand.getDefiningOp() == head_computation) {
|
||||
std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
|
||||
} else {
|
||||
new_cluster_results.push_back(cluster_terminator_operand);
|
||||
new_cluster_result_types.push_back(cluster_terminator_operand.getType());
|
||||
used_old_cluster_results.push_back(std::get<1>(result));
|
||||
}
|
||||
}
|
||||
|
||||
// Create terminator op for the newly created launch op.
|
||||
host_launch_op.body().push_back(new Block());
|
||||
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
|
||||
auto terminator = builder->create<tf_device::ReturnOp>(
|
||||
host_launch_op.getLoc(), output_values);
|
||||
if (new_cluster_results.size() == cluster.getNumResults()) return;
|
||||
|
||||
// Move all outside compile ops from cluster op to launch op.
|
||||
for (auto outside_compiled_op : head_outside_compiled_ops)
|
||||
outside_compiled_op->moveBefore(terminator);
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||
cluster.getLoc(), new_cluster_result_types,
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||
new_cluster.body().takeBody(cluster.body());
|
||||
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
|
||||
|
||||
return host_launch_op;
|
||||
for (auto result :
|
||||
llvm::zip(used_old_cluster_results, new_cluster.getResults()))
|
||||
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
|
||||
|
||||
cluster.erase();
|
||||
}
|
||||
|
||||
struct TPUExtractHeadTailOutsideCompilation
|
||||
@ -202,17 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
return signalPassFailure();
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
module.walk([&](tf_device::ClusterOp cluster) {
|
||||
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
|
||||
IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
|
||||
head_outside_compiled_ops);
|
||||
llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
|
||||
module.walk(
|
||||
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||
|
||||
// TODO(b/156030523): Update device attribute of newly created host launch
|
||||
// op as well as enclosing Replicate op (if TPU computation is replicated)
|
||||
// with host device names.
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
if (head_outside_compiled_ops.empty()) continue;
|
||||
std::string host_device;
|
||||
if (failed(GetHostDeviceForHeadTailComputation(devices, cluster,
|
||||
&host_device)))
|
||||
return signalPassFailure();
|
||||
|
||||
// TODO(b/155115766): Implement tail outside compiled op extraction.
|
||||
});
|
||||
tf_device::LaunchOp head_computation = CreateHeadComputation(
|
||||
&builder, cluster, head_outside_compiled_ops, host_device);
|
||||
RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster);
|
||||
|
||||
// TODO(b/157160906): Implement tail outside compiled op extraction.
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user