Merge branch 'master' into 12829-extract_glimpse

This commit is contained in:
Mihai Maruseac 2020-05-19 15:35:19 +00:00 committed by GitHub
commit 158d128323
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
956 changed files with 40258 additions and 15186 deletions

View File

@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Enable using platform specific build settings
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:macos --copt=-w
build:windows --copt=/w
@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14

29
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,29 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for
compiler_assignees:
- joker-eph

View File

@ -103,17 +103,17 @@ open-source software development:
### Official Builds
Build Type | Status | Artifacts
------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
Build Type | Status | Artifacts
------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
### Community Supported Builds

View File

@ -3,6 +3,31 @@
## Breaking Changes
* `tf.image.extract_glimpse` has been updated to correctly process the case where `centered=False` and `normalized=False`. This is a breaking change as the output is different from (incorrect) previous versions. Note this breaking change only impacts `tf.image.extract_glimpse` and `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of `tf.compat.v1.image.extract_glimpse` does not change. The behavior of exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved models will not be impacted.
# Release 2.1.1
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x
# Release 2.0.2
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 1.15.3
## Bug Fixes and Other Changes
* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645)
* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601)
* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960)
* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770)
# Release 2.2.0
TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update).

View File

@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox.
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed
to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it
As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.

View File

@ -144,7 +144,7 @@ def write_to_bazelrc(line):
def write_action_env_to_bazelrc(var_name, var):
write_to_bazelrc('build --action_env %s="%s"' % (var_name, str(var)))
write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var)))
def run_shell(cmd, allow_non_zero=False, stderr=None):
@ -205,7 +205,7 @@ def setup_python(environ_cp):
# Get PYTHON_BIN_PATH, default is the current running python.
default_python_bin_path = sys.executable
ask_python_bin_path = ('Please specify the location of python. [Default is '
'%s]: ') % default_python_bin_path
'{}]: ').format(default_python_bin_path)
while True:
python_bin_path = get_from_env_or_user_or_default(environ_cp,
'PYTHON_BIN_PATH',
@ -215,9 +215,10 @@ def setup_python(environ_cp):
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path)
print('Invalid python path: {} cannot be found.'.format(python_bin_path))
else:
print('%s is not executable. Is it the python binary?' % python_bin_path)
print('{} is not executable. Is it the python binary?'.format(
python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = ''
# Convert python path to Windows style before checking lib and version
@ -236,7 +237,7 @@ def setup_python(environ_cp):
default_python_lib_path = python_lib_paths[0]
python_lib_path = get_input(
'Please input the desired Python library path to use. '
'Default is [%s]\n' % python_lib_paths[0])
'Default is [{}]\n'.format(python_lib_paths[0]))
if not python_lib_path:
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
@ -252,7 +253,7 @@ def setup_python(environ_cp):
# Set-up env variables used by python_configure.bzl
write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path)
write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path)
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path))
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# If choosen python_lib_path is from a path specified in the PYTHONPATH
@ -266,7 +267,7 @@ def setup_python(environ_cp):
with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
f.write('export PYTHON_BIN_PATH="{}"'.format(python_bin_path))
def reset_tf_configure_bazelrc():
@ -320,11 +321,12 @@ def get_var(environ_cp,
Raise the error to avoid infinitely looping.
"""
if not question:
question = 'Do you wish to build TensorFlow with %s support?' % query_item
question = 'Do you wish to build TensorFlow with {} support?'.format(
query_item)
if not yes_reply:
yes_reply = '%s support will be enabled for TensorFlow.' % query_item
yes_reply = '{} support will be enabled for TensorFlow.'.format(query_item)
if not no_reply:
no_reply = 'No %s' % yes_reply
no_reply = 'No {}'.format(yes_reply)
yes_reply += '\n'
no_reply += '\n'
@ -368,7 +370,7 @@ def get_var(environ_cp,
print(no_reply)
var = False
else:
print('Invalid selection: %s' % user_input_origin)
print('Invalid selection: {}'.format(user_input_origin))
return var
@ -1385,7 +1387,6 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_TENSORRT'] = '0'

View File

@ -530,6 +530,13 @@ package_group(name = "ndarray_tensor_allow_list")
# TODO(b/154762408) Remove this package group once it's no longer needed.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
# TODO(b/154650521) Remove.
package_group(
name = "types_whitelist",
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -85,7 +85,7 @@ tf_cuda_library(
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:chromiumos": [
":tf_attrtype",
@ -182,7 +182,7 @@ tf_cuda_library(
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_status",
@ -219,7 +219,7 @@ tf_cuda_library(
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
@ -232,12 +232,13 @@ cc_library(
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tf_status_internal",
"//tensorflow/core:lib",
],
}),
@ -259,10 +260,15 @@ cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
}),
)
cc_library(
@ -272,7 +278,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -286,16 +292,17 @@ cc_library(
srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -311,14 +318,15 @@ tf_cuda_library(
"tf_tensor_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = select({
deps = [
":tensor_interface",
":tf_datatype",
":tf_status",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
":tensor_interface",
":tf_datatype",
":tf_status",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
@ -386,8 +394,14 @@ tf_cuda_library(
deps = [
":tf_status",
":tf_status_internal",
"//tensorflow/core:lib",
],
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
tf_cc_test(
@ -426,7 +440,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@ -457,7 +471,7 @@ tf_cuda_library(
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
@ -484,7 +498,7 @@ tf_cuda_library(
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",

View File

@ -35,7 +35,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
@ -319,6 +319,7 @@ tf_cuda_cc_test(
tags = [
"noguitar", # TODO(b/155445984): flaky
#"guitar",
"notap", # TODO(b/156981931): flaky
"multi_gpu",
],
deps = [
@ -357,10 +358,13 @@ tf_cuda_cc_test(
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
@ -412,7 +416,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
@ -448,6 +452,8 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",

View File

@ -899,9 +899,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::Status::OK();
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->SyncExecutors();
status->status = tensorflow::unwrap(ctx)->AsyncWait();
#endif // !IS_MOBILE_PLATFORM
}
@ -924,7 +922,7 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
context->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;

View File

@ -137,7 +137,7 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status);
// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);

View File

@ -50,6 +50,13 @@ tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->at(task_index) =
tensorflow::strings::StrCat("localhost:", port);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -101,6 +108,22 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
TF_DeleteStatus(status);
}
// Read the value of variable `var` and save it into `out_value`.
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
TFE_TensorHandle** out_value) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_Execute(op, out_value, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@ -243,6 +266,102 @@ TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
TestRemoteExecuteUpdateServerDef(true);
}
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
TFE_TensorHandle* value_handle = nullptr;
ReadVariable(ctx, var_handle1, &value_handle);
CheckTFE_TensorHandleHasFloats(value_handle, {2});
TFE_DeleteTensorHandle(value_handle);
// Start a new worker to replace task:1
ReplaceTaskInServerDef(&server_def, 1);
server_def.set_task_index(1);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
// Update server def to replace the remote device with the device info on the
// new worker (different incarnation ID).
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// The device of var_handle0 is local device which is the same before and
// after cluster update. Remove resource with valid device should succeed.
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle0, status);
TFE_OpSetDevice(op, dev0_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
// The device of var_handle1 is remote device, which was replaced during
// cluster update. Removing resource with invalid device should fail
// gracefully (i.e., with error status) instead of crashing with segfaults.
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle1, status);
TFE_OpSetDevice(op, dev1_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
TestRemoteExecuteUpdateServerDefResourceAccess(false);
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
TestRemoteExecuteUpdateServerDefResourceAccess(true);
}
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
@ -282,6 +401,7 @@ void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{2, tensorflow::strings::StrCat("localhost:", port)});
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);

View File

@ -657,3 +657,17 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle);
}
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
context->SetLogDevicePlacement(enable);
}

View File

@ -549,6 +549,18 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -19,11 +19,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
@ -434,7 +439,26 @@ string AddVariablesFunction() {
return def.SerializeAsString();
}
TEST(CAPI, TestFunctionWithPackedInput) {
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
@ -474,6 +498,12 @@ TEST(CAPI, TestFunctionWithPackedInput) {
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
@ -502,6 +532,10 @@ TEST(CAPI, TestFunctionWithPackedInput) {
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
@ -537,6 +571,189 @@ TEST(CAPI, TestFunctionWithPackedInput) {
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);

View File

@ -1203,6 +1203,8 @@ void BM_ReadVariable(int iters) {
CHECK_EQ(0, TFE_TensorHandleNumDims(h, status));
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
h = nullptr;
TFE_OpAddInput(op, var_handle, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
tensorflow::testing::StopTiming();
TFE_DeleteOp(op);

View File

@ -150,6 +150,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -26,6 +28,51 @@ using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
return *factories;
}
static const char* default_factory = "<unset>";
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
(GetFactories()[name] == factory) &&
"Duplicate tracing factory registration");
GetFactories()[name] = factory;
}
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '",
default_factory, "' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
factories_sorted.insert(factory.first);
const char* comma = "";
for (const string& factory : factories_sorted) {
msg += comma + factory;
comma = ", ";
}
msg += ")";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
} // end namespace internal
} // end namespace tensorflow
// =============================================================================
// Public C API entry points
//
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
//
// =============================================================================
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
TF_DeleteExecutionContext(ctx);
return func;
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
@ -58,6 +127,10 @@ int TF_OutputListNumOutputs(TF_OutputList* o) {
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return wrap(unwrap(o)->outputs[i]);
}
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status* s) {
unwrap(o)->outputs.push_back(unwrap(tensor));
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {

View File

@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// Creates a context for tracing the execution of operations into a function.
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing
// tracing, a function can be obtained by TF_FinalizeFunction.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
@ -77,19 +88,21 @@ void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation.
// It just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
// TODO(aminim): the description above isn't clear with respect to
// TF_OutputListNumOutputs and the current eager implementation which requires
// the number of outputs to be set by the client.
// an operation, or provided to create a function.
// When executing an operation in an eager context, the expected number of
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
// Prepare tracing to the expected number of output for an operation.
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
// Return the number of outputs in the list.
int TF_OutputListNumOutputs(TF_OutputList* o);
// Return the `i`th output in the list.
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Append a tensor at the end of the output list, growing its size by one.
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph. The output tensors are
@ -100,13 +113,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_ExecutionContext* ctx, TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The returned TF_GraphToFunction must be deleted by the client.
// context. The provided `ctx` is consumed by this API call and deleted.
// The returned TF_AbstractFunction must be deleted by the client,
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status);
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList*, TF_Status*);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);

View File

@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
}
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
// adding them to the graph.
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext {
public:
GraphContext()
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
explicit GraphContext(const char* name)
: ExecutionContext(kKind),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
@ -136,6 +139,10 @@ class GraphContext : public ExecutionContext {
return;
}
auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
@ -164,24 +171,38 @@ class GraphContext : public ExecutionContext {
}
}
TF_Function* ToFunction(const char* fn_name, int num_inputs,
const GraphTensor* inputs, int num_outputs,
const GraphTensor* outputs, TF_Status* status) const {
std::vector<TF_Output> graph_inputs;
graph_inputs.resize(num_inputs);
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_OperationDescription* opdesc =
TF_NewOperation(graph_.get(), "Placeholder",
absl::StrCat("_input_", inputs_.size()).c_str());
TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.resize(num_outputs);
for (int i = 0; i < num_inputs; i++) {
graph_inputs[i] = inputs[i].output;
}
for (int i = 0; i < num_outputs; i++) {
graph_outputs[i] = outputs[i].output;
graph_outputs.reserve(outputs->outputs.size());
for (AbstractTensor* abstract_output : outputs->outputs) {
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
if (!output) {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
}
graph_outputs.push_back(output->output);
}
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
graph_inputs.size(), graph_inputs.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, fn_name, status);
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
@ -195,54 +216,20 @@ class GraphContext : public ExecutionContext {
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
const char* name_;
};
// Helper that converts the graph currently held in the context into a function.
static AbstractFunction* ExecutionContextToFunction(
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const AbstractTensor* inputs, int num_outputs,
const AbstractTensor* outputs, TF_Status* status) {
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
if (graph_ctx == nullptr) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"fn_body is not a TF_GraphContext.");
return nullptr;
}
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
if (!graph_inputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
return nullptr;
}
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
if (!graph_outputs) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef");
return true;
}();
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Graph API.
// =============================================================================
using tensorflow::internal::unwrap;
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
return wrap(new tensorflow::internal::GraphContext());
}
TF_AbstractFunction* TF_ExecutionContextToFunction(
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
const TF_AbstractTensor* inputs, int num_outputs,
const TF_AbstractTensor* outputs, TF_Status* status) {
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
unwrap(inputs), num_outputs,
unwrap(outputs), status));
}

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
@ -148,6 +149,17 @@ struct ExecutionContext {
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
@ -156,6 +168,11 @@ struct ExecutionContext {
const ExecutionContextKind k;
};
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \

View File

@ -29,7 +29,12 @@ using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifiedCAPI, TestBasicEager) {
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx);
}
TEST(UnifiedCAPI, TestBasicGraph) {
TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
// Start a new function / execution context.
string fn_name = "double";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
// Clean up operation and inputs.
TF_DeleteAbstractOp(add_op);
string fn_name = "double";
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
TF_AbstractFunction* func =
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractTensor(output_t);
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
@ -174,17 +160,160 @@ TEST(UnifiedCAPI, TestBasicGraph) {
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TF_DeleteAbstractTensor(final_result);
TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations, create a new list
// and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds(a, b):
* my_add1 = a + b
* my_add2 = b + b
* return my_add1, my_add2
*
* Now we will execute this function with an eager context:
*
* output1, output2 = two_adds(2.0, 3.0)
*
* and check that we got 5.0 and 6.0 as results.
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
input_eager = TestScalarTensorHandle(eager_ctx, 3.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(2, TF_OutputListNumOutputs(func_outputs));
float results[2];
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
results[idx] = *static_cast<float*>(TF_TensorData(f_t));
TF_DeleteTensor(f_t);
}
ASSERT_EQ(results[0], 5.0);
ASSERT_EQ(results[1], 6.0);
for (int idx = 0; idx < 2; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -193,18 +322,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, func);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(ctx);
}
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -222,10 +348,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -243,7 +369,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -273,7 +399,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
@ -289,10 +416,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
@ -349,5 +477,7 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
} // namespace
} // namespace tensorflow

View File

@ -101,6 +101,9 @@ class AbstractContextInterface {
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -44,6 +44,7 @@ tf_cc_test(
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -53,3 +54,19 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(
name = "parallel_device_ops_srcs",
srcs = ["parallel_device_ops.cc"],
visibility = ["//tensorflow/python/distribute/parallel_device:__pkg__"],
)
cc_library(
name = "parallel_device_ops",
srcs = [":parallel_device_ops_srcs"],
visibility = ["//tensorflow:internal"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)

View File

@ -92,6 +92,10 @@ class ParallelDevice {
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
@ -208,6 +212,46 @@ std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
@ -282,6 +326,13 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
// TODO(allenl): Figure out if we need this op, and if so whether we should move
// it to core TF. Right now the eager C API does some checking of op
// registrations before calling into custom devices, but we may be able to avoid
// that.
REGISTER_OP("DeviceID")
.Output("device_id: int64")
.SetIsStateful()
.SetShapeFn(tensorflow::shape_inference::ScalarShape);

View File

@ -278,14 +278,15 @@ TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
}
// Assert that `handle` is equal to `expected_value`.
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(expected_value,
*static_cast<float*>(TF_TensorData(value_zero.get())));
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
@ -343,8 +344,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 20.);
AssertScalarFloatEq(components[1].get(), 20.);
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -373,8 +374,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(components[0].get(), 23.);
AssertScalarFloatEq(components[1].get(), 18.);
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
@ -383,6 +384,32 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
@ -498,8 +525,8 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// The value of the original tensor is replicated on each device.
AssertScalarFloatEq(components[0].get(), 3.);
AssertScalarFloatEq(components[1].get(), 3.);
ExpectScalarEq<float>(components[0].get(), 3.);
ExpectScalarEq<float>(components[1].get(), 3.);
// Verify that the mirrors are placed on the component devices.
std::string first_device =
@ -630,7 +657,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
&second_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(second_components[1].get(), 9.);
ExpectScalarEq<float>(second_components[1].get(), 9.);
// Verify that the mirrors are placed on the component devices.
std::string first_device = TFE_TensorHandleBackingDeviceName(
@ -644,8 +671,8 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
std::array<TensorHandlePtr, 2> first_components;
ExtractPerDeviceValues(context.get(), second_components[0].get(),
&first_components, status.get());
AssertScalarFloatEq(first_components[0].get(), 3.);
AssertScalarFloatEq(first_components[1].get(), 6.);
ExpectScalarEq<float>(first_components[0].get(), 3.);
ExpectScalarEq<float>(first_components[1].get(), 6.);
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
status.get());
@ -806,8 +833,8 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 3.);
AssertScalarFloatEq(result_components[1].get(), 3.);
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
}
void RegisterCollectiveMulFunction(TFE_Context* context,
@ -909,8 +936,8 @@ TEST(PARALLEL_DEVICE, TestFunction) {
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
std::string first_device = TFE_TensorHandleBackingDeviceName(
result_components[0].get(), status.get());

View File

@ -31,9 +31,6 @@ cc_library(
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
],
copts = tf_copts(),
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
# so that we can depend on c/eager/c_api_unified_experimental.h.
features = ["-layering_check"],
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
@ -41,6 +38,8 @@ cc_library(
":concrete_function_type",
":function_metadata",
":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
@ -160,6 +159,38 @@ cc_library(
],
)
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
@ -29,10 +29,9 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
&tensorflow::unwrap(func)->GetFunctionMetadata()));
}
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
// internal header, and implement this function.
return nullptr;
const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func) {
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
}
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
return tensorflow::unwrap(list)->size();
}
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>,
TF_TensorHandleList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_

View File

@ -24,6 +24,7 @@ exports_files(
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
"tensorhandle_list.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -39,6 +40,7 @@ cc_library(
":concrete_function_list",
":function_metadata",
":saved_model_api",
":tensorhandle_list",
],
)
@ -61,3 +63,8 @@ alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus
extern "C" {
@ -36,7 +36,7 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
TF_ConcreteFunction* func);
// Returns a list of TensorHandles implicitly captured by this function.
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
TF_ConcreteFunction* func);
// Returns a TFE_Op suitable for executing this function.

View File

@ -21,19 +21,27 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
// Returns the size of `list`.
TF_CAPI_EXPORT size_t
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize(
TF_ConcreteFunctionList* list);
// Returns the `i`th TF_ConcreteFunction in the list.
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet(
TF_ConcreteFunctionList* list, int i);
// Deletes `list`.
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList(
TF_ConcreteFunctionList* list);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -62,3 +62,17 @@ cc_library(
"//tensorflow/c:tf_tensor",
],
)
cc_library(
name = "tensorhandle",
hdrs = [
"tensorhandle.h",
],
deps = [
":runtime",
":status",
":tensor",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Runtime represents an opaque instance of a Tensorflow runtime, with its own
@ -40,6 +41,7 @@ class Runtime {
private:
friend class RuntimeBuilder;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TFE_Context. Takes ownership of ctx.
explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {}
@ -63,6 +65,7 @@ class Runtime {
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime.
@ -79,6 +80,7 @@ inline std::unique_ptr<Runtime> RuntimeBuilder::Build(Status* status) {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Status is a wrapper around an error code and an optional error message.
@ -57,6 +58,7 @@ class Status {
friend class RuntimeBuilder;
friend class Runtime;
friend class SavedModelAPI;
friend class TensorHandle;
// Wraps a TF_Status*, and takes ownership of it.
explicit Status(TF_Status* status) : status_(status) {}
@ -88,6 +90,7 @@ inline void Status::SetStatus(TF_Code code, const std::string& msg) {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/status.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// Tensor represents an n-dimensional array of values.
@ -168,6 +169,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype,
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/status.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// An opaque representation of a tensor computed/managed by the Tensorflow
// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer
// to tensors placed in memory of different devices or remote address spaces.
// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created
// from it.
class TensorHandle {
public:
// Unwraps a Tensor from the given TensorHandle. If an error occurred,
// status->ok() will be false, and the returned Tensor must not be used.
Tensor Resolve(Status* status);
// Constructs a TensorHandle from a Tensor. If an error occurred,
// status->ok() will be false, and the returned TensorHandle must not be used.
static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime,
Status* status);
// TensorHandle is movable, and not copyable
TensorHandle(TensorHandle&&) = default;
TensorHandle& operator=(TensorHandle&&) = default;
private:
// Wraps a TFE_TensorHandle. Takes ownership of handle.
explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {}
// TensorHandle is not copyable
TensorHandle(const TensorHandle&) = delete;
TensorHandle& operator=(const TensorHandle&) = delete;
// Returns the underlying TFE_TensorHandle that this object wraps.
// This object retains ownership of the pointer.
TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); }
// Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle,
// and takes ownership of handle.
void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); }
struct TFETensorHandleDeleter {
void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
};
std::unique_ptr<TFE_TensorHandle, TFETensorHandleDeleter> handle_;
};
inline Tensor TensorHandle::Resolve(Status* status) {
TF_Tensor* tensor =
TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus());
if (!status->ok()) {
return Tensor(nullptr);
}
return Tensor(tensor);
}
inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor,
const Runtime& runtime,
Status* status) {
TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor(
runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus());
if (!status->ok()) {
return TensorHandle(nullptr);
}
return TensorHandle(tensor_handle);
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_

View File

@ -5,12 +5,22 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tensor_types_test_util",
testonly = True,
hdrs = ["tensor_types_test_util.h"],
deps = [
"//tensorflow/c:tf_datatype",
],
)
tf_cc_test(
name = "tensor_test",
srcs = [
"tensor_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
@ -19,3 +29,22 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "tensorhandle_test",
srcs = [
"tensorhandle_test.cc",
],
deps = [
":tensor_types_test_util",
"//tensorflow/c:tf_datatype",
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",
"//tensorflow/cc/experimental/base/public:status",
"//tensorflow/cc/experimental/base/public:tensor",
"//tensorflow/cc/experimental/base/public:tensorhandle",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)

View File

@ -16,69 +16,22 @@ limitations under the License.
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include <stddef.h>
#include <cstdint>
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
using SimpleTypes =
::testing::Types<FloatType, DoubleType, Int32Type, UINT8Type, INT8Type,
INT64Type, UINT16Type, UINT32Type, UINT64Type>;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorTest : public ::testing::Test {};
@ -88,14 +41,13 @@ TYPED_TEST_SUITE(ConstructScalarTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(ConstructScalarTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
cc::Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
Tensor tensor = Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
@ -113,7 +65,7 @@ TYPED_TEST_SUITE(Construct1DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
@ -121,7 +73,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
std::vector<int64_t> shape;
shape.push_back(value.size());
cc::Tensor tensor = cc::Tensor::FromBuffer(
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@ -130,7 +82,7 @@ TYPED_TEST(Construct1DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view(
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@ -152,14 +104,14 @@ TYPED_TEST_SUITE(Construct2DTensorTest, SimpleTypes);
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
cc::Status status;
Status status;
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
cc::Tensor tensor = cc::Tensor::FromBuffer(
Tensor tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
@ -169,7 +121,7 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
gtl::ArraySlice<typename TypeParam::type> tensor_view(
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
@ -185,22 +137,22 @@ TYPED_TEST(Construct2DTensorTest, ValidTensorAttributesAfterConstruction) {
TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
bool done = false;
cc::Status status;
Status status;
std::vector<int32_t> data_vector({12, 14, 20, 18, 39, 42, 100});
{
// data_vector is a rank 1 tensor.
std::vector<int64_t> shape;
shape.push_back(data_vector.size());
cc::Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
Tensor::DeleterCallback callback = [&done](void* data, size_t len) {
done = true;
};
cc::Tensor tensor =
cc::Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status);
Tensor tensor =
Tensor::FromBuffer(/*dtype=*/TF_INT32, /*shape=*/shape,
/*data=*/data_vector.data(),
/*len=*/data_vector.size() * sizeof(int32_t),
/*deleter=*/callback, &status);
ASSERT_TRUE(status.ok()) << status.message();
}
// At this point, tensor has been destroyed, and the deleter callback should
@ -209,4 +161,3 @@ TEST(CPPTensorAPI, ConstructTensorFromBuffer) {
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_
#include <stdint.h>
#include "tensorflow/c/tf_datatype.h"
namespace tensorflow {
// Each of the following struct types have two members: a kDType that
// corresponds to a TF_Datatype enum value, and a typedef "type"
// of its corresponding C++ type. These types allow us to write Dtype-agnostic
// tests via GoogleTest's TypedTests:
// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests
struct FloatType {
using type = float;
static constexpr TF_DataType kDType = TF_FLOAT;
};
struct DoubleType {
using type = double;
static constexpr TF_DataType kDType = TF_DOUBLE;
};
struct Int32Type {
using type = int32_t;
static constexpr TF_DataType kDType = TF_INT32;
};
struct UINT8Type {
using type = uint8_t;
static constexpr TF_DataType kDType = TF_UINT8;
};
struct INT8Type {
using type = int8_t;
static constexpr TF_DataType kDType = TF_INT8;
};
struct INT64Type {
using type = int64_t;
static constexpr TF_DataType kDType = TF_INT64;
};
struct UINT16Type {
using type = uint16_t;
static constexpr TF_DataType kDType = TF_UINT16;
};
struct UINT32Type {
using type = uint32_t;
static constexpr TF_DataType kDType = TF_UINT32;
};
struct UINT64Type {
using type = uint64_t;
static constexpr TF_DataType kDType = TF_UINT64;
};
} // namespace tensorflow
#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TEST_TENSOR_TYPES_TEST_UTIL_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/experimental/base/public/tensorhandle.h"
#include <stddef.h>
#include <stdint.h>
#include <memory>
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/cc/experimental/base/public/runtime.h"
#include "tensorflow/cc/experimental/base/public/runtime_builder.h"
#include "tensorflow/cc/experimental/base/public/tensor.h"
#include "tensorflow/cc/experimental/base/tests/tensor_types_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::Status;
using tensorflow::experimental::cc::Tensor;
using tensorflow::experimental::cc::TensorHandle;
using SimpleTypes = ::testing::Types<
tensorflow::FloatType, tensorflow::DoubleType, tensorflow::Int32Type,
tensorflow::UINT8Type, tensorflow::INT8Type, tensorflow::INT64Type,
tensorflow::UINT16Type, tensorflow::UINT32Type, tensorflow::UINT64Type>;
template <typename T>
class ConstructScalarTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(ConstructScalarTensorHandleTest, SimpleTypes);
// This test constructs a scalar tensor for each of the types in "SimpleTypes",
// then wraps it in a TensorHandle. We then unwrap it back into a Tensor, and
// verify the expected dims, dtype, value, num bytes, and num elements.
TYPED_TEST(ConstructScalarTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
typename TypeParam::type value = 42;
Tensor original_tensor =
Tensor::FromBuffer(/*dtype=*/dtype, /*shape=*/{},
/*data=*/&value,
/*len=*/sizeof(value),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 0);
EXPECT_EQ(tensor.dtype(), dtype);
EXPECT_EQ(*reinterpret_cast<typename TypeParam::type*>(tensor.data()), 42);
EXPECT_EQ(tensor.num_bytes(), sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), 1);
}
template <typename T>
class Construct1DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct1DTensorHandleTest, SimpleTypes);
// This test constructs a 1D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct1DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 1 vector.
std::vector<int64_t> shape;
shape.push_back(value.size());
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 1);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
template <typename T>
class Construct2DTensorHandleTest : public ::testing::Test {};
TYPED_TEST_SUITE(Construct2DTensorHandleTest, SimpleTypes);
// This test constructs a 2D tensor for each of the types in "SimpleTypes",
// and verifies the expected dimensions, dtype, value, number of bytes, and
// number of elements.
TYPED_TEST(Construct2DTensorHandleTest,
ValidTensorAttributesAfterConstruction) {
Status status;
RuntimeBuilder runtime_builder;
std::unique_ptr<Runtime> runtime = runtime_builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
TF_DataType dtype = TypeParam::kDType;
// This is our 1D tensor of varying dtype.
std::vector<typename TypeParam::type> value = {42, 100, 0, 1, 4, 29};
// Shape is Rank 2 vector with shape 2 x 3.
std::vector<int64_t> shape({2, 3});
Tensor original_tensor = Tensor::FromBuffer(
/*dtype=*/dtype, /*shape=*/shape,
/*data=*/value.data(),
/*len=*/value.size() * sizeof(typename TypeParam::type),
/*deleter=*/[](void*, size_t) {}, &status);
ASSERT_TRUE(status.ok()) << status.message();
TensorHandle handle =
TensorHandle::FromTensor(original_tensor, *runtime, &status);
ASSERT_TRUE(status.ok()) << status.message();
Tensor tensor = handle.Resolve(&status);
ASSERT_TRUE(status.ok()) << status.message();
EXPECT_EQ(tensor.dims(), 2);
EXPECT_EQ(tensor.dtype(), dtype);
tensorflow::gtl::ArraySlice<typename TypeParam::type> tensor_view(
reinterpret_cast<typename TypeParam::type*>(tensor.data()), value.size());
EXPECT_EQ(tensor_view[0], 42);
EXPECT_EQ(tensor_view[1], 100);
EXPECT_EQ(tensor_view[2], 0);
EXPECT_EQ(tensor_view[3], 1);
EXPECT_EQ(tensor_view[4], 4);
EXPECT_EQ(tensor_view[5], 29);
EXPECT_EQ(tensor.num_bytes(),
value.size() * sizeof(typename TypeParam::type));
EXPECT_EQ(tensor.num_elements(), value.size());
}
} // namespace
} // namespace tensorflow

View File

@ -84,7 +84,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
]) + if_android([
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
]),
)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunction is an executable "function" loaded from a SavedModelAPI.
@ -54,6 +55,7 @@ inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// ConcreteFunctionList helps convert an opaque pointer to an array of
@ -56,6 +57,7 @@ inline std::vector<ConcreteFunction*> ConcreteFunctionList::ToVector() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// FunctionMetadata stores additional function information, including
@ -40,6 +41,7 @@ class FunctionMetadata final {
};
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h"
namespace tensorflow {
namespace experimental {
namespace cc {
// SavedModelAPI offers a way to load Tensorflow Saved Models
@ -155,6 +156,7 @@ inline std::vector<ConcreteFunction*> SavedModelAPI::ListFunctions() {
}
} // namespace cc
} // namespace experimental
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_

View File

@ -26,10 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
using tensorflow::experimental::cc::Runtime;
using tensorflow::experimental::cc::RuntimeBuilder;
using tensorflow::experimental::cc::SavedModelAPI;
using tensorflow::experimental::cc::Status;
constexpr char kTestData[] = "cc/saved_model/testdata";
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
@ -43,21 +47,21 @@ std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
class CPPSavedModelAPITest : public ::testing::TestWithParam<bool> {};
TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
cc::Status status;
cc::RuntimeBuilder builder;
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unordered_set<std::string> tags = {"serve"};
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status, &tags);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@ -67,20 +71,20 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModelWithTags) {
}
TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
cc::Status status;
cc::RuntimeBuilder builder;
Status status;
RuntimeBuilder builder;
bool use_tfrt = GetParam();
if (use_tfrt) {
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
}
builder.SetUseTFRT(use_tfrt);
std::unique_ptr<cc::Runtime> runtime = builder.Build(&status);
std::unique_ptr<Runtime> runtime = builder.Build(&status);
ASSERT_TRUE(status.ok()) << status.message();
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
std::unique_ptr<cc::SavedModelAPI> model =
cc::SavedModelAPI::Load(model_dir, *runtime, &status);
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
@ -94,4 +98,3 @@ INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
} // namespace
} // namespace tensorflow

View File

@ -42,7 +42,8 @@ def tf_library(
mlir_components = "None",
deps = None,
tags = []):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
"""Runs tfcompile to compile a TensorFlow graph into executable code with fast
math enabled on cpu.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
@ -207,6 +208,15 @@ def tf_library(
srcs.append(debug_info)
debug_info_flag = " --debug_info=$(location " + debug_info + ")"
default_fast_math_xla_flags = ("XLA_FLAGS='" +
"--xla_cpu_enable_fast_math=true " +
"--xla_cpu_fast_math_honor_nans=false " +
"--xla_cpu_fast_math_honor_infs=false " +
"--xla_cpu_fast_math_honor_functions=false " +
"--xla_cpu_fast_math_honor_division=false " +
"--xla_cpu_enable_fast_min_max=true " +
"$${XLA_FLAGS:-}' ")
native.genrule(
name = ("gen_" + name),
srcs = srcs,
@ -216,6 +226,7 @@ def tf_library(
function_object_file,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
@ -256,6 +267,7 @@ def tf_library(
session_module_pb,
],
cmd = (
default_fast_math_xla_flags +
"CUDA_VISIBLE_DEVICES='' " +
"$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +

View File

@ -67,6 +67,8 @@ int main(int argc, char** argv) {
flags.entry_point = "entry";
flags.debug_info_path_begin_marker = "";
// Note that tfcompile.bzl's tf_library macro sets fast math flags as that is
// generally the preferred case.
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);
xla::AppendDebugOptionsFlags(&flag_list);

View File

@ -251,7 +251,7 @@ cc_library(
visibility = [":friends"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:portable_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:graph",

View File

@ -77,10 +77,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo",
"//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes",
"//tensorflow/compiler/mlir/tfrt:lower_tf_to_tfd_alwayslink",
"//tensorflow/compiler/mlir/tfrt:runtime_fallback_opdefs_alwayslink",
"//tensorflow/compiler/mlir/tfrt:tf_legalize_to_tfrt",
"//tensorflow/compiler/mlir/tfrt:tf_to_corert",
],
)
@ -152,7 +148,6 @@ tf_cc_binary(
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
"//tensorflow/compiler/mlir/tfrt:compatibility_analysis",
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -31,7 +31,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)

View File

@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
std::string node_def_str;
if (!node_def.SerializeToString(&node_def_str)) {
return emitError(loc, "failed to serialize tensorflow node_def"),
llvm::None;
}
auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
return builder_.CreateVector(flex_builder->GetBuffer());
}
@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs(
const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
auto flex_builder = absl::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();
for (const auto& pair : node_def.attr()) {
using Item = std::pair<std::string, ::tensorflow::AttrValue>;
std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
std::sort(attrs.begin(), attrs.end(),
[](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
for (const Item& pair : attrs) {
const char* key = pair.first.c_str();
const auto& attr = pair.second;
const ::tensorflow::AttrValue& attr = pair.second;
switch (attr.value_case()) {
case ::tensorflow::AttrValue::kS:
flex_builder->String(key, attr.s());

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td"
include "tensorflow/compiler/mlir/lite/quantization/quantization.td"
@ -414,9 +414,9 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
}];
let arguments = (
ins TFL_TensorOf<[F32, QI8, QUI8]>:$input,
ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input,
TFL_TensorOf<[F32, QI8, QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, I32]>:$bias,
TFL_TensorOfOrNone<[F32, I32, I64]>:$bias,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
@ -425,7 +425,7 @@ class TFL_ConvOp<string mnemonic, string opSummary, int index> :
I32Attr:$stride_w
);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8]>:$output);
let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output);
let hasOptions = 0b1;
}
@ -1561,10 +1561,12 @@ def TFL_GreaterOp : TFL_Op<"greater", [
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
}
def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
TFL_GpuTargetOp]> {
def TFL_HardSwishOp: TFL_Op<"hard_swish", [
NoSideEffect,
SameOperandsAndResultShape,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_GpuTargetOp]> {
let summary = "Hardswish activation function.";
let description = [{
Computes hard-swish activation function
@ -1574,7 +1576,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [NoSideEffect,
let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$input);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$out);
let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output);
let hasOptions = 0;
}
@ -1606,7 +1608,8 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect,
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [
SameOperandsAndResultShape,
NoSideEffect,
SameOperandsAndResultType]> {
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Leaky Relu operator";
let description = [{
@ -1740,7 +1743,8 @@ def TFL_LogOp: TFL_Op<"log", [
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
// zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1)
FixedResultScale<Int8UniformQuantizedType<127, 625, -4>>,
@ -1896,11 +1900,11 @@ Rounds the values of a tensor to the nearest integer, element-wise.
}];
let arguments = (ins
TFL_TensorOf<[F32]>:$x
TFL_FpTensor:$x
);
let results = (outs
TFL_TensorOf<[F32]>:$y
TFL_FpTensor:$y
);
}
@ -2443,9 +2447,9 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
Computes element-wise reverse square root of input
}];
let arguments = (ins AnyTensor:$x);
let arguments = (ins TFL_FpTensor:$x);
let results = (outs AnyTensor:$y);
let results = (outs TFL_FpTensor:$y);
let hasFolder = 1;
}
@ -3361,9 +3365,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [
let results = (outs AnyTensor:$output);
}
def TFL_DensifyOp: TFL_Op<"densify", [NoSideEffect,
SameOperandsAndResultType,
NoQuantizableResult]> {
def TFL_DensifyOp: TFL_Op<"densify", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoQuantizableResult]> {
let summary = "Densify operator";
let description = [{

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
namespace lite {
@ -38,6 +39,7 @@ namespace lite {
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,
@ -73,7 +75,7 @@ TfLiteStatus QuantizeModel(
// Apply quantization passes
PassManager pm(module->getContext());
TFL::QuantizationSpecs quant_specs;
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.inference_type = tflite::TflTypeToTfType(inference_type);
quant_specs.post_training_quantization = true;
quant_specs.disable_per_channel = disable_per_channel;
@ -81,8 +83,10 @@ TfLiteStatus QuantizeModel(
auto input_tf_type = tflite::TflTypeToTfType(input_type);
if (input_tf_type == tensorflow::DT_FLOAT) {
emit_adaptor = true;
} else if (input_tf_type == tensorflow::DT_UINT8) {
quant_specs.inference_type = tensorflow::DT_QUINT8;
} else if (input_tf_type == tensorflow::DT_UINT8 ||
input_tf_type == tensorflow::DT_INT8 ||
input_tf_type == tensorflow::DT_INT16) {
quant_specs.inference_type = input_tf_type;
}
pm.addPass(TFL::CreatePrepareQuantizePass(quant_specs));

View File

@ -26,11 +26,13 @@ namespace mlir {
namespace lite {
// Quantize the `input_model` and write the result to a flatbuffer `builder`.
// The `input_type` and `output_type` can be float32/qint8/int8.
// The `input_type`, `output_type` and `inference_type` can be
// float32/qint8/int8/int16.
// Return partially quantized model if `fully_quantize` is false.
TfLiteStatus QuantizeModel(
const tflite::ModelT& input_model, const tflite::TensorType& input_type,
const tflite::TensorType& output_type,
const tflite::TensorType& inference_type,
const std::unordered_set<std::string>& operator_names,
bool disable_per_channel, bool fully_quantize,
flatbuffers::FlatBufferBuilder* builder,

View File

@ -46,7 +46,8 @@ TfLiteStatus QuantizeAnnotatedModel(llvm::StringRef buffer,
tflite::StderrReporter error_reporter;
return mlir::lite::QuantizeModel(
*model, tflite::TensorType_INT8, tflite::TensorType_INT8, {},
*model, tflite::TensorType_INT8, tflite::TensorType_INT8,
tflite::TensorType_INT8, {},
/*disable_per_channel=*/false,
/*fully_quantize=*/true, builder, &error_reporter);
}

View File

@ -90,7 +90,7 @@ struct QuantizationSpecs {
bool RunWeightQuantization() const { return weight_quantization; }
// Whether this inference type represents a signed storage type.
bool IsSignedInferenceType() {
bool IsSignedInferenceType() const {
switch (inference_type) {
case tensorflow::DT_QUINT8:
case tensorflow::DT_QUINT16:
@ -102,7 +102,7 @@ struct QuantizationSpecs {
// Gets the width of this quantization type. Returns 0 if it isn't a
// quantization type.
int64_t GetQuantizationTypeWidth() {
int64_t GetQuantizationTypeWidth() const {
switch (inference_type) {
case tensorflow::DT_QINT8:
case tensorflow::DT_QUINT8:

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -35,6 +36,7 @@ limitations under the License.
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
namespace mlir {
@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
}
};
// Fold Extra Requantize ops if the preceding ops has free scale requirement.
template <typename RQ>
struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
explicit FoldTrivalRequantizeOp(MLIRContext* context)
: OpRewritePattern<RQ>(context, 1) {}
LogicalResult matchAndRewrite(RQ op,
PatternRewriter& rewriter) const override {
Value pre_quantized = op.input();
auto pre_quantized_type =
quant::QuantizedType::getQuantizedElementType(pre_quantized.getType());
if (!pre_quantized_type) return failure();
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}
op.emitWarning("Remove trivial `rescale` op. Please fix the source graph.");
llvm::SmallVector<Type, 4> new_output_types;
for (auto result : def->getResults()) {
result.getUsers().begin()->dump();
op.dump();
if (result.hasOneUse() && *result.getUsers().begin() == op) {
new_output_types.push_back(op.qtype());
} else {
new_output_types.push_back(result.getType());
}
}
// Remove this rescale op.
rewriter.replaceOp(op, {pre_quantized});
// Replace the output scale of the preceding op.
rewriter.setInsertionPointAfter(def);
OperationState new_state(def->getLoc(), def->getName().getStringRef(),
def->getOperands(), new_output_types,
def->getAttrs());
Operation* new_op = rewriter.createOperation(new_state);
rewriter.replaceOp(def, new_op->getResults());
return success();
}
};
// Given a quantized type `input`, magnifying its scales by the factor stored in
// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the
// dimension size of `input` or isn't floating-point, nullptr will be returned.

View File

@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> {
return %1 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacent
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE]]
}
// Checks that tfl.reshape should be removed if its output has more than one
@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32>
return %3 : tensor<64xf32>
// CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse
// CHECK: %cst = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %2 = addf %0, %1
// CHECK: return %2
// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]]
// CHECK: return %[[RESULT]]
}
// Checks that tfl.reshape should be kept if its output has more than one
@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32
return %0, %1 : tensor<16x4xf32>, tensor<64xf32>
// CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse
// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %cst_0 = constant dense<64> : tensor<1xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %0, %1
// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32>
// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32>
// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32>
// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]]
}
// Checks that tfl.reshape should be removed if its output type is the same

View File

@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>,
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32>
// CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<9> : tensor<i32>
// CHECK: %cst_0 = constant dense<6> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<5> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<2> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32>
%5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%2 = constant dense< 4> : tensor<4xi32>
%3 = constant dense<-2> : tensor<4xi32>
// CHECK: %cst = constant dense<7> : tensor<i32>
// CHECK: %cst_0 = constant dense<10> : tensor<4xi32>
// CHECK: %cst_1 = constant dense<3> : tensor<4xi32>
// CHECK: %cst_2 = constant dense<6> : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32>
// CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32>
// CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32>
// CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32>
%5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32>
%6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32>
@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>)
%2 = constant dense< 3.5> : tensor<4xf32>
%3 = constant dense<-0.5> : tensor<4xf32>
// CHECK: %cst = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32>
// CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32>
// CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32>
// CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32>
%5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32>
%6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32>
@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_int
@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape
@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> {
return %2 : tensor<4xi32>
// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_trailing_dim
@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>,
return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_int_mixing_1_n
@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> {
%0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_splat_float
@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_splat_dense_float
@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_same_shape
@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) {
return %2 : tensor<4xf32>
// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_float_trailing_dim
@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32
return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %cst, %cst_0, %cst_1
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32>
// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32>
// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32>
// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]]
}
// CHECK-LABEL: @add_dense_dense_float_mixfng_1_n
@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @rank
func @rank() -> tensor<1xi32> {
%cst = constant dense<[[1], [2]]> : tensor<2x1xi32>
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
// CHECK-LABEL: @rank_input_known_rank
func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> {
// CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32>
// CHECK: return %[[CST]]
%0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32>
return %0 : tensor<1xi32>
}
@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
%input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
%shape = constant dense<[4]> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> {
// CHECK-LABEL: @pseudo_const
func @pseudo_const() -> tensor<i32> {
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32>
// CHECK: return %[[CST]]
%0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
return %0 : tensor<i32>
}
@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> {
%cst_1 = constant dense<4> : tensor<i32>
%cst_2 = constant dense<1> : tensor<i32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> {
%cst_1 = constant dense<4.0> : tensor<f32>
%cst_2 = constant dense<1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> {
%cst_1 = constant dense<-4.0> : tensor<f32>
%cst_2 = constant dense<-1.0> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> {
%cst_1 = constant dense<7.0> : tensor<f32>
%cst_2 = constant dense<1.5> : tensor<f32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32>
// CHECK: return %[[CST]]
%0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> {
%cst = constant dense<[1, 2, 3]> : tensor<3xi32>
%cst_perm = constant dense<0> : tensor<1xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[1, 0]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> {
%cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
%cst_perm = constant dense<[0, 1]> : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> {
%cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32>
%cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32>
// CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32>
// CHECK: return %[[CST]]
%0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32>
return %0 : tensor<4x2x3xi32>
}
@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> {
%87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %87 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic
@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> {
return %2 : tensor<?xi32>
// CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return [[cst]]
// CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concat_2_tensors_1_empty
@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> {
%3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32>
return %3 : tensor<2xi32>
// CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return [[cst]] : tensor<2xi32>
// CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32>
// CHECK: return %[[CST]] : tensor<2xi32>
}
// CHECK-LABEL: @concat_3_tensors_1_empty
@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> {
%3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
// CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"}
// CHECK: return %0 : tensor<?xi32>
}
@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32>
return %0 : tensor<2x2x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsMiddleDim
@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32>
return %0 : tensor<1x4x3xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @concatConstantTensorsLastDim
@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> {
%0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32>
return %0 : tensor<1x2x6xi32>
// CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32>
// CHECK-NOT: constant-dense
// CHECK-NOT: "tfl.concatenation"
// CHECK: return [[cst]]
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_dense_float_mixfng_1_n
@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> {
return %0 : tensor<2x2xf32>
// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @div_dense_different_rank
@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> {
return %0 : tensor<1x2x2xf32>
// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %cst
// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32>
// CHECK: return %[[CST]]
}

View File

@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: opcode_index: 1,
// CHECK-NEXT: inputs: [ 2, 1 ],
// CHECK-NEXT: outputs: [ 3 ],
// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ]
// CHECK-NEXT: }, {
// CHECK-NEXT: opcode_index: 2,
// CHECK-NEXT: inputs: [ 3 ],

View File

@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t
// CHECK-NEXT: return %[[split]]#0, %[[split]]#1
}
// CHECK-LABEL: RemoveTrival
func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> {
%1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>>
%2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>>
return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>>
// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>>
// CHECK-NEXT: return %[[fc]]
}
func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
%cst = constant dense<[1, 1001]> : tensor<2xi32>
%0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>

View File

@ -48,7 +48,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
quant_specs.default_ranges.second.hasValue()) {
pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass(
quant_specs.default_ranges.first.getValueOr(0.0),
quant_specs.default_ranges.second.getValueOr(0.0)));
quant_specs.default_ranges.second.getValueOr(0.0),
quant_specs.IsSignedInferenceType()));
pass_manager->addPass(mlir::TFL::CreateQuantizePass());
pass_manager->addPass(
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));

View File

@ -46,8 +46,11 @@ namespace {
class DefaultQuantParamsPass
: public PassWrapper<DefaultQuantParamsPass, FunctionPass> {
public:
explicit DefaultQuantParamsPass(double default_min, double default_max)
: default_min_(default_min), default_max_(default_max) {}
explicit DefaultQuantParamsPass(double default_min, double default_max,
bool is_signed)
: default_min_(default_min),
default_max_(default_max),
is_signed_(is_signed) {}
void runOnFunction() override;
@ -82,6 +85,7 @@ class DefaultQuantParamsPass
double default_min_;
double default_max_;
bool is_signed_;
quant::QuantParams default_quant_params_;
};
} // namespace
@ -214,15 +218,16 @@ quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(
default_quant_params_ = quant::fakeQuantAttrsToType(
builder.getUnknownLoc(),
/*numBits=*/8, default_min_, default_max_, /*narrowRange=*/false,
builder.getF32Type());
builder.getF32Type(), is_signed_);
}
return default_quant_params_;
}
// Creates an instance of the default quant parameters pass.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max);
double default_min, double default_max, bool is_signed) {
return absl::make_unique<DefaultQuantParamsPass>(default_min, default_max,
is_signed);
}
// Registers this pass with default values, only for test
@ -230,7 +235,8 @@ static PassRegistration<DefaultQuantParamsPass> pass(
"tfl-default-quant",
"Apply quantization with default quantization parameter", [] {
return CreateDefaultQuantParamsPass(/*default_min=*/-1.0,
/*default_max=*/1.0);
/*default_max=*/1.0,
/*is_signed=*/false);
});
} // namespace TFL

View File

@ -321,7 +321,8 @@ void DenseToSparse::runOnFunction() {
if (result.needs_densify) {
const auto value = op->getOperand(operand);
auto densify = builder.create<DensifyOp>(op->getLoc(), value);
auto densify =
builder.create<DensifyOp>(op->getLoc(), value.getType(), value);
value.replaceAllUsesWith(densify);
densify.setOperand(value);
}

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
if (attr.getType().isInteger(/*width=*/32)) {
*attr_i32 = attr;
return success();
}
int64_t value = attr.getInt();
if (value > std::numeric_limits<int>::max() ||
value < std::numeric_limits<int>::min()) {
return failure();
}
*attr_i32 = IntegerAttr::get(
IntegerType::get(/*width=*/32, attr.getContext()), value);
return success();
}
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
// "axis" operand could be a i64 tensor. Resolve it here.
IntegerAttr axis_i32;
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis),
fused_activation_function);
op, output_type, values, axis_i32, fused_activation_function);
return success();
}

View File

@ -859,6 +859,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
target.addLegalOp<ConstantOp>();
target.addLegalOp<FuncOp>();
target.addLegalOp<ReturnOp>();
target.addLegalOp<TFL::CustomOp>();
// Register fused LSTM/RNN ops as legal.
target.addLegalOp<TFL::LSTMOp>();
target.addLegalOp<TFL::UnidirectionalSequenceLSTMOp>();

View File

@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass();
// Creates an instance of the TensorFlow Lite dialect pass to add default
// quantization parameters.
std::unique_ptr<OperationPass<FuncOp>> CreateDefaultQuantParamsPass(
double default_min, double default_max);
double default_min, double default_max, bool is_signed);
// Creates an instance of the TensorFlow Lite dialect pass to convert dense
// tensor to sparse format.

View File

@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() {
auto func = getFunction();
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, &patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
applyPatternsAndFoldGreedily(func, patterns);
if (!emit_quant_adaptor_ops_) {

View File

@ -70,6 +70,7 @@ class PrepareQuantizePass
: public PassWrapper<PrepareQuantizePass, FunctionPass> {
public:
// Constructor used by the PassRegistration and enforce uint8 quantization.
// This is only used by test.
explicit PrepareQuantizePass() {
if (quantize_signed)
quant_specs_.inference_type = tensorflow::DT_QINT8;
@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() {
// convert all of them to signed.
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, true, ctx);
patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx);
} else {
// Convert quant stats to uint8 quantization parameters.
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(8, false, false, ctx);
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
applyPatternsAndFoldGreedily(func, patterns);

View File

@ -0,0 +1,41 @@
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
package(licenses = ["notice"])
tf_python_pybind_extension(
name = "mlir_wrapper",
srcs = [
"attrs.cc",
"basic_classes.cc",
"builders.cc",
"mlir_wrapper.cc",
"mlir_wrapper.h",
"ops.cc",
"types.cc",
],
module_name = "mlir_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "filecheck_wrapper",
srcs = ["filecheck_wrapper.cc"],
module_name = "filecheck_wrapper",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@pybind11",
],
)

View File

@ -0,0 +1,25 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_attrs(py::module& m) {
py::class_<mlir::Attribute>(m, "Attribute");
py::class_<mlir::IntegerAttr, mlir::Attribute>(m, "IntegerAttr")
.def("get",
py::overload_cast<mlir::Type, int64_t>(&mlir::IntegerAttr::get));
}

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_basic_classes(py::module& m) {
py::class_<mlir::MLIRContext>(m, "MLIRContext").def(py::init<>());
py::class_<mlir::Location>(m, "Location");
py::class_<mlir::UnknownLoc>(m, "UnknownLoc")
.def("get", &mlir::UnknownLoc::get);
py::class_<mlir::Region>(m, "Region")
.def("back", &mlir::Region::back, py::return_value_policy::reference)
.def("front", &mlir::Region::front, py::return_value_policy::reference)
.def("add_block", [](mlir::Region& r) { r.push_back(new mlir::Block); })
.def("push_back", &mlir::Region::push_back)
.def("size", [](mlir::Region& r) { return r.getBlocks().size(); })
.def("front", &mlir::Region::front, py::return_value_policy::reference);
py::class_<mlir::Block::iterator>(m, "Block_Iterator");
py::class_<mlir::Block>(m, "Block")
.def("new", ([]() { return new mlir::Block; }),
py::return_value_policy::reference)
.def("end", &mlir::Block::end)
.def("addArgument", &mlir::Block::addArgument);
py::class_<mlir::Value>(m, "Value").def("getType", &mlir::Value::getType);
py::class_<mlir::OpResult, mlir::Value>(m, "OpResult");
py::class_<mlir::BlockArgument, mlir::Value>(m, "BlockArgument");
}

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
void init_builders(py::module& m) {
py::class_<mlir::Builder>(m, "Builder")
.def(py::init<mlir::MLIRContext*>())
.def("getFunctionType",
[](mlir::Builder& b, std::vector<mlir::Type> inputs,
std::vector<mlir::Type> outputs) {
return b.getFunctionType(llvm::ArrayRef<mlir::Type>(inputs),
llvm::ArrayRef<mlir::Type>(outputs));
});
py::class_<mlir::OpBuilder>(m, "OpBuilder")
.def(py::init<mlir::MLIRContext*>())
.def(py::init<mlir::Region&>())
.def(py::init<mlir::Operation*>())
.def(py::init<mlir::Block*, mlir::Block::iterator>())
.def("getUnknownLoc", &mlir::OpBuilder::getUnknownLoc)
.def("setInsertionPoint",
py::overload_cast<mlir::Block*, mlir::Block::iterator>(
&mlir::OpBuilder::setInsertionPoint))
.def("saveInsertionPoint", &mlir::OpBuilder::saveInsertionPoint)
.def("restoreInsertionPoint", &mlir::OpBuilder::restoreInsertionPoint)
.def(
"createOperation",
[](mlir::OpBuilder& opb, mlir::OperationState& state) {
return opb.createOperation(state);
},
py::return_value_policy::reference)
.def("getContext", &mlir::OpBuilder::getContext,
py::return_value_policy::reference);
py::class_<mlir::OpBuilder::InsertPoint>(m, "OpBuilder_InsertionPoint")
.def("getBlock", &mlir::OpBuilder::InsertPoint::getBlock);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FileCheck.h"
#include "llvm/Support/SourceMgr.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(filecheck_wrapper, m) {
m.def("check", [](std::string input, std::string check) {
llvm::FileCheckRequest fcr;
llvm::FileCheck fc(fcr);
llvm::SourceMgr SM = llvm::SourceMgr();
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input),
llvm::SMLoc());
SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(check),
llvm::SMLoc());
llvm::Regex regex = fc.buildCheckPrefixRegex();
fc.readCheckFile(SM, llvm::StringRef(check), regex);
return fc.checkInput(SM, llvm::StringRef(input));
});
}

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
PYBIND11_MODULE(mlir_wrapper, m) {
m.def("registerDialects", []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
});
init_basic_classes(m);
init_types(m);
init_builders(m);
init_ops(m);
init_attrs(m);
}

View File

@ -13,19 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//===- dialect_static_registration.cc -------------------------------------===//
//
// This file registers the RuntimeFallbackDialect.
//
//===----------------------------------------------------------------------===//
#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H
#include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace mlir {
namespace tfd {
namespace py = pybind11;
// Static initialization for dialect registration.
static DialectRegistration<RuntimeFallbackDialect> tfd_registration;
void init_basic_classes(py::module& m);
void init_types(py::module& m);
void init_builders(py::module& m);
void init_ops(py::module& m);
void init_attrs(py::module& m);
} // namespace tfd
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H

View File

@ -0,0 +1,194 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
void init_ops(py::module& m) {
py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
m, "Operation")
.def("getRegion", &mlir::Operation::getRegion,
py::return_value_policy::reference)
.def("getResult", &mlir::Operation::getResult)
.def("dump", &mlir::Operation::dump)
.def("getNumResults", &mlir::Operation::getNumResults);
py::class_<mlir::OperationState>(m, "OperationState")
.def(py::init([](mlir::Location loc, std::string name) {
return mlir::OperationState(loc, llvm::StringRef(name));
}))
.def("addTypes",
[](mlir::OperationState& state, std::vector<mlir::Type> tys) {
state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
})
.def("addOperands",
[](mlir::OperationState& os, std::vector<mlir::Value> ops) {
os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
})
.def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
py::return_value_policy::reference);
py::class_<mlir::ModuleOp>(m, "ModuleOp")
.def("create",
[](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
.def("push_back",
[](mlir::ModuleOp& m, mlir::FuncOp f) { m.push_back(f); })
.def("dump", &mlir::ModuleOp::dump)
.def("getAsStr", [](mlir::ModuleOp& m) {
std::string str;
llvm::raw_string_ostream os(str);
m.print(os);
return os.str();
});
py::class_<mlir::FuncOp>(m, "FuncOp")
.def("create",
[](mlir::Location location, std::string name,
mlir::FunctionType type) {
auto func = mlir::FuncOp::create(location, name, type);
func.addEntryBlock();
return func;
})
.def(
"getBody",
[](mlir::FuncOp& f) -> mlir::Region& { return f.getBody(); },
py::return_value_policy::reference)
.def("getArguments",
[](mlir::FuncOp& f) { return f.getArguments().vec(); })
.def("getName", [](mlir::FuncOp& f) { return f.getName().str(); })
.def("getType", &mlir::FuncOp::getType);
py::class_<mlir::ReturnOp>(m, "ReturnOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Value> values) -> mlir::Operation* {
return opb
.create<mlir::ReturnOp>(loc,
mlir::ArrayRef<mlir::Value>(values))
.getOperation();
});
// mlir::TF::AddOp
py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
});
py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
mlir::Value reduction_indices,
bool keep_dims = false) -> mlir::Operation* {
return opb
.create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
reduction_indices, keep_dims)
.getOperation();
});
// mlir::TF::ConstOp
py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Attribute value) -> mlir::Operation* {
return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
});
// mlir::TF::EqualOp
py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb
.create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
.getOperation();
});
// mlir::TF::GreaterEqualOp
py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
.getOperation();
});
// mlir::TF::GreaterOp
py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
});
// mlir::TF::LegacyCallOp
py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
std::vector<mlir::Type> output, std::vector<mlir::Value> args,
std::string f) -> mlir::Operation* {
return opb
.create<mlir::TF::LegacyCallOp>(
loc, mlir::ArrayRef<mlir::Type>(output),
mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
.getOperation();
});
// mlir::TF::LessEqualOp
py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
});
// mlir::TF::LessOp
py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
});
// mlir::TF::NegOp
py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc,
mlir::Value x) -> mlir::Operation* {
return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
});
py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
.def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) {
return opb
.create<mlir::TF::NotEqualOp>(
loc, x, y, mlir::BoolAttr::get(true, opb.getContext()))
.getOperation();
});
// mlir::TF::SubOp
py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
.def("create",
[](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
mlir::Value y) -> mlir::Operation* {
return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
});
}

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
void init_types(py::module& m) {
// Type
py::class_<mlir::Type> Type(m, "Type");
Type.def("getKind", &mlir::Type::getKind);
// Type Enums
py::enum_<mlir::StandardTypes::Kind>(Type, "StandardTypes_Kind")
.value("BF16", mlir::StandardTypes::BF16);
// Type Sub-classes
py::class_<mlir::FunctionType, mlir::Type>(m, "FunctionType")
.def("getResults",
[](mlir::FunctionType& ft) { return ft.getResults().vec(); });
py::class_<mlir::FloatType, mlir::Type>(m, "FloatType")
.def("get", &mlir::FloatType::get);
py::class_<mlir::IntegerType, mlir::Type>(m, "IntegerType")
.def("get", py::overload_cast<unsigned, mlir::MLIRContext*>(
&mlir::IntegerType::get));
py::class_<mlir::UnrankedTensorType, mlir::Type>(m, "UnrankedTensorType")
.def("get", &mlir::UnrankedTensorType::get);
py::class_<mlir::RankedTensorType, mlir::Type>(m, "RankedTensorType")
.def("get", [](std::vector<int64_t> shape, mlir::Type ty) {
return mlir::RankedTensorType::get(mlir::ArrayRef<int64_t>(shape), ty);
});
}

View File

@ -70,9 +70,9 @@ tool_dirs = config.mlir_tf_tools_dirs + [
]
tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate',
'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt',
'xla-opt'
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -44,6 +44,7 @@ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/tfjs',
'tensorflow/compiler/mlir/xla',
'tensorflow/compiler/aot',
'tensorflow/compiler/xla/service/mlir_gpu',

View File

@ -36,7 +36,7 @@ filegroup(
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
],
)
@ -1075,7 +1075,7 @@ genrule(
srcs = [
"@llvm-project//mlir:include/mlir/Interfaces/CallInterfaces.td",
"@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffects.td",
"@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
"ir/tf_generated_ops.td",
"ir/tf_op_base.td",
@ -1140,6 +1140,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
@ -1278,6 +1279,7 @@ cc_library(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)
@ -1292,6 +1294,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace TFControlFlow {

View File

@ -1765,7 +1765,7 @@ of corresponding 3-element vectors is cross-multiplied independently.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> {
let summary = "An Op to sum inputs across replicated TPU instances.";
let description = [{
@ -1789,7 +1789,7 @@ and `B, D, F, H` as group 1. Thus we get the outputs:
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CumsumOp : TF_Op<"Cumsum", [AllTypesMatch<["x", "out"]>, NoSideEffect]> {
def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> {
let summary = "Compute the cumulative sum of the tensor `x` along `axis`.";
let description = [{
@ -2907,6 +2907,8 @@ fill([2, 3], 9) ==> [[9, 9, 9]
return Verify(*this);
}];
let hasFolder = 1;
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value dims, Value value"
>];
@ -3625,31 +3627,6 @@ tf.imag(input) ==> [4.75, 5.75]
TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>;
}
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = [{
Create a copy of `x` with the updated specified rows 'i' with values 'v'.
}];
let description = [{
Creates a copy of tensor 'x' and updates the columns specified in tensor 'i'
with the values 'v'. Originally this function was mutative however for
compilation we make this operation create / operate on a copy.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the reciprocal of x element-wise.";
@ -4350,7 +4327,7 @@ cublas.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [AllTypesMatch<["input", "band"]>, NoSideEffect]> {
def TF_MatrixBandPartOp : TF_Op<"MatrixBandPart", [NoSideEffect, TF_AllTypesMatch<["input", "band"]>]> {
let summary = [{
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
}];
@ -6354,6 +6331,8 @@ If `x` and `y` are reals, this will return the floating-point division.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> {
@ -10352,6 +10331,33 @@ https://www.tensorflow.org/xla/operation_semantics#gather
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> {
let summary = [{
A pseudo-op to represent host-side computation in an XLA program.
}];
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrArrayAttr:$ancestors,
TF_ShapeAttrArray:$shapes,
SymbolRefAttr:$shape_inference_graph,
StrAttr:$key,
DefaultValuedAttr<I64Attr, "1000000">:$cost_estimate_ns,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF_XlaKeyValueSortOp : TF_Op<"XlaKeyValueSort", [NoSideEffect]> {
let summary = "Wraps the XLA Sort operator, documented at";
@ -10400,6 +10406,24 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> {
let summary = "An op to receive a tensor from the host.";
let description = [{
}];
let arguments = (ins
TF_ShapeAttr:$shape,
StrAttr:$key
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedResultTypeAttr Toutput = TF_DerivedResultTypeAttr<0>;
}
def TF_XlaReduceOp : TF_Op<"XlaReduce", [NoSideEffect]> {
let summary = "Wraps the XLA Reduce operator, documented at";
@ -10464,6 +10488,23 @@ i=0...N-1.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> {
let summary = "An op to send a tensor to the host.";
let description = [{
}];
let arguments = (ins
TF_Tensor:$input,
StrAttr:$key
);
let results = (outs);
TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>;
}
def TF_XlaSvdOp : TF_Op<"XlaSvd", [NoSideEffect]> {
let summary = [{
Computes the eigen decomposition of a batch of self-adjoint matrices
@ -10547,6 +10588,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device.";
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
StrAttr:$key,
DefaultValuedAttr<I64Attr, "0">:$tpu_core
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> {
let summary = "An op that receives embeddng activations on the TPU.";
@ -10605,3 +10667,44 @@ used to look up the program in the compilation cache.
TF_DerivedResultSizeAttr num_computations = TF_DerivedResultSizeAttr<1>;
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
}
def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> {
let summary = [{
A placeholder op to receive values from a running XLA computation.
}];
let description = [{
}];
let arguments = (ins
TF_StrTensor:$dynamic_key,
StrAttr:$key,
I64Attr:$device_ordinal
);
let results = (outs
Variadic<TF_Tensor>:$outputs
);
TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>;
}
def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> {
let summary = "A placeholder op to send values to a running XLA computation.";
let description = [{
}];
let arguments = (ins
Variadic<TF_Tensor>:$inputs,
TF_StrTensor:$dynamic_key,
StrAttr:$key,
I64Attr:$device_ordinal
);
let results = (outs);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#define TF_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffects.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
//===----------------------------------------------------------------------===//
@ -70,6 +70,16 @@ class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
"$_op.getOperand(" # opId # ").getType(), "
"$_op.getResult(" # resId # ").getType())">]>;
class TF_AllTypesMatchPred<list<string> values> :
CPred<"TF::AreCastCompatible(llvm::makeArrayRef({"# StrJoin<values>.result #"}))">;
class TF_AllTypesMatch<list<string> names> :
PredOpTrait<
"all of {" # StrJoin<names>.result # "} have dynamically equal types ",
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -110,48 +110,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) {
return !type || type.getRank() <= rank;
}
// Returns true if the given pair of TensorFlow types can be cast to one
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32> and tensor<3xf32> are cast compatible.
static bool AreCastCompatible(Type a, Type b) {
if (TensorCastOp::areCastCompatible(a, b)) return true;
// Resource types may optionally contain subtypes information that does not
// match. Check subtypes compatibility when possible, otherwise treat them as
// compatible.
auto a_or_element_type = getElementTypeOrSelf(a);
auto b_or_element_type = getElementTypeOrSelf(b);
auto a_kind = a_or_element_type.getKind();
auto b_kind = b_or_element_type.getKind();
if (a_kind == TensorFlowTypes::RESOURCE &&
b_kind == TensorFlowTypes::RESOURCE) {
auto a_resource_type = a_or_element_type.dyn_cast<ResourceType>();
auto b_resource_type = b_or_element_type.dyn_cast<ResourceType>();
bool a_has_subtype = !a_resource_type.getSubtypes().empty();
bool b_has_subtype = !b_resource_type.getSubtypes().empty();
if (!a_has_subtype || !b_has_subtype) return true;
assert(a_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");
assert(b_resource_type.getSubtypes().size() <= 1 &&
"Resource type must have at most one subtype");
return TensorCastOp::areCastCompatible(
a_resource_type.getSubtypes().front(),
b_resource_type.getSubtypes().front());
}
// Variant types may optionally contain subtypes information that need not
// match. It is also not possible to compare subtypes for compatibility as
// their interpretation depends on the ops operating on them. So, accept all
// pairs of variant types.
return a_kind == TensorFlowTypes::VARIANT &&
b_kind == TensorFlowTypes::VARIANT;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1;
}
@ -503,9 +461,10 @@ LogicalResult FoldOperandsPermutation(
namespace {
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
template <typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
template <
typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType();
@ -520,7 +479,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value);
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
@ -537,6 +497,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
return arithmetic_op.x();
}
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
// TODO(chhe): we could fold and add an identity to force the broadcast.
if (result_op_type != rhs_type) {
return {};
}
bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
@ -1413,7 +1379,7 @@ static LogicalResult Verify(DynamicStitchOp op) {
auto expected_out_ty =
RankedTensorType::get(expected_shape, out_ty.getElementType());
if (!AreCastCompatible(out_ty, expected_out_ty)) {
if (!AreCastCompatible({out_ty, expected_out_ty})) {
return op.emitOpError() << "has invalid output type; should be "
"compatible with inferred type "
<< expected_out_ty;
@ -1647,7 +1613,7 @@ static ShapedType InferFillOpType(Value dims, Value value) {
llvm::SmallVector<int64_t, 4> shape;
shape.reserve(dims_attr.getNumElements());
for (const APInt &dim : dims_attr.getValues<APInt>()) {
for (const APInt dim : dims_attr.getValues<APInt>()) {
shape.push_back(dim.getSExtValue());
}
return RankedTensorType::get(shape, etype);
@ -1658,6 +1624,29 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims,
FillOp::build(builder, result, InferFillOpType(dims, value), dims, value);
}
OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "fill op has two operand");
auto value = operands[1].dyn_cast_or_null<ElementsAttr>();
if (!value) return {};
auto type = getType().cast<ShapedType>();
if (type.hasStaticShape())
return DenseElementsAttr::get(type, value.getValue({}));
auto dims = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!dims) return {};
llvm::SmallVector<int64_t, 4> shape;
shape.reserve(dims.getNumElements());
for (const APInt dim : dims.getValues<APInt>()) {
shape.push_back(dim.getSExtValue());
}
type = RankedTensorType::get(shape, type.getElementType());
return DenseElementsAttr::get(type, value.getValue({}));
}
//===----------------------------------------------------------------------===//
// FusedBatchNormGradOp
//===----------------------------------------------------------------------===//
@ -1814,14 +1803,14 @@ static LogicalResult Verify(IfOp op) {
for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible(operandType, thenInputType))
if (!AreCastCompatible({operandType, thenInputType}))
return op.emitError(
llvm::formatv("then branch input type {0} is incompatible with "
"operand type {1} at index {2}",
thenInputType, operandType, i));
auto elseInputType = elseFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible(operandType, elseInputType))
if (!AreCastCompatible({operandType, elseInputType}))
return op.emitError(
llvm::formatv("else branch input type {0} is incompatible with "
"operand type {1} at index {2}",
@ -1829,7 +1818,7 @@ static LogicalResult Verify(IfOp op) {
// If branches have incompatible input types that means that no tensor can
// serve as input to both the functions. Hence, the op is invalid.
if (!AreCastCompatible(thenInputType, elseInputType))
if (!AreCastCompatible({thenInputType, elseInputType}))
return op.emitError(llvm::formatv(
"branches inputs have incompatible types {0} and {1} at index {2}",
thenInputType, elseInputType, i));
@ -1845,14 +1834,14 @@ static LogicalResult Verify(IfOp op) {
for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible(thenResultType, resultType))
if (!AreCastCompatible({thenResultType, resultType}))
return op.emitError(
llvm::formatv("then branch result type {0} is incompatible with op "
"result type {1} at index {2}",
thenResultType, resultType, i));
auto elseResultType = elseFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible(elseResultType, resultType))
if (!AreCastCompatible({elseResultType, resultType}))
return op.emitError(
llvm::formatv("else branch result type {0} is incompatible with op "
"result type {1} at index {2}",
@ -2426,6 +2415,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<RealDivWithSqrtDivisor>(context);
}
OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) {
return IdentityArithmeticOpFolder<RealDivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@ -2616,9 +2609,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
<< variadic_idx_str << " to match rank of operand"
<< variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, verify that the result is dynamic.
return op->emitOpError("requires dynamic shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
// The operand is an unranked tensor, print a warning if the result
// is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
}
Type element_type = result_ranked_type.getElementType();
@ -3789,7 +3785,7 @@ static LogicalResult Verify(WhileOp op) {
auto aType = a.second[idx];
auto bType = b.second[idx];
if (!AreCastCompatible(aType, bType))
if (!AreCastCompatible({aType, bType}))
return op.emitError(llvm::formatv(
"{0} type {1} is incompatible with {2} type {3} at index {4}",
a.first, aType, b.first, bType, idx));

View File

@ -31,7 +31,7 @@ limitations under the License.
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/Interfaces/SideEffects.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"

View File

@ -905,5 +905,29 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> {
TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>;
}
// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once
// autogenerated op def matches.
def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> {
let summary = "Updates specified rows 'i' with values 'v'.";
let description = [{
Computes `x[i, :] = v; return x`.
Originally this function is mutative however for compilation we make this
operation create / operate on a copy of `x`.
}];
let arguments = (ins
TF_Tensor:$x,
I32Tensor:$i,
TF_Tensor:$v
);
let results = (outs
TF_Tensor:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
#endif // TF_OPS

View File

@ -28,6 +28,134 @@ llvm::Optional<llvm::ArrayRef<int64_t>> GetShape(mlir::Value value) {
if (shaped_type.hasRank()) return shaped_type.getShape();
return llvm::None;
}
// Merges cast compatible shapes and returns a more refined shape. The two
// shapes are cast compatible if they have the same rank and at each dimension,
// either both have same size or one of them is dynamic. Returns false if the
// given shapes are not cast compatible. The refined shape is same or more
// precise than the two input shapes.
bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
llvm::ArrayRef<int64_t> b_shape,
llvm::SmallVectorImpl<int64_t>* refined_shape) {
if (a_shape.size() != b_shape.size()) return false;
int64_t rank = a_shape.size();
refined_shape->reserve(rank);
for (auto dims : llvm::zip(a_shape, b_shape)) {
int64_t dim1 = std::get<0>(dims);
int64_t dim2 = std::get<1>(dims);
if (mlir::ShapedType::isDynamic(dim1)) {
refined_shape->push_back(dim2);
continue;
}
if (mlir::ShapedType::isDynamic(dim2)) {
refined_shape->push_back(dim1);
continue;
}
if (dim1 == dim2) {
refined_shape->push_back(dim1);
continue;
}
return false;
}
return true;
}
// Given two types `a` and `b`, returns a refined type which is cast compatible
// with both `a` and `b` and is equal to or more precise than both of them. It
// returns empty Type if the input types are not cast compatible.
//
// The two types are considered cast compatible if they have dynamically equal
// shapes and element type. For element types that do not have subtypes, they
// must be equal. However for TensorFlow types such as Resource and Variant,
// that also have subtypes, we recursively check for subtype compatibilty for
// Resource types and assume all variant types are cast compatible. If either
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
//
// The returned type is same or more precise than the input types. For example,
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
//
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
// might allow operands to either be same as result type or be a ref type
// corresponding to it.
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a) {
// Fast path if everything is equal.
if (a == b) return b;
auto a_tt = a.dyn_cast<mlir::TensorType>();
auto b_tt = b.dyn_cast<mlir::TensorType>();
// If only one of a or b is a tensor type, they are incompatible.
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
// For non-tensor types, we do not need to worry about shape and can return
// early.
if (!a_tt && !b_tt) {
// Remove ref types.
if (may_ignore_ref_type_a) {
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
a = ref_type.RemoveRef();
if (a == b) return a;
}
}
if (a.getKind() != b.getKind()) return nullptr;
// If either is not a type that contain subtypes then the types are not cast
// compatible.
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
if (!a_wst || !b_wst) return nullptr;
// For Variant types we are more permissive right now and accept all pairs
// of Variant types. If we are more constrainted and check compatibility of
// subtypes, we might reject valid graphs.
// TODO(prakalps): Variant doesn't have a subtype, we assign it
// one, so we should only assign it one when we know the subtype. Then we
// can be more constrained and check subtypes for cast compatibility as
// well.
if (a.isa<mlir::TF::VariantType>()) return a;
// For Resource types, we recursively check the subtypes for cast
// compatibility, if possible. Otherwise treat them as compatible.
auto a_wst_st = a_wst.GetSubtypes();
auto b_wst_st = b_wst.GetSubtypes();
if (a_wst_st.empty() || b_wst_st.empty()) return a;
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
mlir::Type refined_st =
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
/*may_ignore_ref_type_a=*/false);
if (!refined_st) return nullptr;
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
}
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
}
// For tensor types, check compatibility of both element type and shape.
mlir::Type refined_element_ty = GetCastCompatibleType(
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
if (!refined_element_ty) return nullptr;
if (!a_tt.hasRank() && !b_tt.hasRank()) {
return mlir::UnrankedTensorType::get(refined_element_ty);
}
if (!a_tt.hasRank()) {
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
}
if (!b_tt.hasRank()) {
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
}
llvm::SmallVector<int64_t, 8> refined_shape;
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
return nullptr;
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
}
} // namespace
namespace mlir {
@ -224,44 +352,16 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
bool HasCompatibleElementTypes(Type lhs, Type rhs,
bool may_ignore_ref_type_lhs) {
// Fast path if everything is equal.
if (lhs == rhs) return true;
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;
}
// In TF all values are tensors.
auto lhs_tt = lhs.cast<TensorType>();
auto rhs_tt = rhs.cast<TensorType>();
// Verify matching element types. These should be identical dynamically,
// so this allows for types not yet fully refined.
auto lhs_et = lhs_tt.getElementType();
auto rhs_et = rhs_tt.getElementType();
if (lhs_et == rhs_et) return true;
// Remove ref types.
if (may_ignore_ref_type_lhs) {
if (auto ref_type = lhs_et.dyn_cast<TF::TensorFlowRefType>()) {
lhs_et = ref_type.RemoveRef();
if (lhs_et == rhs_et) return true;
}
}
if (lhs_et.getKind() != rhs_et.getKind()) return false;
// If either is not type that contain subtypes then the element types don't
// match.
auto lhs_wst = lhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
auto rhs_wst = rhs_et.dyn_cast<TF::TensorFlowTypeWithSubtype>();
if (!lhs_wst || !rhs_wst) return false;
// Consider the subtype recursively.
auto lhs_wst_st = lhs_wst.GetSubtypes();
auto rhs_wst_st = rhs_wst.GetSubtypes();
if (lhs_wst_st.empty() || rhs_wst_st.empty()) return true;
if (lhs_wst_st.size() != rhs_wst_st.size()) return false;
for (auto subtypes : llvm::zip(lhs_wst_st, rhs_wst_st)) {
if (!HasCompatibleElementTypes(std::get<0>(subtypes),
std::get<1>(subtypes)))
return false;
bool AreCastCompatible(ArrayRef<Type> types) {
Type common = types.front();
for (auto type : types.drop_front()) {
Type refined_type =
GetCastCompatibleType(common, type, /*may_ignore_ref_type_a=*/false);
if (!refined_type) return false;
common = refined_type;
}
return true;
}

View File

@ -313,6 +313,12 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);
bool HasCompatibleElementTypes(Type lhs, Type rhs,
bool may_ignore_ref_type_lhs = false);
// Returns true if all TensorFlow types can be cast to one
// another. In other words, a single run-time value is legal for both the types.
// For example, tensor<*xf32>, tensor<?xf32> and tensor<3xf32> are cast
// compatible.
bool AreCastCompatible(ArrayRef<Type> types);
} // end namespace TF
} // end namespace mlir

View File

@ -471,3 +471,14 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> {
// CHECK: return [[VAL0]]
return %0 : tensor<i32>
}
// CHECK-LABEL: @foldFill
func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) {
%0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32>
// CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>}
%3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32>
return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32>
}

Some files were not shown because too many files have changed in this diff Show More