Merge branch 'master' into unify_doccitations

This commit is contained in:
Martin Wicke 2019-04-17 22:39:54 -07:00 committed by GitHub
commit d502acb5ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2188 changed files with 118252 additions and 32862 deletions
ISSUE_TEMPLATE.mdLICENSEREADME.mdWORKSPACEconfigure.py
tensorflow
api_template.__init__.pyapi_template_v1.__init__.py
c
compat_template_v1.__init__.py
compiler
aot
jit
tests
tf2tensorrt
tf2xla
xla

View File

@ -32,7 +32,7 @@ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
You can obtain the TensorFlow version with:
```bash
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
```
### Describe the problem

View File

@ -1,4 +1,4 @@
Copyright 2018 The TensorFlow Authors. All rights reserved.
Copyright 2019 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004

View File

@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
compatible API's for C++, Go, Java, JavaScript and Swift.
compatible API's for C++, Go, Java, JavaScript, and Swift.
Keep up to date with release announcements and security updates by
subscribing to
@ -50,10 +50,10 @@ instructions, and how to build from source.*
People who are a little more adventurous can also try our nightly binaries:
**Nightly pip packages**
* We are pleased to announce that TensorFlow now offers nightly pip packages
under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi.
**Nightly pip packages** * We are pleased to announce that TensorFlow now offers
nightly pip packages under the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi.
Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean
environment to install the nightly TensorFlow build. We support CPU and GPU
packages on Linux, Mac, and Windows.

View File

@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file"
http_archive(
name = "io_bazel_rules_closure",
sha256 = "ddce3b3a3909f99b28b25071c40b7fec7e2e1d1d1a4b2e933f3082aa99517105",
strip_prefix = "rules_closure-316e6133888bfc39fb860a4f1a31cfcbae485aef",
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/316e6133888bfc39fb860a4f1a31cfcbae485aef.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/316e6133888bfc39fb860a4f1a31cfcbae485aef.tar.gz", # 2019-03-21
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
],
)
@ -43,47 +43,37 @@ remote_config_workspace()
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90",
strip_prefix = "rules_apple-0.14.0",
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"],
)
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
)
http_archive(
name = "bazel_skylib",
sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb",
strip_prefix = "bazel-skylib-0.7.0",
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"],
)
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
) # https://github.com/bazelbuild/rules_apple/releases
http_archive(
name = "build_bazel_apple_support",
sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0",
strip_prefix = "apple_support-0.5.0",
urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"],
)
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
) # https://github.com/bazelbuild/apple_support/releases
http_archive(
name = "bazel_skylib",
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "build_bazel_rules_swift",
sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a",
strip_prefix = "rules_swift-0.7.0",
urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"],
)
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
) # https://github.com/bazelbuild/rules_swift/releases
http_archive(
name = "com_github_apple_swift_swift_protobuf",
type = "zip",
strip_prefix = "swift-protobuf-1.2.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"],
)
# Use swift_rules_dependencies to fetch the tolchains.
# Since we defined all the "git_repository" rules above, the following call will
# skip redefining them.
strip_prefix = "swift-protobuf-1.4.0/",
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
) # https://github.com/apple/swift-protobuf/releases
http_file(
name = "xctestrunner",
executable = 1,
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
) # https://github.com/google/xctestrunner/releases
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
swift_rules_dependencies()

View File

@ -33,13 +33,11 @@ except ImportError:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
_DEFAULT_CUDA_VERSION = '10.0'
_DEFAULT_CUDA_VERSION = '10'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_TENSORRT_VERSION = '5'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@ -58,6 +56,7 @@ NCCL_LIB_PATHS = [
# List of files to configure when building Bazel on Apple platforms.
APPLE_BAZEL_FILES = [
'tensorflow/lite/experimental/ios/BUILD',
'tensorflow/lite/experimental/objc/BUILD',
'tensorflow/lite/experimental/swift/BUILD'
]
@ -68,11 +67,6 @@ IOS_FILES = [
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
]
if platform.machine() == 'ppc64le':
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
else:
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
class UserInputError(Exception):
pass
@ -206,9 +200,10 @@ def setup_python(environ_cp):
ask_python_bin_path = ('Please specify the location of python. [Default is '
'%s]: ') % default_python_bin_path
while True:
python_bin_path = get_from_env_or_user_or_default(
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
default_python_bin_path)
python_bin_path = get_from_env_or_user_or_default(environ_cp,
'PYTHON_BIN_PATH',
ask_python_bin_path,
default_python_bin_path)
# Check if the path is valid
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
@ -392,14 +387,14 @@ def set_build_var(environ_cp,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
if var == '1':
write_to_bazelrc(
'build:%s --define %s=true' % (bazel_config_name, option_name))
write_to_bazelrc('build:%s --define %s=true' %
(bazel_config_name, option_name))
write_to_bazelrc('build --config=%s' % bazel_config_name)
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
write_to_bazelrc(
'build:%s --define %s=true' % (bazel_config_name, option_name))
write_to_bazelrc('build:%s --define %s=true' %
(bazel_config_name, option_name))
def set_action_env_var(environ_cp,
@ -446,6 +441,9 @@ def convert_version_to_int(version):
"""
version = version.split('-')[0]
version_segments = version.split('.')
# Treat "0.24" as "0.24.0"
if len(version_segments) == 2:
version_segments.append('0')
for seg in version_segments:
if not seg.isdigit():
return None
@ -665,9 +663,9 @@ def prompt_loop_or_load_from_env(environ_cp,
print(error_msg % val)
environ_cp[var_name] = ''
else:
raise UserInputError(
'Invalid %s setting was provided %d times in a row. '
'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts))
raise UserInputError('Invalid %s setting was provided %d times in a row. '
'Assuming to be a scripting mistake.' %
(var_name, n_ask_attempts))
environ_cp[var_name] = val
return val
@ -676,8 +674,8 @@ def prompt_loop_or_load_from_env(environ_cp,
def create_android_ndk_rule(environ_cp):
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
if is_windows() or is_cygwin():
default_ndk_path = cygpath(
'%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA'])
default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
environ_cp['APPDATA'])
elif is_macos():
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
@ -696,8 +694,9 @@ def create_android_ndk_rule(environ_cp):
error_msg=('The path %s or its child file "source.properties" '
'does not exist.'))
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
check_ndk_level(android_ndk_home_path))
write_action_env_to_bazelrc(
'ANDROID_NDK_API_LEVEL',
get_ndk_api_level(environ_cp, android_ndk_home_path))
def create_android_sdk_rule(environ_cp):
@ -764,8 +763,10 @@ def create_android_sdk_rule(environ_cp):
write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
"""Check the revision number of an Android NDK path."""
def get_ndk_api_level(environ_cp, android_ndk_home_path):
"""Gets the appropriate NDK API level to use for the provided Android NDK path."""
# First check to see if we're using a blessed version of the NDK.
properties_path = '%s/source.properties' % android_ndk_home_path
if is_windows() or is_cygwin():
properties_path = cygpath(properties_path)
@ -774,17 +775,40 @@ def check_ndk_level(android_ndk_home_path):
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
if revision:
ndk_api_level = revision.group(1)
ndk_version = revision.group(1)
else:
raise Exception('Unable to parse NDK revision.')
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print(
'WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' %
(android_ndk_home_path, ndk_api_level, _SUPPORTED_ANDROID_NDK_VERSIONS))
return ndk_api_level
if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The NDK version in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_version,
_SUPPORTED_ANDROID_NDK_VERSIONS))
# Now grab the NDK API level to use. Note that this is different from the
# SDK API level, as the NDK API level is effectively the *min* target SDK
# version.
platforms = os.path.join(android_ndk_home_path, 'platforms')
api_levels = sorted(os.listdir(platforms))
api_levels = [
x.replace('android-', '') for x in api_levels if 'android-' in x
]
def valid_api_level(api_level):
return os.path.exists(
os.path.join(android_ndk_home_path, 'platforms',
'android-' + api_level))
android_ndk_api_level = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_NDK_API_LEVEL',
var_default='18', # 18 is required for GPU acceleration.
ask_for_var=('Please specify the (min) Android NDK API level to use. '
'[Available levels: %s]') % api_levels,
check_success=valid_api_level,
error_msg='Android-%s is not present in the NDK path.')
return android_ndk_api_level
def set_gcc_host_compiler_path(environ_cp):
@ -831,149 +855,39 @@ def reformat_version_sequence(version_str, sequence_count):
return '.'.join(v[:sequence_count])
def set_tf_cuda_paths(environ_cp):
"""Set TF_CUDA_PATHS."""
ask_cuda_paths = (
'Please specify the comma-separated list of base paths to look for CUDA '
'libraries and headers. [Leave empty to use the default]: ')
tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS',
ask_cuda_paths, '')
if tf_cuda_paths:
environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
"""Set TF_CUDA_VERSION."""
ask_cuda_version = (
'Please specify the CUDA SDK version you want to use. '
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
# Configure the Cuda SDK version to use.
tf_cuda_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
if is_windows() or is_cygwin():
default_cuda_path = cygpath(
environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
elif is_linux():
# If the default doesn't exist, try an alternative default.
if (not os.path.exists(default_cuda_path)
) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
' installed. Refer to README.md for more details. '
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
cuda_toolkit_path = get_from_env_or_user_or_default(
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
if is_windows() or is_cygwin():
cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
cuda_rt_lib_paths = [
'%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
'lib64',
'lib/powerpc64le-linux-gnu',
'lib/x86_64-linux-gnu',
]
]
elif is_macos():
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
(tf_cuda_version, cuda_toolkit_paths_full))
environ_cp['TF_CUDA_VERSION'] = ''
environ_cp['CUDA_TOOLKIT_PATH'] = ''
else:
raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
tf_cuda_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDA_VERSION',
ask_cuda_version,
_DEFAULT_CUDA_VERSION)
environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
def set_tf_cudnn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
"""Set TF_CUDNN_VERSION."""
ask_cudnn_version = (
'Please specify the cuDNN version you want to use. '
'[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_cudnn_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
'installed. Refer to README.md for more details. [Default'
' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
cudnn_install_path = get_from_env_or_user_or_default(
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
cudnn_install_path = os.path.realpath(
os.path.expanduser(cudnn_install_path))
if is_windows() or is_cygwin():
cudnn_install_path = cygpath(cudnn_install_path)
if is_windows():
cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
elif is_linux():
cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
elif is_macos():
cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
cuda_dnn_lib_alt_path)
if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
cuda_dnn_lib_alt_path_full):
break
# Try another alternative for Linux
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
if os.path.exists(
'%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
# Reset and Retry
print(
'Invalid path to cuDNN %s toolkit. None of the following files can be '
'found:' % tf_cudnn_version)
print(cuda_dnn_lib_path_full)
print(cuda_dnn_lib_alt_path_full)
if is_linux():
print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
environ_cp['TF_CUDNN_VERSION'] = ''
else:
raise UserInputError('Invalid TF_CUDNN setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
tf_cudnn_version = get_from_env_or_user_or_default(environ_cp,
'TF_CUDNN_VERSION',
ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
@ -1005,252 +919,38 @@ def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
return cudnn_ok and cuda_ok
def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
Adapted from code contributed by Sami Kama (https://github.com/samikama).
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_tensorrt_version(environ_cp):
"""Set TF_TENSORRT_VERSION."""
if not is_linux():
raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support.
if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
False))) != '1':
if not int(environ_cp.get('TF_NEED_TENSORRT', False)):
return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
'installed. [Default is %s]:') % (
_DEFAULT_TENSORRT_PATH_LINUX)
trt_install_path = get_from_env_or_user_or_default(
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
_DEFAULT_TENSORRT_PATH_LINUX)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
fl.update([
os.path.realpath(os.path.join(search_path, x))
for x in os.listdir(search_path)
if 'libnvinfer.so' in x
])
return fl
possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
if not matches.groups():
continue
ver_str = matches.group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
if ver > highest_ver[0]:
highest_ver = [ver, ver_str, lib_file]
if highest_ver[1] is not None:
trt_install_path = os.path.dirname(highest_ver[2])
tf_tensorrt_version = highest_ver[1]
break
# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break
# Reset and Retry
if possible_files:
print('TensorRT libraries found in one the following directories',
'are not compatible with selected cuda and cudnn installations')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
print(
'Invalid path to TensorRT. None of the following files can be found:')
print(trt_install_path)
print(os.path.join(trt_install_path, 'lib'))
print(os.path.join(trt_install_path, 'lib64'))
if search_result:
print(libnvinfer_path_from_ldconfig)
else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
ask_tensorrt_version = (
'Please specify the TensorRT version you want to use. '
'[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION
tf_tensorrt_version = get_from_env_or_user_or_default(
environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version,
_DEFAULT_TENSORRT_VERSION)
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
def set_tf_nccl_install_path(environ_cp):
"""Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
Args:
environ_cp: copy of the os.environ.
Raises:
ValueError: if this method was called under non-Linux platform.
UserInputError: if user has provided invalid input multiple times.
"""
def set_tf_nccl_version(environ_cp):
"""Set TF_NCCL_VERSION."""
if not is_linux():
raise ValueError('Currently NCCL is only supported on Linux platforms.')
raise ValueError('Currently NCCL is only supported on Linux platform.')
if 'TF_NCCL_VERSION' in environ_cp:
return
ask_nccl_version = (
'Please specify the locally installed NCCL version you want to use. '
'[Default is to use https://github.com/nvidia/nccl]: ')
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_nccl_version = get_from_env_or_user_or_default(
environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '')
if not tf_nccl_version:
break # No need to get install path, building the open source code.
tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
# Look with ldconfig first if we can find the library in paths
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
# include directory. This is where the NCCL .deb packages install them.
# First check to see if NCCL is in the ldconfig.
# If its found, use that location.
if is_linux():
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
nccl2_path_from_ldconfig)
if nccl2_path_from_ldconfig:
nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
# Check if this is the main system lib location
if re.search('.*linux-gnu', nccl_install_path):
trunc_nccl_install_path = '/usr'
print('This looks like a system path.')
else:
trunc_nccl_install_path = nccl_install_path + '/..'
# Look for header
nccl_hdr_path = trunc_nccl_install_path + '/include'
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_hdr_path + '/nccl.h'):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
break
else:
print(
'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
)
else:
print('NCCL2 is listed by ldconfig but the library is not found. '
'Your ldconfig is out of date. Please run sudo ldconfig.')
else:
# NCCL is not found in ldconfig. Ask the user for the location.
default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_nccl_path = (
r'Please specify the location where NCCL %s library is '
'installed. Refer to README.md for more details. [Default '
'is %s]:') % (tf_nccl_version, default_nccl_path)
nccl_install_path = get_from_env_or_user_or_default(
environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
nccl_install_path = os.path.realpath(
os.path.expanduser(nccl_install_path))
if is_windows() or is_cygwin():
nccl_install_path = cygpath(nccl_install_path)
nccl_lib_path = ''
if is_windows():
nccl_lib_path = 'lib/x64/nccl.lib'
elif is_linux():
nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
if not os.path.exists(nccl_lpath):
for relative_path in NCCL_LIB_PATHS:
path = '%s/%s%s' % (nccl_install_path, relative_path,
nccl_lib_filename)
if os.path.exists(path):
print('NCCL found at ' + path)
nccl_lib_path = path
break
else:
nccl_lib_path = nccl_lpath
elif is_macos():
nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(
os.path.dirname(nccl_lib_path), '../include/nccl.h')
print('Assuming NCCL header path is ' + nccl_hdr_path)
if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
os.path.dirname(nccl_lib_path))
# Set NCCL_HDR_PATH
environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
write_action_env_to_bazelrc('NCCL_HDR_PATH',
os.path.dirname(nccl_hdr_path))
break
# Reset and Retry
print(
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
'O/S agnostic package of NCCL 2' %
(tf_nccl_version, nccl_lib_path, nccl_hdr_path))
environ_cp['TF_NCCL_VERSION'] = ''
else:
raise UserInputError('Invalid TF_NCCL setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
# Set TF_NCCL_VERSION
'[Leave empty to use http://github.com/nvidia/nccl]: ')
tf_nccl_version = get_from_env_or_user_or_default(environ_cp,
'TF_NCCL_VERSION',
ask_nccl_version, '')
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@ -1607,6 +1307,66 @@ def configure_ios():
symlink_force(filepath, new_filepath)
def validate_cuda_config(environ_cp):
"""Run find_cuda_config.py and return cuda_toolkit_path, or None."""
def maybe_encode_env(env):
"""Encodes unicode in env to str on Windows python 2.x."""
if not is_windows() or sys.version_info[0] != 2:
return env
for k, v in env.items():
if isinstance(k, unicode):
k = k.encode('ascii')
if isinstance(v, unicode):
v = v.encode('ascii')
env[k] = v
return env
cuda_libraries = ['cuda', 'cudnn']
if is_linux():
if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists
cuda_libraries.append('tensorrt')
if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty
cuda_libraries.append('nccl')
proc = subprocess.Popen(
[environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] +
cuda_libraries,
stdout=subprocess.PIPE,
env=maybe_encode_env(environ_cp))
if proc.wait():
# Errors from find_cuda_config.py were sent to stderr.
print('Asking for detailed CUDA configuration...\n')
return False
config = dict(
tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout)
print('Found CUDA %s in:' % config['cuda_version'])
print(' %s' % config['cuda_library_dir'])
print(' %s' % config['cuda_include_dir'])
print('Found cuDNN %s in:' % config['cudnn_version'])
print(' %s' % config['cudnn_library_dir'])
print(' %s' % config['cudnn_include_dir'])
if 'tensorrt_version' in config:
print('Found TensorRT %s in:' % config['tensorrt_version'])
print(' %s' % config['tensorrt_library_dir'])
print(' %s' % config['tensorrt_include_dir'])
if config.get('nccl_version', None):
print('Found NCCL %s in:' % config['nccl_version'])
print(' %s' % config['nccl_library_dir'])
print(' %s' % config['nccl_include_dir'])
print('\n')
environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path']
return True
def main():
global _TF_WORKSPACE_ROOT
global _TF_BAZELRC
@ -1627,7 +1387,7 @@ def main():
# environment variables.
environ_cp = dict(os.environ)
current_bazel_version = check_bazel_version('0.22.0', '0.24.0')
current_bazel_version = check_bazel_version('0.24.1', '0.25.0')
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
reset_tf_configure_bazelrc()
@ -1683,11 +1443,39 @@ def main():
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
set_tf_nccl_install_path(environ_cp)
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
environ_save = dict(environ_cp)
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
if validate_cuda_config(environ_cp):
cuda_env_names = [
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
'CUDA_TOOLKIT_PATH'
]
for name in cuda_env_names:
if name in environ_cp:
write_action_env_to_bazelrc(name, environ_cp[name])
break
# Restore settings changed below if CUDA config could not be validated.
environ_cp = dict(environ_save)
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
if is_linux():
set_tf_tensorrt_version(environ_cp)
set_tf_nccl_version(environ_cp)
set_tf_cuda_paths(environ_cp)
else:
raise UserInputError(
'Invalid CUDA setting were provided %d '
'times in a row. Assuming to be a scripting mistake.' %
_DEFAULT_PROMPT_ASK_ATTEMPTS)
set_tf_cuda_compute_capabilities(environ_cp)
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
@ -1755,10 +1543,8 @@ def main():
system_specific_test_config(os.environ)
if get_var(environ_cp, 'TF_CONFIGURE_IOS', 'Configure TensorFlow for iOS',
False, ('Would you like to configure TensorFlow for iOS builds?'),
'Configuring TensorFlow for iOS builds.',
'Not configuring TensorFlow for iOS builds.'):
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
configure_ios()
else:
# TODO(pcloudy): Remove BAZEL_USE_CPP_ONLY_TOOLCHAIN after Bazel is upgraded

View File

@ -12,7 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
"""
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
TensorFlow via the command `import tensorflow as tf`.
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
sub-modules, as described below.
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
"""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division

View File

@ -118,7 +118,11 @@ if _running_from_pip_package():
# pylint: disable=undefined-variable
try:
del python
if '__all__' in vars():
vars()['__all__'].remove('python')
del core
if '__all__' in vars():
vars()['__all__'].remove('core')
except NameError:
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
@ -129,6 +133,8 @@ except NameError:
# others don't exist.
try:
del compiler
if '__all__' in vars():
vars()['__all__'].remove('compiler')
except NameError:
pass
# pylint: enable=undefined-variable

View File

@ -104,6 +104,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",

View File

@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
const auto& op_type = op->operation.Name();
auto op_name =
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
auto* desc =
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
std::unique_ptr<TF_OperationDescription> desc(
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
VLOG(1) << "Adding attrs.";
tensorflow::AttrValueMap attrs;
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
size_t inputIndex = 0;
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
// TODO(bgogul): Add support for number attributes.
DCHECK(input_arg.number_attr().empty())
<< "Number attributes is not implemented yet.";
if (input_arg.type_list_attr().empty()) {
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
auto symbolic_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
TF_AddInput(desc, symbolic_input);
TF_AddInput(desc.get(), symbolic_input);
continue;
}
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
size_t list_size = 0;
if (!input_arg.type_list_attr().empty()) {
const std::string& type_list_attr = input_arg.type_list_attr();
const auto& attr_value = attrs[type_list_attr];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
<< "Type list attribute should be a list!";
list_size = attr_value.list().type_size();
} else {
CHECK(!input_arg.number_attr().empty());
const auto& attr_value = attrs[input_arg.number_attr()];
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
<< "Number attribute should be int!";
if (attr_value.i() < 0) {
status->status = tensorflow::errors::Internal(
"Number attribute for length should be >=0!");
return nullptr;
}
list_size = attr_value.i();
}
std::vector<TF_Output> list_inputs(list_size);
for (TF_Output& list_input : list_inputs) {
list_input =
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
if (!status->status.ok()) return nullptr;
}
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
}
auto* graph_op = TF_FinishOperation(desc, status);
auto* graph_op = TF_FinishOperation(desc.release(), status);
if (!status->status.ok()) return nullptr;
VLOG(1) << "Op finalized; setting return tensors.";

View File

@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
TFE_DeleteOp(identityn);
}
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", 2);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
constexpr size_t kNumInputs = 2;
for (size_t i = 0; i < kNumInputs; ++i) {
TFE_OpAddInput(concatv2, matrix, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
TFE_OpAddInput(concatv2, axis, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
AddEagerOpToGraphAndCheck(
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
int64_t attrN;
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
EXPECT_EQ(attrN, kNumInputs);
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
kNumInputs);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
});
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
TFE_DeleteOp(concatv2);
}
TEST_F(AddEagerOpToGraphTest,
GeneratesInternalErrorsForInvalidNumberAttributes) {
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
TFE_TensorHandle* axis = TestAxisTensorHandle();
int num_retvals = 5;
TFE_TensorHandle* retvals[5];
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
TFE_OpSetAttrInt(concatv2, "N", -1);
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
&num_retvals, status_);
EXPECT_EQ(graph_op, nullptr);
EXPECT_EQ(status_->status.error_message(),
"Number attribute for length should be >=0!");
TFE_DeleteOp(concatv2);
TFE_DeleteTensorHandle(axis);
TFE_DeleteTensorHandle(matrix);
}
} // namespace
} // namespace tensorflow

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_internal.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -352,6 +352,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
argdef->set_type(node->output_type(idx));
const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
for (const auto& attr : node->attrs()) {
// Only copy internal attributes. These attributes will be applied to
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
// nodes.
if (absl::StartsWith(attr.first, "_")) {
arg_attrs.mutable_attr()->insert(attr);
}
}
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}

View File

@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
}
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
const char* attr_name, const char* attr_value,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
TF_NewGraph(), TF_DeleteGraph);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
TF_DeleteStatus);
TF_Operation* node;
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
&node);
TF_Output inputs[] = {{node, 0}};
TF_Output outputs[] = {};
func_ = TF_GraphToFunction(
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
/*opers=*/nullptr, 1, inputs, 0, outputs,
/*output_names=*/nullptr,
/*opts=*/nullptr, /*description=*/nullptr, s.get());
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
ASSERT_NE(func_, nullptr);
// Verify that FunctionDef ArgDef has attributes.
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
auto arg_attrs = func_->fdef.arg_attr().find(0);
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
auto iter = arg_attrs->second.attr().find("_test_attr");
ASSERT_NE(iter, arg_attrs->second.attr().end());
EXPECT_EQ(iter->second.s(), "value");
}
TEST_F(CApiFunctionTest, SetGradientAndRun) {
// Define the function and its grad
DefineFunction(func_name_, &func_);

View File

@ -29,8 +29,7 @@ namespace checkpoint {
class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
: reader_(nullptr),
v2_reader_(nullptr),
var_to_shape_map_(nullptr),
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status());
Set_TF_Status_from_Status(status, v2_reader_->status());
return;
}
auto result = BuildV2VarMaps();
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
} else {
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
Set_TF_Status_from_Status(status, reader_->status());
return;
}
var_to_shape_map_.reset(

View File

@ -39,7 +39,7 @@ class TensorSliceReader;
// variables.
class CheckpointReader {
public:
CheckpointReader(const string& filepattern, TF_Status* out_status);
CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const;
const string DebugString() const;

View File

@ -70,6 +70,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
@ -110,6 +111,7 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/profiler/lib:profiler_eager_lib",
"//tensorflow/core/profiler/lib:profiler_session",
],
)
@ -200,6 +202,7 @@ tf_cuda_library(
"//conditions:default": [],
}) + [
"@com_google_absl//absl/memory",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -236,7 +239,6 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler:protos_all_cc",
"@com_google_absl//absl/strings",
],
)

View File

@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
tensorflow::eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(
remote_eager_workers->GetClient(remote_worker, &eager_client));
if (eager_client == nullptr) {
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);

View File

@ -17,6 +17,12 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
@ -92,3 +98,123 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
num_tracing_attempts);
return s.ok();
}
static tensorflow::mutex gauges_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string,
tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
get_gauges_map() EXCLUSIVE_LOCKS_REQUIRED(gauges_map_lock) {
static std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
gauges_map = new std::unordered_map<
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>;
return gauges_map;
}
static tensorflow::mutex samplers_map_lock(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
get_samplers_map() EXCLUSIVE_LOCKS_REQUIRED(samplers_map_lock) {
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
samplers_map =
new std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>;
return samplers_map;
}
void TFE_MonitoringSetGauge(const char* name, const char* label,
int64_t value) {
tensorflow::mutex_lock l(gauges_map_lock);
auto gauges_map = get_gauges_map();
if (gauges_map->find(name) == gauges_map->end()) {
gauges_map->emplace(
name, tensorflow::monitoring::Gauge<tensorflow::int64, 1>::New(
name,
tensorflow::strings::StrCat(
name, " :Gauge metric collected from Python API."),
"metric_descriptor"));
}
gauges_map->at(name)->GetCell(label)->Set(value);
}
void TFE_MonitoringAddSampler(const char* name, const char* label,
double value) {
tensorflow::mutex_lock l(samplers_map_lock);
auto samplers_map = get_samplers_map();
if (samplers_map->find(name) == samplers_map->end()) {
samplers_map->emplace(
name, tensorflow::monitoring::Sampler<1>::New(
{name,
tensorflow::strings::StrCat(
name, " :Counter metric collected from Python API."),
"metric_descriptor"},
{tensorflow::monitoring::Buckets::Exponential(1, 2, 30)}));
}
samplers_map->at(name)->GetCell(label)->Add(value);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
}
int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
return cell->cell.value();
}
TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringCounter0({name, description});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
return result;
}
void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell()));
}
TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringCounter1({name, description, label1});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
return result;
}
void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1)));
}
TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringCounter2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
return result;
}
void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1, label2)));
}

View File

@ -87,6 +87,68 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
// Set the value of a Gauge metric. If the metric with given name does not
// exist, it will create a new Gauge metric. Right now it only supports type
// int64, consider to add more type supports if needed.
TF_CAPI_EXPORT extern void TFE_MonitoringSetGauge(const char* name,
const char* label,
int64_t value);
// Add the given value to a Sampler metric. If the metric with given name
// does not exist, it will create a new Sampler metric.
TF_CAPI_EXPORT extern void TFE_MonitoringAddSampler(const char* name,
const char* label,
double value);
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
// These APIs de-templated monitoring Counter for swig.
typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
// Atomically increments the value of the cell. The value must be non-negative.
TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
TFE_MonitoringCounterCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
TFE_MonitoringCounterCell* cell);
// APIs for Counter without label.
typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
// Returns a new Counter metric object. The caller should manage lifetime of
// the object. Using duplicate metric name will crash the program with fatal
// error.
TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
const char* name, TF_Status* status, const char* description);
// Deletes the Counter object.
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
TFE_MonitoringCounter0* counter);
// Retrieves the cell from the Counter object. The Counter object will manage
// lifetime of the cell.
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter);
// APIs for Counter with 1 label.
typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
const char* name, TF_Status* status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
TFE_MonitoringCounter1* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1);
// APIs for Counter with 2 labels.
typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
TFE_MonitoringCounter2* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -16,14 +16,16 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/profiler/trace_events.pb.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -79,11 +81,15 @@ void ExecuteWithProfiling(bool async) {
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
// device name with "stream:all" is collected by Device Tracer.
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
// TODO(fishx): move following check out from this if statement.
// This is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
}
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
@ -125,5 +131,94 @@ TEST(CAPI, MultipleProfilerSession) {
TFE_DeleteProfilerContext(profiler_context);
}
TEST(CAPI, MonitoringSetGauge) {
TFE_MonitoringSetGauge("test/gauge", "label", 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringSetGauge("test/gauge", "label", 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =
TFE_MonitoringNewCounter0("test/counter", status, "description");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell = TFE_MonitoringGetCellCounter0(counter);
TFE_MonitoringCounterCellIncrementBy(cell, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/counter",
metrics->point_set_map.at("test/counter")->metric_name);
EXPECT_EQ(
1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringCounterCellIncrementBy(cell, 5);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(
6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringDeleteCounter0(counter);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(metrics->point_set_map.end(),
metrics->point_set_map.find("test/counter"));
}
TEST(CAPI, MonitoringCounterMultiple) {
TF_Status* status = TF_NewStatus();
auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
"description", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
TFE_MonitoringCounterCellIncrementBy(cell1, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
"description", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
TFE_MonitoringCounterCellIncrementBy(cell2, 2);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
TFE_MonitoringDeleteCounter1(counter1);
TFE_MonitoringDeleteCounter2(counter2);
}
TEST(CAPI, MonitoringAddSampler) {
TFE_MonitoringAddSampler("test/sampler", "label", 1.0);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/sampler",
metrics->point_set_map.at("test/sampler")->metric_name);
EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringAddSampler("test/sampler", "label", 5.0);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
}
} // namespace
} // namespace tensorflow

View File

@ -15,8 +15,6 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include "tensorflow/c/eager/c_api.h"
#include <algorithm>
#include <cstddef>
#include <map>
@ -28,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -50,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
@ -133,6 +133,32 @@ struct TFE_Profiler {
std::unique_ptr<tensorflow::ProfilerSession> profiler;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};
template <int NumLabels>
struct TFE_MonitoringCounter {
template <typename... LabelDesc>
TFE_MonitoringCounter(const char* name, const char* description,
LabelDesc&&... label) {
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
name, description, label...));
}
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
};
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
using TFE_MonitoringCounter::TFE_MonitoringCounter;
};
namespace tensorflow {
// Set an AttrValue on the op. Doesn't handle the list types.
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,

View File

@ -21,7 +21,6 @@ from __future__ import print_function as _print_function
import os as _os
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER

View File

@ -163,7 +163,10 @@ def tf_library(
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
# The XLA backends morph kernal name prefix __ that is not in the form of
# __xla_.
ep = ("__xla_" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:

View File

@ -371,6 +371,7 @@ cc_library(
srcs = ["resource_operation_safety_analysis.cc"],
hdrs = ["resource_operation_safety_analysis.h"],
deps = [
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
@ -521,6 +522,7 @@ cc_library(
":device_info_cache",
":encapsulate_util",
":flags",
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
@ -565,7 +567,6 @@ cc_library(
hdrs = ["xla_cluster_util.h"],
deps = [
":flags",
":resource_operation_safety_analysis",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -732,44 +733,6 @@ tf_cc_test(
],
)
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":compilation_passes",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
cc_library(
name = "node_matchers",
testonly = True,

View File

@ -106,6 +106,8 @@ namespace tensorflow {
namespace {
using se::port::StatusOr;
// Represents a logical predicate, used as described in the algorithm overview
// above.
class Predicate {
@ -698,7 +700,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status Populate();
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override;
void Print() const override;
absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
const;
@ -1113,42 +1116,13 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
return Status::OK();
}
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
CHECK(!node.IsMerge());
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
}
Predicate* pred = nullptr;
for (const Edge* edge : node.in_edges()) {
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
CHECK(it != predicate_map_.end());
if (vlog_) {
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
<< it->second->ToString();
}
// Today we just compare the predicates for equality (with some
// canonicalization/simplification happening before) but we could be more
// sophisticated here if need be. Comparing pointers is sufficient because
// we intern Predicate instances by their content.
if (pred != nullptr && pred != it->second) {
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> true";
}
return true;
}
pred = it->second;
}
if (vlog_) {
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
<< ") -> false";
}
return false;
StatusOr<DeadnessAnalysis::DeadnessPredicate>
DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
auto it = predicate_map_.find(TensorId(n->name(), oidx));
TF_RET_CHECK(it != predicate_map_.end())
<< "could not find " << TensorId(n->name(), oidx).ToString()
<< " in predicate map";
return MakeDeadnessPredicate(it->second);
}
void DeadnessAnalysisImpl::Print() const {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -43,14 +44,38 @@ namespace tensorflow {
// "liveness" already has other connotations.
class DeadnessAnalysis {
public:
// Returns true if `node` may have some live inputs and some dead inputs.
//
// This is a conservatively correct routine -- if it returns false then `node`
// is guaranteed to not have inputs with mismatching liveness, but not the
// converse.
//
// REQUIRES: node is not a Merge operation.
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
// An opaque representation of a predicate. DeadnessPredicate
// instances that compare equal via operator== represent predicates
// that always evaluate to the same value.
struct DeadnessPredicate {
public:
DeadnessPredicate(const DeadnessPredicate&) = default;
DeadnessPredicate(DeadnessPredicate&&) = default;
DeadnessPredicate& operator=(const DeadnessPredicate&) = default;
DeadnessPredicate& operator=(DeadnessPredicate&&) = default;
bool operator==(const DeadnessPredicate& other) const {
return other.pred_ == pred_;
}
bool operator!=(const DeadnessPredicate& other) const {
return other.pred_ != pred_;
}
private:
explicit DeadnessPredicate(void* pred) : pred_(pred) {}
// This is really a Predicate*, but we don't want to expose that
// implementation detail to our clients. `pred_` has pointer equality so we
// can just compare the pointer in operator== and operator!=.
void* pred_;
friend class DeadnessAnalysis;
};
virtual se::port::StatusOr<DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const = 0;
// Prints out the internal state of this instance. For debugging purposes
// only.
@ -61,6 +86,11 @@ class DeadnessAnalysis {
// instance of DeadnessAnalysis in `result`.
static Status Run(const Graph& graph,
std::unique_ptr<DeadnessAnalysis>* result);
protected:
static DeadnessPredicate MakeDeadnessPredicate(void* pred) {
return DeadnessPredicate(pred);
}
};
} // namespace tensorflow

View File

@ -37,6 +37,22 @@ limitations under the License.
namespace tensorflow {
namespace {
se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
const DeadnessAnalysis& deadness_analysis, const Node& n) {
absl::optional<DeadnessAnalysis::DeadnessPredicate> pred;
for (const Edge* edge : n.in_edges()) {
TF_ASSIGN_OR_RETURN(
DeadnessAnalysis::DeadnessPredicate this_pred,
deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
if (pred && *pred != this_pred) {
return true;
}
pred = this_pred;
}
return false;
}
using deadness_analysis_internal::ComputePredicates;
using deadness_analysis_internal::PredicateMapTy;
@ -219,7 +235,10 @@ TEST(DeadnessAnalysisTest, BasicPositive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, BasicNegative) {
@ -232,7 +251,10 @@ TEST(DeadnessAnalysisTest, BasicNegative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndIsCommutative) {
@ -260,11 +282,27 @@ TEST(DeadnessAnalysisTest, AndIsCommutative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
bool has_inputs_with_mismatching_deadness;
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live1.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndIsAssociative) {
@ -287,7 +325,10 @@ TEST(DeadnessAnalysisTest, AndIsAssociative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrIsCommutative) {
@ -312,11 +353,27 @@ TEST(DeadnessAnalysisTest, OrIsCommutative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
bool has_inputs_with_mismatching_deadness;
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *live1.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrIsAssociative) {
@ -336,7 +393,10 @@ TEST(DeadnessAnalysisTest, OrIsAssociative) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndOfOr) {
@ -358,7 +418,10 @@ TEST(DeadnessAnalysisTest, AndOfOr) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add2.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, OrOfAnd) {
@ -382,7 +445,10 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add2.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
@ -430,7 +496,10 @@ TEST(DeadnessAnalysisTest, AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add3.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Ternary) {
@ -454,7 +523,10 @@ TEST(DeadnessAnalysisTest, Ternary) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Recv) {
@ -469,7 +541,10 @@ TEST(DeadnessAnalysisTest, Recv) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, HostRecv) {
@ -484,7 +559,10 @@ TEST(DeadnessAnalysisTest, HostRecv) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, Loop) {
@ -505,8 +583,17 @@ TEST(DeadnessAnalysisTest, Loop) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
bool has_inputs_with_mismatching_deadness;
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
TF_ASSERT_OK_AND_ASSIGN(
has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add1.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -544,7 +631,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -634,7 +724,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
{
PredicateMapTy predicate_map;
@ -693,7 +786,10 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
{
@ -792,7 +888,10 @@ TEST(DeadnessAnalysisTest, ControlInputs) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, ControlTrigger) {
@ -819,7 +918,10 @@ TEST(DeadnessAnalysisTest, ControlTrigger) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
@ -840,7 +942,10 @@ TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add.node()));
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
@ -857,7 +962,10 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
TEST(DeadnessAnalysisTest, RecvVsSwitchText) {

View File

@ -599,7 +599,8 @@ Status ConstructHostGraph(
Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
FunctionLibraryDefinition* fld,
const string& host_graph_func_name,
Node* xla_computation_node) {
Node* xla_computation_node,
Node* pivot_node) {
// Temporarily use "0" as "device_ordinal". It will be rewritten with the
// correct value in a later pass. We cannot just use placeholder value here
// because FunctionDef instantiation does not allow placeholder value for
@ -620,7 +621,11 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
// Copy all nodes.
std::map<const Node*, Node*> node_map;
node_map[host_graph->source_node()] = main_graph->source_node();
if (pivot_node) {
node_map[host_graph->source_node()] = pivot_node;
} else {
node_map[host_graph->source_node()] = main_graph->source_node();
}
node_map[host_graph->sink_node()] = main_graph->sink_node();
Status s = Status::OK();
auto copy_node_fn = [&](const Node* n) {
@ -673,7 +678,7 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
// 2) Remove control edges.
// 3) Prune nodes that are not useful for shape inference.
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
Graph* host_graph,
Graph* host_graph, Node* pivot_node,
FunctionLibraryDefinition* fld) {
// Use "0" as "device_ordinal". It does not matter for shape inference.
AttrValue device_ordinal_attr;
@ -717,41 +722,45 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
for (Node* n : nodes) {
g->RemoveNode(n);
}
std::map<const Node*, Node*> node_map;
node_map[host_graph->source_node()] = g->source_node();
Status s;
auto copy_node_fn = [&](const Node* n) {
if (!s.ok()) {
return;
}
if (node_map.find(n) != node_map.end()) {
return;
}
NodeDef copy_def = n->def();
Node* copy = g->AddNode(copy_def, &s);
if (!s.ok()) {
return;
}
for (auto e : n->in_edges()) {
if (node_map.find(e->src()) == node_map.end()) {
s = errors::Internal("Cannot find node image for ",
e->src()->DebugString());
return;
}
g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input());
}
node_map[n] = copy;
Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
// Reverse DFS from send_from_host_main_graph, and stop at start_node.
struct Visit {
Node* n;
bool is_exiting;
};
// TODO(b/77601805): consolidate copy graph functions.
ReverseDFSFrom(*host_graph,
std::vector<const Node*>{send_from_host_main_graph},
/*enter=*/nullptr, copy_node_fn, NodeComparatorID());
if (!s.ok()) {
return s;
std::vector<Visit> stack{{send_from_host_main_graph, false}};
std::map<Node*, Node*> node_map;
node_map[host_graph->source_node()] = g->source_node();
while (!stack.empty()) {
Visit& curr = stack.back();
if (curr.is_exiting) {
if (node_map.find(curr.n) == node_map.end()) {
Node* copy = g->CopyNode(curr.n);
if (curr.n != start_node) {
for (const Edge* e : curr.n->in_edges()) {
auto node_iter = node_map.find(e->src());
if (node_iter == node_map.end()) {
return errors::Internal("Cannot find node image for ",
e->src()->DebugString());
}
g->AddEdge(node_iter->second, e->src_output(), copy,
e->dst_input());
}
}
node_map[curr.n] = copy;
}
stack.pop_back();
} else {
curr.is_exiting = true;
if (curr.n != start_node) {
for (const Edge* e : curr.n->in_edges()) {
if (node_map.find(e->src()) != node_map.end()) {
continue;
}
stack.push_back({e->src(), false});
}
}
}
}
send_from_host = node_map[send_from_host_main_graph];
@ -1687,13 +1696,14 @@ Status ExtractOutsideCompilation(
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
}
std::vector<string> shape_inference_graphs;
auto node_name_index = g->BuildNodeNameIndex();
for (auto& iter : clusters) {
string xla_cluster_name = iter.first;
Node* n = iter.second.node;
auto const& func_name_attrs = iter.second.func_name_attrs;
auto const& host_compute_core = iter.second.host_compute_core;
std::vector<string> shape_inference_graphs;
bool has_outside_compilation;
string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
@ -1701,14 +1711,18 @@ Status ExtractOutsideCompilation(
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
host_compute_core, flr, fld, &shape_inference_graphs,
&has_outside_compilation));
TF_RETURN_IF_ERROR(
ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
}
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
Node* pivot_node = node_name_index[pivot_name];
TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
g, fld, host_graph_func_name, n, pivot_node));
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
g, pivot_node, fld));
}
}
if (VLOG_IS_ON(4)) {

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,8 @@ class MarkForCompilationPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
Status RunForTest(const GraphOptimizationPassOptions& options);
Status RunForTest(const GraphOptimizationPassOptions& options,
bool disable_deadness_analysis);
friend class MarkForCompilationPassTestHelper;
};

View File

@ -525,8 +525,8 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
@ -564,8 +564,8 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
@ -598,8 +598,8 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
TF_ASSERT_OK(
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
@ -610,6 +610,77 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
EXPECT_EQ(clusters["B"], clusters["C"]);
}
TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
Scope root = Scope::NewRootScope().ExitOnError();
Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph,
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
auto clusters = GetClusters(*graph);
EXPECT_NE(clusters["tan_a0"], "");
EXPECT_NE(clusters["tan_a1"], "");
EXPECT_NE(clusters["tan_b0"], "");
EXPECT_NE(clusters["tan_b1"], "");
EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
}
TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
Scope root = Scope::NewRootScope().ExitOnError();
Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
switch_b.output_true);
Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
switch_b.output_true);
Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph,
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
auto clusters = GetClusters(*graph);
EXPECT_NE(clusters["add_a"], "");
EXPECT_NE(clusters["add_b"], "");
EXPECT_NE(clusters["add_c"], "");
EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
}
namespace {
Node* MakeRead(const Scope& scope, const string& id,
Node** var_handle_op = nullptr) {
@ -703,7 +774,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes_a = {
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
"AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
}

View File

@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
bool enable_global_jit) {
MarkForCompilationPassTestHelper::Options options) {
// Assign all unassigned nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
@ -31,7 +31,7 @@ namespace tensorflow {
}
SessionOptions session_options;
if (enable_global_jit) {
if (options.enable_global_jit) {
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_global_jit_level(OptimizerOptions::ON_2);
@ -49,13 +49,16 @@ namespace tensorflow {
opt_options.session_options = &session_options;
opt_options.flib_def = flib_def;
MarkForCompilationPass pass;
return pass.RunForTest(opt_options);
return pass.RunForTest(
opt_options,
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
}
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, bool enable_global_jit) {
std::unique_ptr<Graph>* graph,
MarkForCompilationPassTestHelper::Options options) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
return MarkForCompilation(graph, &flib_def, enable_global_jit);
return MarkForCompilation(graph, &flib_def, options);
}
} // namespace tensorflow

View File

@ -21,16 +21,35 @@ limitations under the License.
namespace tensorflow {
class MarkForCompilationPassTestHelper {
public:
struct Options {
bool enable_global_jit;
bool disable_deadness_analysis;
Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
Options WithNoGlobalJit() {
Options copy = *this;
copy.enable_global_jit = false;
return copy;
}
Options WithDeadnessAnalysis() {
Options copy = *this;
copy.disable_deadness_analysis = false;
return copy;
}
};
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
// `graph` to the CPU device. To make testing easier, ignores device
// registration, _XlaCompile attributes and input deadness.
// registration and _XlaCompile attributes.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def,
bool enable_global_jit = true);
Options options = Options());
// Like `MarkForCompilation` but creates `flib_def` from the op registry.
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
bool enable_global_jit = true);
Options options = Options());
};
} // namespace tensorflow

View File

@ -84,6 +84,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@ -93,22 +94,6 @@ limitations under the License.
namespace tensorflow {
namespace {
// Returns true if `n` may call a function.
Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
bool* out_result) {
if (flib_def->Contains(n.type_string())) {
*out_result = true;
} else {
*out_result =
std::any_of(n.def().attr().begin(), n.def().attr().end(),
[](const std::pair<string, AttrValue>& name_attr_pair) {
return name_attr_pair.second.has_func();
});
}
return Status::OK();
}
// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
// not a resource operation recognized by XLA then sets `out_resource_op_kind`
// to nullopt.
@ -134,9 +119,7 @@ Status XlaResourceOpKindForNode(
// We conservatively assume that functions will both read and write resource
// variables. In the future we may consider doing some form of
// inter-procedural analysis.
bool may_call_function;
TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
if (may_call_function) {
if (MayCallFunction(n, flib_def)) {
*out_resource_op_kind = XlaResourceOpKind::kReadWrite;
} else {
*out_resource_op_kind = absl::nullopt;

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -227,28 +226,6 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
GraphCycles* cycles) {
std::vector<std::pair<int, int>> unsafe_deps;
TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
*graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
// An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
// operations that interact with resource variables, must not be put in the
// same cluster. We enforce this constraint by creating a phantom node, X,
// and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
// and Q together since that would create a cycle with X.
for (std::pair<int, int> unsafe_dep : unsafe_deps) {
int phantom_node_id = cycles->NewNode();
CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
}
return Status::OK();
}
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device,
@ -436,4 +413,16 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
return result;
}
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
if (flib_def->Contains(n.type_string())) {
return true;
}
// This is a conservative check: there may be nodes with a `func`
// attribute that do not make function calls.
return absl::c_any_of(n.def().attr(),
[](const std::pair<string, AttrValue>& name_attr_pair) {
return name_attr_pair.second.has_func();
});
}
} // namespace tensorflow

View File

@ -74,13 +74,6 @@ void RemoveFromXlaCluster(Node* node);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
// Adds edges to `cycles` to prevent clustering resource operations that cannot
// be legally clustered.
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
GraphCycles* cycles);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `device_names`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
@ -134,6 +127,10 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
// Returns true if `g` is a single-GPU graph. A single-GPU graph uses exactly
// one GPU (and any number of CPUs).
bool IsSingleGpuGraph(const Graph& g);
// Returns true if it is possible (but not guaranteed) that `n` calls a
// function.
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -30,10 +30,17 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
return Status::OK();
}
Status XlaCpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
@ -46,7 +53,13 @@ Status XlaCpuDeviceFactory::CreateDevices(
compile_on_demand
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_all_resource_ops = true;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -519,7 +519,7 @@ Status XlaDevice::RefreshStatus() {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
// it just-in-time.
OpKernel* (*factory)(OpKernelConstruction*) =
[](OpKernelConstruction* context) -> OpKernel* {

View File

@ -247,6 +247,9 @@ class XlaAssignVariableOp : public OpKernel {
data::MakeIteratorOp); \
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER( \
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \

View File

@ -1,349 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include <atomic>
#include <deque>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
namespace tensorflow {
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
node.type_string() == "Rank" || node.type_string() == "Size";
}
// Returns true if the op can be decomposed into XLA ops for which
// there are fusible elemental implementations.
static bool IsXlaFusible(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
{// tf2xla/kernels/aggregate_ops.cc
"AddN",
// tf2xla/kernels/binary_ops.cc
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
// tf2xla/kernels/unary_ops.cc
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
"Square", "Tan", "Tanh", "Real", "Imag",
// tf2xla/kernels/bcast_ops.cc
"BroadcastArgs", "BroadcastGradientArgs",
// tf2xla/kernels/bias_ops.cc
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
// tf2xla/kernels/cast_op.cc
"Cast",
// tf2xla/kernels/concat_op.cc
"Concat", "ConcatV2", "ConcatOffset",
// tf2xla/kernels/const_op.cc
"Const",
// tf2xla/kernels/elu_op.cc
"Elu", "EluGrad", "Selu", "SeluGrad",
// tf2xla/kernels/fill_op.cc
"Fill",
// tf2xla/kernels/identity_op.cc
"Identity", "IdentityN", "PreventGradient",
"StopGradient", /*"Snapshot",*/
// tf2xla/kernels/index_ops.cc
"ArgMax", "ArgMin",
// tf2xla/kernels/mirror_pad_op.cc
"MirrorPad",
// tf2xla/kernels/one_hot_op.cc
"OneHot",
// tf2xla/kernels/pack_op.cc
"Pack",
// tf2xla/kernels/pad_op.cc
"Pad", "PadV2",
// tf2xla/kernels/relu_op.cc
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
// tf2xla/kernels/reshape_op.cc
"Reshape",
// tf2xla/kernels/reverse_op.cc
"Reverse", "ReverseV2",
// tf2xla/kernels/reverse_sequence_op.cc
"ReverseSequence",
// tf2xla/kernels/shape_op.cc
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
"ZerosLike", "OnesLike",
// tf2xla/kernels/slice_op.cc
"Slice",
// tf2xla/kernels/split_op.cc
"Split", "SplitV",
// tf2xla/kernels/strided_slice_op.cc
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
// tf2xla/kernels/tile_ops.cc
"Tile",
// tf2xla/kernels/transpose_op.cc
"Transpose", "InvertPermutation",
// tf2xla/kernels/unpack_op.cc
"Unpack"});
return elementwise_ops->count(node.op()) > 0;
}
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) {
VLOG(2) << "Here at fusion optimizer";
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
// Once that happens, the expected interaction between this optimizer and when
// the global_jit_level is set is as follows: Fusion optimizer will replace
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
// be further compiled where possible via mark_for_compilation_pass. Note that
// this might lead to inefficient clustering, and it is best to use either the
// fusion optimizer or the global_jit flag, and not combine the two.
// Create a Graph out of GraphDef. This is required currently because the
// helpers around clustering, encapsulation etc work on graphs.
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item.graph.library());
Graph graph(function_library);
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
ImportGraphDefOptions options;
// Graph optimization happens at the late stage of graph execution, when
// colocation constraints are already validated previously and the device
// placement of nodes has also completed, so there is no need to validate
// colocation constraints again.
options.validate_colocation_constraints = false;
options.validate_shape = false;
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
std::unique_ptr<DeadnessAnalysis> deadness;
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
OrderedNodeSet compilation_candidates;
for (Node* node : graph.op_nodes()) {
// If there is a _XlaCompile annotation, ignore the node if it is
// true. Nodes are marked with this attr via experimental_jit_scope, and
// will be handled by the mark_for_compilation pass.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok() && compile) {
continue;
}
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
// marked with this attr to indicate they are already part of a cluster and
// hence ignored.
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
if (status.ok()) {
continue;
}
// If there is an explicit XLA device placement, ignore the node.
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
// Assume all fusible ops are registered.
// TODO(hpucha): Check for registration if possible.
if (!IsXlaFusible(node->def())) {
continue;
}
// XLA does not offer guaranteed aliasing between the input and output of
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
// Leave such nodes out of XLA clusters.
if (HasForwardedRefInput(*node)) {
continue;
}
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness to propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
continue;
}
compilation_candidates.insert(node);
}
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
*output = item.graph;
return Status::OK();
}
GraphCycles cycles;
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
CreateCycleDetectionGraph(&graph, &cycles));
if (!cycle_detection_graph_ok) {
return Status::OK();
}
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
// when clustering changes order of execution. See b/77263461 for a specific
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
// example.
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. This is a simplified
// version of the clustering in mark_for_compilation_pass that also deals with
// nodes that are explicitly tagged to be compiled/clustered.
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph.FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
for (int to : cycles.Successors(from)) {
if (to >= graph.num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
Node* node_to = graph.FindNodeId(to);
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
}
// Do not cluster across devices.
if (node_from->def().device() != node_to->def().device()) {
VLOG(2) << "Devices " << node_from->def().device() << " "
<< node_to->def().device();
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
<< node_to->assigned_device_name();
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
// If contracting the edge would create a cycle, bail out.
// However, just because we can't merge the clusters now does not mean
// we won't be able to merge them in the future.
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) continue;
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Identity nodes will be removed if the node gets marked for compilation.
// Therefore we don't want to count them towards the effective cluster size.
if (n->def().op() != "Identity") {
effective_cluster_sizes[cluster]++;
}
}
const int min_cluster_size = 2;
int num_clusters = 0;
for (auto size : effective_cluster_sizes) {
if (size >= min_cluster_size) {
VLOG(3) << "Cluster " << num_clusters << " " << size;
num_clusters++;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Compile if this is a cluster of >= min_cluster_size compilable operators.
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = absl::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
graph.ToGraphDef(output);
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
} // namespace tensorflow

View File

@ -1,49 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
// Optimizes graphs by fusing ops where possible, resulting in more efficient
// execution.
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
public:
XlaFusionOptimizer() {}
~XlaFusionOptimizer() override {}
Status Init(
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
return Status::OK();
}
string name() const override { return "xla-fusion"; };
Status Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) override;
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
const GraphDef& optimize_output, double result) override {
// Nothing to do for XlaFusionOptimizer.
}
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_

View File

@ -1,208 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
protected:
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
std::unordered_map<string, string> ids;
for (const NodeDef& node : graph.node()) {
string cluster;
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node.name()] = cluster;
}
}
return ids;
}
};
TEST_F(XlaFusionOptimizerTest, Chains) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
Node* d =
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_EQ(clusters["E"], clusters["F"]);
EXPECT_NE(clusters["B"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, FusibleOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(2, clusters.size());
EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp(
"Add", a, b,
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
ops::UnaryOp("Cos", e,
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(3, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["A"], clusters["C"]);
}
TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
Scope root = Scope::NewRootScope().ExitOnError();
Output var_handle =
ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
Output begin = ops::Const(root.WithOpName("begin"), 0);
Output end = ops::Const(root.WithOpName("end"), 1);
Output strides = ops::Const(root.WithOpName("strides"), 1);
ops::ResourceStridedSliceAssign assign_1(
root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
ops::ResourceStridedSliceAssign assign_2(
root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
root.graph()->AddControlEdge(assign_1.operation.node(),
assign_2.operation.node());
grappler::GrapplerItem item;
root.graph()->ToGraphDef(&item.graph);
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
}
} // namespace
} // namespace tensorflow

View File

@ -55,10 +55,32 @@ static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
if (!platform.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
return Status::OK();
}
int device_count = platform.ValueOrDie()->VisibleDeviceCount();
if (device_count <= 0) {
return Status::OK();
}
for (int i = 0; i < device_count; ++i) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
}
return Status::OK();
}
Status XlaGpuDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
@ -66,7 +88,13 @@ Status XlaGpuDeviceFactory::CreateDevices(
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_all_resource_ops = true;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
static XlaDeviceOpRegistrations* registrations =

View File

@ -32,10 +32,19 @@ constexpr std::array<DataType, 10> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
return Status::OK();
}
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
@ -47,7 +56,13 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.compile_all_resource_ops = true;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);

View File

@ -347,9 +347,11 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
<< "Invalid input for outputs " << i;
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
int input_index =
kernel->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {

View File

@ -875,6 +875,7 @@ tf_xla_py_test(
name = "stack_ops_test",
size = "small",
srcs = ["stack_ops_test.py"],
tags = ["config-cuda-only"],
use_xla_device = False,
deps = [
":xla_test",
@ -921,6 +922,7 @@ tf_xla_py_test(
srcs = ["tensor_array_ops_test.py"],
# TensorArray ops are not implemented in the on-demand compilation model yet.
disabled_backends = ["cpu_ondemand"],
tags = ["config-cuda-only"],
use_xla_device = False,
deps = [
":xla_test",

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([2], dtype=np.int64),
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
def testBatchMatMulBroadcast(self):
"""Tests broadcasting behavior of BatchMatMul."""
with compat.forward_compatibility_horizon(2019, 4, 19):
# [2, 3] @ [1, 3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32),
np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [1, 2, 3] @ [3, 4] -> [1, 2, 4]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]],
dtype=np.float32),
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
dtype=np.float32))
# [2, 1, 3] @ [3, 1] -> [2, 1, 1]
self._testBinary(
math_ops.matmul,
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1], [2], [3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b)
self._testBinary(
lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True),
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
np.array([[1, 2, 3]], dtype=np.float32),
expected=np.array([[[140]], [[146]]], dtype=np.float32))
# [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4]
self._testBinary(
math_ops.matmul,
np.ones([5, 1, 2, 3], dtype=np.float32),
np.ones([1, 7, 3, 4], dtype=np.float32),
expected=np.full([5, 7, 2, 4], 3, dtype=np.float32))
# [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5]
self._testBinary(
math_ops.matmul,
np.full([4, 5, 1, 2, 3], 2., dtype=np.float32),
np.full([1, 1, 3, 5], 3., dtype=np.float32),
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
def testPad(self):
for dtype, pad_type in itertools.product(
self.numeric_types, [np.int32, np.int64]):

View File

@ -341,7 +341,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
var = f()
self.assertEqual(1.0, var.numpy())
def DISALBED_testResourceVariableNoInlineReadWrite(self):
def testResourceVariableNoInlineReadWrite(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
w = resource_variable_ops.ResourceVariable(0.0)
@ -359,8 +359,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertEqual(145.0, f().numpy())
self.assertEqual(15.0, w.read_value().numpy())
# TODO(b/36139787)
def DISABLED_testResourceVariableNoInlineReadOnly(self):
def testResourceVariableNoInlineReadOnly(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(10.0)
@ -374,8 +373,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertEqual(50.0, f().numpy())
# TODO(b/36139787)
def DISABLED_testResourceVariableNoInlineWriteOnly(self):
def testResourceVariableNoInlineWriteOnly(self):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(0.0)

View File

@ -95,7 +95,12 @@ class JitLaunchTest(test.TestCase):
# If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
# node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
# ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
def _compare(self,
fn,
args,
require_kernel_launch=True,
name=None,
noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
feeds = {}
@ -105,7 +110,8 @@ class JitLaunchTest(test.TestCase):
placeholders.append(placeholder)
feeds[placeholder] = arg
compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
compiled_op = CompiledKernel(
fn, *placeholders, name=name, noinline=noinline)
direct_op = fn(*placeholders)
run_metadata = config_pb2.RunMetadata()
@ -155,17 +161,16 @@ class JitLaunchTest(test.TestCase):
# to symbolically execute Bar correctly regardless of whether Bar is inlined
# or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
# self._compare(
# AddOnceReturnTwice, [np.array(
# [[[0.5, -1.0]]], dtype=np.float32)],
# noinline=True)
self._compare(
AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
name="AddOnceReturnTwice_inline",
noinline=True)
# Tests compiled=True and noinline=False.
self._compare(
AddOnceReturnTwice, [np.array(
[[[0.5, -1.0]]], dtype=np.float32)],
AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
name="AddOnceReturnTwice_noinline",
noinline=False)
def testOneConstOutput(self):

View File

@ -454,7 +454,7 @@ class PoolGradTest(xla_test.XLATestCase):
"""Verifies the output values of the pooling function.
Args:
pool_func: Pooling function to be called, e.g., tf.nn.max_pool
pool_func: Pooling function to be called, e.g., tf.nn.max_pool2d
pool_grad_func: Corresponding pooling gradient function.
input_sizes: Input tensor dimensions.
ksize: The kernel size dimensions

View File

@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
def fn():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
return ta.write(-1, np.int32(7)).flow
return ta.write(-1, constant_op.constant(7)).flow
# Test writing the wrong datatype.
with self.assertRaisesOpError(
"TensorArray dtype is float but op has dtype int32"):
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
# callers provide proper init dtype.
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
r"("
r"conversion requested dtype float32 for Tensor with dtype int32"
r"|"
r"TensorArray dtype is float but op has dtype int32"
r")"):
xla.compile(fn)[0].eval()
@test_util.disable_control_flow_v2("b/124334096 verify dtype")

View File

@ -125,7 +125,7 @@ class ListOpsTest(xla_test.XLATestCase):
self.assertAllEqual(e0, 2.0)
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(e1, 1.0)
self.assertAllEqual(list_ops.tensor_list_length(l), 0)
self.assertAllEqual(list_ops.tensor_list_length(l), 2)
def testGetSet(self):
with self.cached_session(), self.test_scope():
@ -211,6 +211,18 @@ class ListOpsTest(xla_test.XLATestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(t, [0., 0., 0.])
def testZerosLikeForTensorList(self):
with self.cached_session(), self.test_scope():
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=[],
max_num_elements=2)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
z = array_ops.zeros_like(l)
z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
self.assertAllEqual(z.shape.as_list(), [None])
self.assertAllEqual(z, [0.0, 0.0])
if __name__ == "__main__":
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
os.environ.get('TF_XLA_FLAGS', ''))

View File

@ -23,9 +23,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
"//tensorflow/core:platform/default/build_config.bzl",
"tf_additional_all_protos",
"tf_proto_library",
)
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
# Google-internal targets go here (must be at the end).
tf_cuda_cc_test(
name = "tensorrt_test_cc",
@ -74,13 +78,67 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "trt_engine_resource_op_kernels",
srcs = ["kernels/trt_engine_resource_ops.cc"],
copts = tf_copts(),
visibility = ["//visibility:private"],
deps = [
":trt_allocator",
":trt_engine_instance_proto_cc",
":trt_logging",
":trt_plugins",
":trt_resources",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
]) + tf_custom_op_library_additional_deps(),
alwayslink = 1,
)
tf_cuda_cc_test(
name = "trt_engine_resource_ops_test",
size = "small",
srcs = ["kernels/trt_engine_resource_ops_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_engine_instance_proto_cc",
":trt_engine_resource_op_kernels",
":trt_engine_resource_ops_op_lib",
":trt_logging",
":trt_resources",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:resource_variable_ops",
"@com_google_absl//absl/memory",
],
)
tf_cc_shared_object(
name = "python/ops/libtftrt.so",
copts = tf_copts(is_external = True),
linkopts = ["-lm"],
deps = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
"//tensorflow/core:lib_proto_parsing",
] + if_tensorrt([
"@local_config_tensorrt//:tensorrt",
@ -112,10 +170,40 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "trt_engine_op_test",
size = "small",
srcs = ["kernels/trt_engine_op_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":trt_op_kernels",
":trt_op_libs",
":trt_resources",
"@com_google_googletest//:gtest",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
] + if_tensorrt([
"@local_config_cuda//cuda:cuda_headers",
]),
)
tf_gen_op_libs(
op_lib_names = [
"trt_engine_op",
"get_serialized_resource_op",
"trt_engine_resource_ops",
],
)
@ -142,6 +230,7 @@ tf_cuda_library(
tf_gen_op_wrapper_py(
name = "trt_ops",
deps = [
":trt_engine_resource_ops_op_lib",
":trt_op_libs",
],
)
@ -156,7 +245,9 @@ tf_custom_op_py_library(
]),
kernels = [
":trt_op_kernels",
":trt_engine_resource_op_kernels",
":trt_op_libs",
":trt_engine_resource_ops_op_lib",
],
srcs_version = "PY2AND3",
deps = [
@ -173,6 +264,7 @@ tf_cuda_library(
name = "trt_resources",
srcs = [
"utils/trt_int8_calibrator.cc",
"utils/trt_lru_cache.cc",
"utils/trt_resources.cc",
],
hdrs = [
@ -440,6 +532,13 @@ cc_library(
],
)
tf_proto_library(
name = "trt_engine_instance_proto",
srcs = ["utils/trt_engine_instance.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
)
cc_library(
name = "py_utils",
srcs = ["utils/py_utils.cc"],

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
#include <memory>
@ -1350,11 +1351,19 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
// the dims are unknown or need to be inferred. And we don't do further checks
// but rely on the caller to not make mistakes.
// Otherwise we do simple check to make sure the total sizes are the same.
if (AreDimsStaticWithDifferentSize(input_dims, dims, input.is_tensor())) {
// If an input is a weight, it is going to become a tensor via
// CreateConstantLayer. So we can treat it as a tensor for
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
return errors::InvalidArgument(
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
DebugString(dims));
}
// ConstantLayer requires static shapes (cannot infer -1).
if (input.is_weights() && !HasStaticShape(dims)) {
return errors::InvalidArgument("Shape is not fully defined: ",
DebugString(dims));
}
if (validation_only) {
*tensor = nullptr;
return Status::OK();
@ -1589,18 +1598,6 @@ Status AllowDataTypes(const OpConverterParams& params,
return Status::OK();
}
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) {
TRT_ShapedWeights weights =
store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
const float* src = static_cast<const float*>(weights_src.GetValues());
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
for (int64_t i = 0; i < weights_src.count(); i++) {
dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]);
}
return weights;
}
// ****************************************************************************
// Constant folding functions for weights.
// TODO(laigd): we should probably use eigen directly.
@ -1614,7 +1611,7 @@ struct LambdaFactory {
switch (op) {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
return [](T t) -> T { return 1.0 / sqrt(t); };
return [](T t) -> T { return 1.0 / std::sqrt(t); };
}
case OP_CATEGORY::NEG:
return [](T t) -> T { return -t; };
@ -1633,7 +1630,7 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
return [](Eigen::half t) {
return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
return Eigen::half(1.0 / std::sqrt(static_cast<float>(t)));
};
}
case OP_CATEGORY::NEG:
@ -1772,10 +1769,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
params->converter->TransposeTensor(tensor, permutation, &tensor));
}
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
// Prepare weights
TRT_ShapedWeights shift_weights(weights.TrtDType());
TRT_ShapedWeights scale_weights(weights.TrtDType());
@ -1937,9 +1930,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
// num_groups will be 1.
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck);
}
// For conv, TF weights are RSCK, and TRT expects KCRS.
// For backprop, TF weights are RSKC, and TRT expects CKRS.
// Therefore, this reorder will work for both cases.
@ -3038,9 +3028,6 @@ Status ConvertBiasAdd(OpConverterParams* params) {
}
TRT_ShapedWeights weights = inputs.at(1).weights();
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
if (weights.shape_.d[0] == 1) {
mode = nvinfer1::ScaleMode::kUNIFORM;
@ -4237,6 +4224,95 @@ Status ConvertTopK(OpConverterParams* params) {
return Status::OK();
}
Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
TF_RETURN_IF_ERROR(AllowDataTypes(
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
TFAttrs attrs(node_def);
const int block_size = attrs.get<int64>("block_size");
if (block_size < 2) {
return errors::InvalidArgument("Block size must be 2 or greater, at ",
node_def.name());
}
const string data_format = attrs.get<string>("data_format");
if (data_format != "NCHW" && data_format != "NHWC") {
return errors::Unimplemented("Data format ", data_format,
" is not supported, at ", node_def.name());
}
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
if (dims.nbDims != 3) {
return errors::InvalidArgument("The input to ", node_def.op(),
" must be rank 4, at ", node_def.name());
}
const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
// Get shuffle parameters.
nvinfer1::Dims first_shuffle_shape;
nvinfer1::Permutation transpose_perm;
nvinfer1::Dims second_shuffle_shape;
if (node_def.op() == "DepthToSpace") {
if (num_channels % (block_size * block_size) != 0) {
return errors::InvalidArgument(
"Number of channels must be divisible by block_size*block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
first_shuffle_shape = {
/*nbDims=*/5,
/*d=*/{block_size, block_size, num_channels / (block_size * block_size),
h, w}};
// Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
transpose_perm = {2, 3, 0, 4, 1};
// Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
second_shuffle_shape =
nvinfer1::DimsCHW(num_channels / (block_size * block_size),
h * block_size, w * block_size);
} else if (node_def.op() == "SpaceToDepth") {
if (h % block_size != 0 || w % block_size != 0) {
return errors::InvalidArgument(
"Width and height must be divisible by block_size, at ",
node_def.name());
}
// First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
first_shuffle_shape = {/*nbDims=*/5,
/*d=*/{num_channels, h / block_size, block_size,
w / block_size, block_size}};
// Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
transpose_perm = {2, 4, 0, 1, 3};
// Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
second_shuffle_shape = nvinfer1::DimsCHW(
num_channels * block_size * block_size, h / block_size, w / block_size);
}
if (params->validation_only) return Status::OK();
nvinfer1::IShuffleLayer* first_shuffle =
params->converter->network()->addShuffle(*inputs.at(0).tensor());
TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
if (data_format == "NHWC") {
first_shuffle->setFirstTranspose({2, 0, 1});
}
first_shuffle->setReshapeDimensions(first_shuffle_shape);
first_shuffle->setSecondTranspose(transpose_perm);
nvinfer1::IShuffleLayer* second_shuffle =
params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
second_shuffle->setReshapeDimensions(second_shuffle_shape);
if (data_format == "NHWC") {
second_shuffle->setSecondTranspose({1, 2, 0});
}
params->converter->MarkQuantizationRangesAsInferrable(
inputs.at(0).tensor(), first_shuffle->getOutput(0));
params->converter->MarkQuantizationRangesAsInferrable(
first_shuffle->getOutput(0), second_shuffle->getOutput(0));
params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
return Status::OK();
}
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
Status ConvertCombinedNMS(OpConverterParams* params) {
TF_RETURN_IF_ERROR(
@ -4416,6 +4492,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
(*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["GatherV2"] = ConvertGather;
@ -4430,6 +4507,7 @@ static void RegisterValidatableOpConverters(
(*registration)["Slice"] = ConvertSlice;
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
(*registration)["Softmax"] = ConvertSoftmax;
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
(*registration)["Split"] = ConvertSplit;
(*registration)["Square"] = ConvertSquare;
(*registration)["Squeeze"] = ConvertSqueeze;

View File

@ -212,6 +212,19 @@ std::vector<CType> InitTestVector(int size, CType start_value = CType(0)) {
return res;
}
template <typename InCType, typename OutCType>
struct StaticCaster {
OutCType operator()(InCType in) const { return static_cast<OutCType>(in); }
};
template <typename InCType, typename OutCType>
std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) {
std::vector<OutCType> res(vals.size());
std::transform(vals.begin(), vals.end(), res.begin(),
StaticCaster<InCType, OutCType>());
return res;
}
// Fake ITensor implementation for testing purposes.
class FakeITensor : public nvinfer1::ITensor {
public:
@ -721,19 +734,25 @@ TEST_F(ConverterTest, TransposeTensor) {
ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
}
void TestPrepareTensorForShape_Tensor(
const std::vector<int>& tensor_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, Converter* converter,
void TestPrepareTensorForShape(
const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
Converter* converter, TrtWeightStore* weight_store,
error::Code expected_code = error::OK,
const char* expected_error_msg_substr = nullptr) {
nvinfer1::ITensor* input_tensor = converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(tensor_dims));
TRT_TensorOrWeights input;
if (input_is_tensor) {
input = TRT_TensorOrWeights(converter->network()->addInput(
"", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
} else {
input = TRT_TensorOrWeights(weight_store->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
}
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
const Status status = converter->PrepareTensorForShape(
TRT_TensorOrWeights(input_tensor), GetTestDims(reshape_dims),
validation_only, &output_tensor);
input, GetTestDims(reshape_dims), validation_only, &output_tensor);
if (expected_code == error::OK) {
TF_EXPECT_OK(status);
if (validation_only) {
@ -748,49 +767,45 @@ void TestPrepareTensorForShape_Tensor(
}
}
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
// Shape size doesn't match.
TEST_F(ConverterTest, PrepareTensorForShape) {
for (bool input_is_tensor : {true, false}) {
// Shape size doesn't match.
Reset();
TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
converter_.get(), weight_store_,
error::INVALID_ARGUMENT, "Incompatible shapes");
// Regular shape.
Reset();
TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
converter_.get(), weight_store_);
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
weight_store_);
}
// Tensor input with zero rank.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {2, 3, 6}, {}, converter_.get(),
error::INVALID_ARGUMENT,
"Incompatible shapes");
TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
converter_.get(), weight_store_);
// TODO(aaroey): we should check the case where uninferred dimensions are
// not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
// Infer shape, ok.
// Infer tensor shape, ok.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {-1, 2}, {15, 2},
converter_.get());
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/true, converter_.get(),
weight_store_);
// Regular shape.
// Infer weight shape, should fail.
Reset();
TestPrepareTensorForShape_Tensor({2, 3, 5}, {10, 3}, {10, 3},
converter_.get());
// Input with zero rank.
Reset();
TestPrepareTensorForShape_Tensor({}, {1, 1}, {1, 1}, converter_.get());
// Reshape to zero rank.
Reset();
TestPrepareTensorForShape_Tensor({1, 1}, {}, {}, converter_.get());
}
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
nvinfer1::ITensor* output_tensor = nullptr;
for (bool validation_only : {false, true}) {
TF_EXPECT_OK(converter_->PrepareTensorForShape(
TRT_TensorOrWeights(weights), GetTestDims({10, 3}), validation_only,
&output_tensor));
if (validation_only) {
EXPECT_EQ(nullptr, output_tensor);
} else {
ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
}
}
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
/*input_is_tensor=*/false, converter_.get(),
weight_store_, error::INVALID_ARGUMENT,
"Shape is not fully defined");
}
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
@ -4910,6 +4925,279 @@ TEST_F(OpConverterTest, ConvertArgMinMax) {
// TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
}
// Get the NodeDef for DepthToSpace or SpaceToSpace.
template <typename OpType>
NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
string data_format) {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
auto attrs = OpType::DataFormat(data_format);
auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
return shuffle.operation.node()->def();
}
template <typename CType>
struct DepthSpaceShuffleTestParams {
std::vector<int> input_dims;
std::vector<CType> input_value;
int block_size;
string data_format;
std::vector<int> expected_output_dims;
std::vector<CType> expected_output;
};
template <typename OpType, DataType dtype, typename CType>
void TestConvertDepthSpaceShuffle(
OpConverterTest* test,
const std::vector<DepthSpaceShuffleTestParams<CType>>& params) {
for (int i = 0; i < params.size(); ++i) {
test->Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
dtype, params[i].block_size, params[i].data_format);
test->AddTestTensor("input", params[i].input_dims, 1,
TfDataTypeToTrt(dtype));
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
output.tensor()->getDimensions());
DataVec input_data{{"input", test::AsTensor<CType>(params[i].input_value)}};
DataVec output_data{{"my_shuffle", ConstructTensor<CType>(
params[i].expected_output.size())}};
test->BuildAndRun(
input_data, &output_data,
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAreArray(params[i].expected_output));
}
}
template <DataType dtype>
void TestConvertDepthToSpace(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{4, 2, 2},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
},
{
/*input_shape=*/{2, 2, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 1},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{16, 1, 1},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{1, 4, 4},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{2, 2, 8},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{4, 4, 2},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::DepthToSpace, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertDepthToSpace) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"DepthToSpace got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for DepthToSpace must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to DepthToSpace must be rank 4, at my_shuffle");
}
{
// Channels not divisible by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Number of channels must be divisible by "
"block_size*block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertDepthToSpace<DT_FLOAT>(this);
TestConvertDepthToSpace<DT_HALF>(this);
TestConvertDepthToSpace<DT_INT32>(this);
}
template <DataType dtype>
void TestConvertSpaceToDepth(OpConverterTest* test) {
typedef typename EnumToDataType<dtype>::Type CType;
const std::vector<CType> common_input = InitTestVector<CType>(16);
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{4, 2, 2},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}),
},
{
/*input_shape=*/{4, 4, 1},
/*input_value=*/common_input,
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 4},
/*expected_output=*/
CastTestVector<int, CType>(
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
},
{
/*input_shape=*/{1, 4, 4},
/*input_value=*/common_input,
/*block_size=*/4,
/*data_format=*/"NCHW",
/*expected_output_dims=*/{16, 1, 1},
/*expected_output=*/InitTestVector<CType>(16),
},
{
/*input_shape=*/{4, 4, 2},
/*input_value=*/InitTestVector<CType>(32),
/*block_size=*/2,
/*data_format=*/"NHWC",
/*expected_output_dims=*/{2, 2, 8},
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
9, 10, 11, 4, 5,
6, 7, 12, 13, 14,
15, 16, 17, 18, 19,
24, 25, 26, 27, 20,
21, 22, 23, 28, 29,
30, 31}),
},
};
TestConvertDepthSpaceShuffle<ops::SpaceToDepth, dtype, CType>(test, params);
}
TEST_F(OpConverterTest, ConvertSpaceToDepth) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"SpaceToDepth got 0 inputs but expected 1, at my_shuffle");
}
{
// Input is a weight, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
"The input \"input\" for SpaceToDepth must be a "
"tensor, at my_shuffle");
}
{
// Input rank != 4
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
AddTestTensor("input", {16, 32});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"The input to SpaceToDepth must be rank 4, at my_shuffle");
}
{
// Width not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 9, 32});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Height not divisble by block_size, should fail.
Reset();
NodeDef node_def =
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
AddTestTensor("input", {16, 32, 9});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
"Width and height must be divisible by "
"block_size, at my_shuffle");
}
{
// Unsupported format, should fail.
Reset();
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
DT_FLOAT, 2, "NCHW_VECT_C");
AddTestTensor("input", {16, 32, 32});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Data format NCHW_VECT_C is not supported, at my_shuffle");
}
TestConvertSpaceToDepth<DT_FLOAT>(this);
TestConvertSpaceToDepth<DT_HALF>(this);
TestConvertSpaceToDepth<DT_INT32>(this);
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
@ -290,17 +291,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
VLOG(1) << "Executing TRT calibration: " << name();
helper->Ref();
core::ScopedUnref sc(helper);
auto res_mgr = ctx->resource_manager();
TRTCalibrationResource* calib_res = nullptr;
OP_REQUIRES_OK(ctx,
res_mgr->LookupOrCreate(
"TF_TRT_Calibration", name(),
ctx->resource_manager()->LookupOrCreate(
"TF-TRT-Calibration", name(),
reinterpret_cast<SerializableResourceBase**>(&calib_res),
{[ctx, this](SerializableResourceBase** cr) -> Status {
return this->AllocateCalibrationResources(ctx, cr);
}}));
core::ScopedUnref calib_sc(calib_res);
int num_inputs = ctx->num_inputs();
// TODO(laigd): need to check that input shape matches.
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
@ -425,8 +426,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
const_cast<float*>(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
return kRetry;
buffers[binding_index] =
const_cast<Eigen::half*>(input_tensor.flat<Eigen::half>().data());
break;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
return kRetry;
@ -480,8 +482,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
const_cast<float*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(WARNING) << "half size is not supported yet!";
return kRetry;
buffers[binding_index] =
const_cast<Eigen::half*>(output_tensor->flat<Eigen::half>().data());
break;
case nvinfer1::DataType::kINT8:
LOG(WARNING) << "int8 is not supported yet!";
return kRetry;
@ -522,10 +525,22 @@ EngineContext* TRTEngineOp::GetEngine(
// TODO(tmorris): using first input to get batch size - is this reliable?
const int batch_size = input_shapes[0].dim_size(0);
// Get engine cache
// Canonicalize the op name by removing the scopes if any. This is mainly
// because in TFv2, the function graph can be instantiated in various ways and
// it'll insert scope names to the name of the TRTEngineOps, which will result
// in many different engine caches if we use the instantiated op name
// directly, but we still want all of them share the same cache (if they were
// representing the same subgraph).
absl::string_view resource_name = name();
size_t last_slash = resource_name.find_last_of('/');
if (last_slash != absl::string_view::npos) {
resource_name.remove_prefix(last_slash + 1);
}
// Get engine cache.
TRTEngineCacheResource* cache_res = nullptr;
auto status = ctx->resource_manager()->LookupOrCreate(
"TRTEngineCache", name(), &cache_res,
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return Status::OK();
@ -632,12 +647,13 @@ EngineContext* TRTEngineOp::GetEngine(
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
return &empty_context;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
cache.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
VLOG(1) << "Added new engine to cache of " << name()
<< ". Cache size: " << cache.size();
}
return cache.at(engine_input_shapes).get();
}

View File

@ -0,0 +1,106 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {
using ::testing::ElementsAre;
template <typename T>
class TRTEngineOpTest : public OpsTestBase {};
using TypeList = ::testing::Types<float, Eigen::half>;
TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
TYPED_TEST(TRTEngineOpTest, Basic) {
DataType dtype = DataTypeToEnum<TypeParam>::v();
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
// Create simple TF graph.
Scope s = Scope::NewRootScope();
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
ops::Placeholder::Shape({1, 2}));
auto add = ops::Add(s.WithOpName("add"), feed, feed);
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
GraphDef graph_def;
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
TensorShapeProto shape;
TensorShape({1, 2}).AsProto(&shape);
// Create the op.
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp")
.Input(FakeInput(1, dtype))
.Attr("input_shapes", {shape})
.Attr("output_shapes", {shape})
.Attr("static_engine", false)
.Attr("segment_funcdef_name", "") // no native fallback
.Attr("serialized_segment", graph_def.SerializeAsString())
.Attr("calibration_data", "")
.Attr("max_cached_engines_count", 1)
.Attr("workspace_size_bytes", 1 << 20)
.Attr("precision_mode", "FP32")
.Attr("use_calibration", false)
.Attr("OutT", {dtype})
.Finalize(OpsTestBase::node_def()));
TF_ASSERT_OK(OpsTestBase::InitOp());
// Execute the op.
OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),
{TypeParam(0.0f), TypeParam(1.0f)});
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = OpsTestBase::context_->mutable_output(0);
const auto& tensor_map = output->flat<TypeParam>();
std::vector<TypeParam> output_data(tensor_map.size());
ASSERT_EQ(0, cudaDeviceSynchronize());
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,223 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/logging.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
using ::nvinfer1::IRuntime;
class CreateTRTEngineCache : public OpKernel {
public:
explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
}
void Compute(OpKernelContext* ctx) override {
VLOG(1) << "Creating TRT engine cache resource in container " << container_
<< " for op " << resource_name_ << " on device "
<< ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Create(
container_, resource_name_,
new TRTEngineCacheResource(ctx, max_cached_engines_)));
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() =
MakeResourceHandle<TRTEngineCacheResource>(ctx, container_,
resource_name_);
}
private:
string container_;
string resource_name_;
// Maximum number of cached engines
int max_cached_engines_;
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
CreateTRTEngineCache);
class PopulateTRTEngineCache : public OpKernel {
public:
explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
ResourceHandle handle = HandleFromInput(ctx, 0);
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
core::ScopedUnref unref_me(resource);
auto allocator = resource->allocator_.get();
OP_REQUIRES(ctx, allocator != nullptr,
errors::Internal("Not able to initialize TRT engine cache when "
"GPU allocator is empty."));
OP_REQUIRES(ctx, resource->cache_.size() == 0,
errors::Internal("Expect engine cache to be empty, but got ",
resource->cache_.size(), " entries."));
// Get the file name.
const string& filename = ctx->input(1).scalar<string>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
// Parse the serialized engines and add them to the cache.
std::unique_ptr<RandomAccessFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
auto reader = absl::make_unique<io::RecordReader>(file.get());
uint64 offset = 0;
int num_loaded_engine = 0;
do {
string record;
Status status = reader->ReadRecord(&offset, &record);
if (errors::IsOutOfRange(status)) break;
TRTEngineInstance engine_instance;
engine_instance.ParseFromString(record);
std::vector<TensorShape> engine_input_shapes;
for (const TensorShapeProto& shape : engine_instance.input_shapes()) {
engine_input_shapes.emplace_back(shape);
}
TrtUniquePtrType<IRuntime> infer(
nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
infer->setGpuAllocator(allocator);
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
infer->deserializeCudaEngine(
engine_instance.serialized_engine().c_str(),
engine_instance.serialized_engine().size(),
PluginFactoryTensorRT::GetInstance()));
auto raw_engine = engine.get();
resource->cache_.emplace(
engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_engine->createExecutionContext())));
++num_loaded_engine;
} while (1);
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container "
<< handle.container() << " for op " << handle.name()
<< " on device " << ctx->device()->name() << " from file "
<< filename;
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache")
.Device(DEVICE_GPU)
.HostMemory("engine_cache_handle"),
PopulateTRTEngineCache);
class DumpTRTEngineCache : public OpKernel {
public:
explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
&delete_cache_after_dump_));
}
void Compute(OpKernelContext* ctx) override {
const string& container = ctx->input(0).scalar<string>()();
const string& resource_name = ctx->input(1).scalar<string>()();
const string& filename = ctx->input(2).scalar<string>()();
OP_REQUIRES(ctx, !filename.empty(),
errors::InvalidArgument("filename cannot be empty."));
TRTEngineCacheResource* resource = nullptr;
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
container, resource_name, &resource));
core::ScopedUnref unref_me(resource);
// Serialize the engines and write them to file.
std::unique_ptr<WritableFile> file;
OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
auto writer = absl::make_unique<io::RecordWriter>(file.get());
for (const auto& pair : resource->cache_) {
TRTEngineInstance engine_instance;
// Add input shapes.
const std::vector<TensorShape>& engine_input_shapes = pair.first;
for (const TensorShape& shape : engine_input_shapes) {
shape.AsProto(engine_instance.add_input_shapes());
}
// Add the serialized engine.
const std::unique_ptr<EngineContext>& engine = pair.second;
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
engine->cuda_engine->serialize());
engine_instance.set_serialized_engine(engine_data->data(),
engine_data->size());
OP_REQUIRES_OK(ctx,
writer->WriteRecord(engine_instance.SerializeAsString()));
}
VLOG(1) << "Serialized " << resource->cache_.size()
<< " TRT engines in container " << container << " for op "
<< resource_name << " on device " << ctx->device()->name()
<< " to file " << filename;
if (delete_cache_after_dump_) {
VLOG(1) << "Destroying TRT engine cache resource in container "
<< container << " for op " << resource_name << " on device "
<< ctx->device()->name();
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
container, resource_name));
}
}
private:
bool delete_cache_after_dump_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache);
};
REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU),
DumpTRTEngineCache);
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,205 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/io/record_reader.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
class TRTEngineResourceOpsTest : public OpsTestBase {
protected:
void Reset() {
inputs_.clear();
gtl::STLDeleteElements(&tensors_);
gtl::STLDeleteElements(&managed_outputs_);
}
TrtUniquePtrType<nvinfer1::ICudaEngine> CreateTRTEngine() {
Logger logger;
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(logger));
TrtUniquePtrType<nvinfer1::INetworkDefinition> network(
builder->createNetwork());
// Add the input.
nvinfer1::Dims dims;
dims.nbDims = 1;
dims.d[0] = 1;
nvinfer1::ITensor* input =
network->addInput("input", nvinfer1::DataType::kFLOAT, dims);
EXPECT_NE(nullptr, input);
// Add a unary layer.
nvinfer1::IUnaryLayer* layer =
network->addUnary(*input, nvinfer1::UnaryOperation::kEXP);
EXPECT_NE(nullptr, layer);
// Mark the output.
nvinfer1::ITensor* output = layer->getOutput(0);
output->setName("output");
network->markOutput(*output);
// Build the engine
builder->setMaxBatchSize(1);
builder->setMaxWorkspaceSize(1 << 10);
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
builder->buildCudaEngine(*network));
EXPECT_NE(nullptr, engine);
return engine;
}
};
TEST_F(TRTEngineResourceOpsTest, Basic) {
// Create the GPU device.
std::unique_ptr<Device> device(
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
ResourceMgr* rm = device->resource_manager();
SetDevice(DEVICE_GPU, std::move(device));
// Create the resource.
const string container = "mycontainer";
const string resource_name = "myresource";
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
.Attr("container", container)
.Attr("resource_name", resource_name)
.Attr("max_cached_engines_count", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
TF_ASSERT_OK(RunOpKernel());
ResourceHandle handle =
context_->mutable_output(0)->scalar<ResourceHandle>()();
TRTEngineCacheResource* resource = nullptr;
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
// Create a serialized TRT engine file.
TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine();
TrtUniquePtrType<nvinfer1::IExecutionContext> context(
engine->createExecutionContext());
resource->cache_.emplace(
std::vector<TensorShape>{TensorShape({1, 1})},
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
resource->Unref();
// Serialize the engine using DumpTRTEngineCache op.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache")
.Attr("delete_cache_after_dump", true)
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<string>(TensorShape({}), {container});
AddInputFromArray<string>(TensorShape({}), {resource_name});
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
AddInputFromArray<string>(TensorShape({}), {filename});
TF_ASSERT_OK(RunOpKernel());
// Make sure the cache is deleted.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
.Attr("ignore_lookup_error", false)
.Input(FakeInput(DT_RESOURCE))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
// Verify the serialized engine file.
Env* env = Env::Default();
std::unique_ptr<RandomAccessFile> file;
TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file));
auto reader = absl::make_unique<io::RecordReader>(file.get());
uint64 offset = 0;
string record;
TF_ASSERT_OK(reader->ReadRecord(&offset, &record));
TRTEngineInstance engine_instance;
engine_instance.ParseFromString(record);
EXPECT_EQ(1, engine_instance.input_shapes_size());
EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size());
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size());
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size());
EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record)));
// Recreate the cache resource.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
.Attr("container", container)
.Attr("resource_name", resource_name)
.Attr("max_cached_engines_count", 1)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
TF_ASSERT_OK(RunOpKernel());
handle = context_->mutable_output(0)->scalar<ResourceHandle>()();
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
EXPECT_EQ(0, resource->cache_.size());
resource->Unref();
// Deserialize the engine using PopulateTRTEngineCache op.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
.Input(FakeInput(DT_RESOURCE))
.Input(FakeInput(DT_STRING))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
AddInputFromArray<string>(TensorShape({}), {filename});
TF_ASSERT_OK(RunOpKernel());
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
EXPECT_EQ(1, resource->cache_.size());
resource->Unref();
// Destroy the engine cache again.
Reset();
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
.Attr("ignore_lookup_error", false)
.Input(FakeInput(DT_RESOURCE))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
TF_ASSERT_OK(RunOpKernel());
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,52 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
REGISTER_OP("CreateTRTEngineCache")
.Attr("container: string")
.Attr("resource_name: string")
.Attr("max_cached_engines_count: int = 1")
.Output("engine_cache_handle: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("PopulateTRTEngineCache")
.Input("engine_cache_handle: resource")
.Input("filename: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("DumpTRTEngineCache")
.Attr("delete_cache_after_dump: bool = false")
.Input("container: string")
.Input("resource_name: string")
.Input("filename: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -459,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph,
}
LOG(INFO) << msg << "(For more information see "
<< "https://docs.nvidia.com/deeplearning"
<< "/dgx/integrate-tf-trt/index.html#support-ops).";
<< "/dgx/tf-trt-user-guide/index.html#supported-ops).";
// The segmentation algorithm below visits nodes in reverse topological order
// and attempts to merge nodes along output edges. That means that subgraphs

View File

@ -0,0 +1,19 @@
syntax = "proto3";
package tensorflow.tensorrt;
import "tensorflow/core/framework/tensor_shape.proto";
// Containing information for a serialized TensorRT engine.
message TRTEngineInstance {
// The input shapes of the TRT engine.
repeated TensorShapeProto input_shapes = 1;
// The serialized TRT engine.
//
// TODO(laigd): consider using a more efficient in-memory representation
// instead of string which is the default here.
bytes serialized_engine = 2;
// TODO(laigd): consider adding calibration stats, precision_modes, etc.
}

View File

@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include <sstream>
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/mutex.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "tensorrt/include/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
Logger& TRTEngineCacheResource::GetLogger() {
static Logger* logger = new Logger();
return *logger;
}
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
TRTEngineCacheResource::~TRTEngineCacheResource() {
VLOG(1) << "Destroying TRTEngineCacheResource...";
}
string TRTEngineCacheResource::DebugString() const {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
mutex_lock lock(item.second->mu);
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second->execution_context.get() << dec
<< endl;
}
return oss.str();
}
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
@ -141,36 +142,18 @@ struct EngineContext {
class TRTEngineCacheResource : public ResourceBase {
public:
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
: cache_(capacity) {
auto device = ctx->device();
auto alloc = device->GetAllocator(AllocatorAttributes());
if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
allocator_ = nullptr;
} else {
allocator_.reset(new TRTDeviceAllocator(alloc));
}
}
// According to the TensorRT API, the logger is considered a singleton by the
// TensorRT library, and multiple instances of IRuntime and/or IBuilder must
// all use the same logger. So here we make it a singleton.
//
// TODO(laigd): use this logger in all places where conversion happens.
static Logger& GetLogger();
string DebugString() const override {
std::stringstream oss;
using std::dec;
using std::endl;
using std::hex;
oss << "TRTEngineCacheResource: ";
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
oss << "LRUCache = " << hex << &cache_ << dec << endl;
oss << "Containing " << cache_.size() << " entries: " << endl;
for (const auto& item : cache_) {
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
<< "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
<< "IExecutionContext: " << item.second.get()->execution_context.get()
<< dec << endl;
}
return oss.str();
}
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
~TRTEngineCacheResource() override;
string DebugString() const override;
// Keep device allocator for TRT.
std::unique_ptr<TRTBaseAllocator> allocator_;

View File

@ -210,6 +210,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
@ -376,6 +377,7 @@ tf_cc_test(
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@ -390,6 +392,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -100,9 +100,12 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
arg.name = resource->name();
break;
}
case XlaExpression::Kind::kTensorList:
return errors::Unimplemented(
"TensorList as function argument is not yet implemented.");
case XlaExpression::Kind::kTensorList: {
arg.kind = XlaCompiler::Argument::kTensorList;
const xla::XlaOp& tensor_list = expressions[i]->handle();
arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
break;
}
case XlaExpression::Kind::kInvalid:
return errors::InvalidArgument("Invalid function argument");
}
@ -302,8 +305,13 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
if (result.outputs[i].is_constant) {
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
xla_op_context.SetOutput(
i, xla::GetTupleElement(output_handle, computation_output));
if (result.outputs[i].is_tensor_list) {
xla_op_context.SetTensorListOutput(
i, xla::GetTupleElement(output_handle, computation_output));
} else {
xla_op_context.SetOutput(
i, xla::GetTupleElement(output_handle, computation_output));
}
++computation_output;
}
}

View File

@ -58,6 +58,7 @@ class AddNOp : public XlaOpKernel {
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &push_index));
OP_REQUIRES_OK(ctx, BuildTensorList(sum, push_index, &sum));
ctx->SetTensorListOutput(0, sum);
break;
}
default:
@ -65,9 +66,8 @@ class AddNOp : public XlaOpKernel {
for (int i = 1; i < ctx->num_inputs(); ++i) {
sum = xla::Add(sum, ctx->Input(i));
}
ctx->SetOutput(0, sum);
}
ctx->SetOutput(0, sum);
}
private:

View File

@ -44,6 +44,7 @@ class BatchMatMulOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp);
REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp);
} // namespace
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
// XLA implementations of Categorical op.
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@ -35,7 +36,9 @@ namespace {
class CategoricalOp : public XlaOpKernel {
public:
explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
explicit CategoricalOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void Compile(XlaOpKernelContext* ctx) override {
// Get the logits
@ -100,8 +103,15 @@ class CategoricalOp : public XlaOpKernel {
xla::PrimitiveType xla_output_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
xla::XlaOp argmax;
if (is_gpu_) {
argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
} else {
argmax = xla::ArgMax(softmax_entries, xla_output_type,
/*axis=*/class_dimension);
}
if (num_samples == 1) {
argmax = xla::Reshape(argmax, {batch_size, 1});
}
@ -123,6 +133,7 @@ class CategoricalOp : public XlaOpKernel {
}
private:
bool is_gpu_;
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
};
@ -140,8 +151,6 @@ class StatelessCategoricalOp : public CategoricalOp {
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
XlaOpKernelContext* ctx) override {
xla::XlaOp seed = ctx->Input(2);
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaBuilder* builder = ctx->builder();
if (uniform_shape.element_type() == xla::BF16) {
@ -150,8 +159,8 @@ class StatelessCategoricalOp : public CategoricalOp {
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
// * log(-log(0)) is ∞.
// * log(-log(1)) is -∞.
auto uniforms = xla::StatelessRngUniform(
{seed0, seed1}, uniform_shape,
xla::XlaOp uniforms = StatelessRngUniform(
seed, uniform_shape,
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
xla::One(builder, uniform_shape.element_type()));
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);

View File

@ -16,6 +16,7 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
@ -293,10 +294,9 @@ xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
return attrs;
}
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs) {
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = conv_input.builder();
@ -377,12 +377,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
return xla::ConvGeneralDilated(
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
dims,
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count);
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
int num_dims = attrs.num_spatial_dims + 2;
@ -456,13 +458,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
/*feature_group_count=*/
attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
filter_shape.dimensions(attrs.num_spatial_dims + 1)
: feature_group_count);
: feature_group_count,
/*batch_group_count=*/1, precision_config);
}
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs) {
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
auto* builder = activations.builder();
@ -612,7 +615,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
rhs_dilation, dnums,
/*feature_group_count=*/feature_group_count,
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1,
precision_config);
if (!use_batch_group_count && attrs.depthwise) {
filter_backprop = ContractFilterForDepthwiseBackprop(

View File

@ -53,17 +53,19 @@ struct ConvOpAttrs {
// Creates a new XLA forward or backward convolution with the given inputs and
// attributes.
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
xla::XlaOp conv_input,
xla::XlaOp filter,
const ConvOpAttrs& attrs);
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter,
const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
StringPiece type_string, xla::XlaOp activations,
const xla::Shape& filter_shape, xla::XlaOp gradients,
const ConvOpAttrs& attrs);
const ConvOpAttrs& attrs,
const xla::PrecisionConfig* precision_config = nullptr);
} // namespace tensorflow

View File

@ -82,33 +82,71 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
}
// Builds a custom_call to a method named 'fake_quant_with_min_max_vars'.
// The method will be provided the input, the min/max range from the original
// TensorFlow op, and the num_bits and narrow_range attributes.
xla::StatusOr<xla::XlaOp> BuildFakeQuantCustomCall(
xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min,
xla::XlaOp input_max, int num_bits, bool narrow_range) {
xla::XlaOp num_bits_arg =
XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits);
xla::XlaOp narrow_range_arg = narrow_range
? XlaHelpers::One(b, DataType::DT_BOOL)
: XlaHelpers::Zero(b, DataType::DT_BOOL);
std::vector<xla::XlaOp> args = {input, input_min, input_max, num_bits_arg,
narrow_range_arg};
std::vector<xla::Shape> arg_shapes;
for (const xla::XlaOp& arg : args) {
TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg));
*arg_shape.mutable_layout() =
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
arg_shapes.push_back(std::move(arg_shape));
}
// Input and output shapes match exactly.
TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input));
return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args,
output_shape, arg_shapes);
}
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
num_bits_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
quant_min_ = narrow_range_ ? 1 : 0;
quant_max_ = (1 << num_bits_) - 1;
float input_min, input_max;
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_));
CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_,
&nudged_input_max_, &input_scale_);
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::XlaBuilder* b = ctx->builder();
if (ctx->compiler()->options().allow_cpu_custom_calls &&
ctx->compiler()->options().custom_fake_quant_op_calls) {
xla::XlaOp custom_call_output =
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
b, input,
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_),
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_),
num_bits_, narrow_range_));
ctx->SetOutput(0, custom_call_output);
return;
}
xla::XlaOp nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
xla::XlaOp nudged_input_max =
@ -121,6 +159,10 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
}
private:
int num_bits_;
bool narrow_range_;
float input_min_;
float input_max_;
float quant_min_;
float quant_max_;
float nudged_input_min_;
@ -184,25 +226,32 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
num_bits_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
quant_min_ = narrow_range_ ? 1 : 0;
quant_max_ = (1 << num_bits_) - 1;
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::XlaOp input_min = ctx->Input(1);
xla::XlaOp input_max = ctx->Input(2);
xla::XlaBuilder* b = ctx->builder();
if (ctx->compiler()->options().allow_cpu_custom_calls &&
ctx->compiler()->options().custom_fake_quant_op_calls) {
xla::XlaOp custom_call_output =
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
b, input, input_min, input_max, num_bits_, narrow_range_));
ctx->SetOutput(0, custom_call_output);
return;
}
xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
@ -213,6 +262,8 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
}
private:
int num_bits_;
bool narrow_range_;
float quant_min_;
float quant_max_;
};

View File

@ -31,7 +31,9 @@ limitations under the License.
namespace tensorflow {
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
: XlaOpKernel(ctx), is_min_(is_min) {}
: XlaOpKernel(ctx),
is_min_(is_min),
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
const TensorShape input_shape = ctx->InputShape(0);
@ -64,10 +66,19 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
// One pass ArgMin/ArgMax is slow on GPUs.
if (is_min_) {
output = xla::ArgMin(input, index_xla_type, axis);
if (is_gpu_) {
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMin(input, index_xla_type, axis);
}
} else {
output = xla::ArgMax(input, index_xla_type, axis);
if (is_gpu_) {
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
} else {
output = xla::ArgMax(input, index_xla_type, axis);
}
}
ctx->SetOutput(0, output);

View File

@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel {
private:
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
const bool is_gpu_;
};
class XlaArgMaxOp : public XlaArgMinMaxOp {

View File

@ -22,6 +22,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
// Returns a tensor containing 'shape' random values uniformly distributed in
// the range [minval, maxval). The raw random bits are generated by the given
// `bit_generator` and converted to the requested data type and range. This
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
// of type F32, S32 and S64.
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval);
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
// It masks the last 16 bit. With normal rounding, values near "maxval" would be

View File

@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
TensorShape shape;
int64 product = 1;
int unknown_index = -1;
bool shape_has_zero_dim = false;
for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d];
if (size == -1) {
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
unknown_index, " and ", d));
unknown_index = d;
shape.AddDim(1);
} else if (size == 0) {
// We don't include zero-sized dimension in product, so that we can
// still calculate number of elements for non-zero-sized dimensions and
// therefore infer their shapes.
shape.AddDim(size);
shape_has_zero_dim = true;
} else {
OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument(
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
}
}
if (unknown_index != -1) {
OP_REQUIRES(
ctx, product > 0,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = input_shape.num_elements() / product;
OP_REQUIRES(
ctx, product * missing == input_shape.num_elements(),
errors::InvalidArgument(
"Input to reshape is a tensor with ", input_shape.num_elements(),
" values, but the requested shape requires a multiple of ",
product));
int64 input_num_elements = 1;
bool input_has_zero_dim = false;
for (int dim = 0; dim < input_shape.dims(); dim++) {
// For zero dimension, we don't count it into `input_num_elements`
// unless `sizes` has no zero dimension, so we are still able to
// infer shapes for other dimensions.
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
input_num_elements *= input_shape.dim_size(dim);
} else {
input_has_zero_dim = true;
}
}
const int64 missing = input_num_elements / product;
if (!input_has_zero_dim) {
OP_REQUIRES(
ctx, product * missing == input_num_elements,
errors::InvalidArgument(
"Input to reshape is a tensor with ", input_num_elements,
" values, but the requested shape requires a multiple of ",
product));
}
shape.set_dim(unknown_index, missing);
}
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),

View File

@ -16,6 +16,8 @@ limitations under the License.
// XLA-specific Shape Ops.
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -223,14 +225,33 @@ class ZerosLikeOp : public XlaOpKernel {
explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
if (IsTensorListInput(ctx, 0)) {
// Input is a TensorList.
// TODO(b/124707753): support nested TensorList.
xla::XlaOp tensor_list = ctx->Input(0);
TensorShape shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(tensor_list, &shape));
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, GetTensorListPrimitiveType(tensor_list, &type));
xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, shape, type, &buffer));
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
xla::XlaOp push_index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tensor_list, &push_index));
xla::XlaOp output_list;
OP_REQUIRES_OK(ctx, BuildTensorList(buffer, push_index, &output_list));
ctx->SetTensorListOutput(0, output_list);
} else {
const TensorShape input_shape = ctx->InputShape(0);
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
}
}
};
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp);
class OnesLikeOp : public XlaOpKernel {
public:

View File

@ -35,127 +35,50 @@ limitations under the License.
namespace tensorflow {
namespace {
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
xla::XlaOp counter, const int64 size) {
auto builder = counter.builder();
auto input_u64 = Iota(builder, xla::U64, size);
input_u64 = input_u64 + counter;
counter = counter + xla::ConstantR0<uint64>(builder, size);
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
}
// `StatelessRngUniformU32` uses ThreeFry2x32s counter space too
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
// the real capacity is 2^64*2. Counter-space efficiency is important for
// stateful ops, hence the following 2 new functions.
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
auto builder = key.builder();
const int64 size = xla::ShapeUtil::ElementsIn(shape);
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
auto inputs_counter = GetInputsFromCounter(counter, half_size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto result = ConcatInDim(builder, outputs, 0);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
const int64 size = xla::ShapeUtil::ElementsIn(shape);
auto inputs_counter = GetInputsFromCounter(counter, size);
auto inputs = inputs_counter.first;
counter = inputs_counter.second;
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
auto result = Uint32sToUint64(outputs);
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
counter);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
xla::XlaOp counter,
const xla::Shape& shape,
xla::XlaOp minval,
xla::XlaOp maxval) {
auto builder = key.builder();
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
const xla::Shape& shape, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
counter);
}
case xla::U32: // fall through
case xla::S32: {
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
counter);
}
case xla::U64: // fall through
case xla::S64: {
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
auto bits = bits_counter.first;
counter = bits_counter.second;
return std::make_pair(
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
counter);
}
case xla::F32:
return xla::UniformF32Distribution(
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
case xla::U32:
case xla::S32:
case xla::U64:
case xla::S64:
return UniformIntDistribution(
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
default:
return std::make_pair(
builder->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
return {key.builder()->ReportError(xla::Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by "
"StatefulRngUniform; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
template <typename A, typename B, typename A2>
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
return std::make_pair(f(p.first), p.second);
}
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
xla::XlaOp initial_state,
const xla::Shape& shape) {
xla::PrimitiveType type = shape.element_type();
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
switch (type) {
case xla::U32:
return StatefulRngUniformU32(key, counter, shape);
case xla::S32: {
// Needs explicit function type because of type-inference failure.
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S32);
};
return map_first(f, StatefulRngUniformU32(key, counter, shape));
}
case xla::U64:
return StatefulRngUniformU64(key, counter, shape);
case xla::S64: {
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
return BitcastConvertType(x, xla::S64);
};
return map_first(f, StatefulRngUniformU64(key, counter, shape));
}
return output;
case xla::S32:
case xla::S64:
output.value = BitcastConvertType(output.value, type);
return output;
default:
auto builder = key.builder();
return std::make_pair(
builder->ReportError(xla::Unimplemented(
return {
key.builder()->ReportError(xla::Unimplemented(
"Types other than U32, S32, U64 and S64 are not implemented by "
"StatefulRngUniformFullInt; got: %s",
xla::primitive_util::LowercasePrimitiveTypeName(type))),
counter);
initial_state};
}
}
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
0);
}
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
// A helper function containing the common part of several kernels below.
// Precondition: 'algorithm' and 'shape' are compile-time constants.
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
int alg_input_idx, int shape_input_idx,
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
TensorShape)> const&
sample_with_threefry) {
Status CompileImpl(
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
int shape_input_idx,
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
sampler) {
auto alg_shape = ctx->InputShape(alg_input_idx);
if (alg_shape.dims() != 0) {
return errors::InvalidArgument("algorithm must be of shape [], not ",
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
TensorShape shape;
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
static constexpr int COUNTER_SIZE = 1;
auto counter = BitcastConvertType(
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
static constexpr int kStateSize = 1;
auto state = BitcastConvertType(
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
auto key = BitcastConvertType(
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
{}),
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
xla::U64);
auto status_or_value = sample_with_threefry(counter, key, shape);
auto status_or_value = sampler(state, key, shape);
if (!status_or_value.ok()) {
return status_or_value.status();
}
auto output_counter = status_or_value.ConsumeValueOrDie();
auto output = output_counter.first;
counter = output_counter.second;
ctx->SetOutput(0, output);
auto builder = ctx->builder();
var = ConcatScalars(builder, {counter, key});
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
state = value_state.state;
ctx->SetOutput(0, value_state.value);
xla::XlaBuilder* builder = ctx->builder();
var = ConcatScalars(builder, {state, key});
xla::PrimitiveType state_element_type;
TF_RETURN_IF_ERROR(
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry = [builder, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::XlaBuilder* builder = ctx->builder();
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::RngOutput uniform_state = StatefulRngUniform(
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
xla::XlaOp uniform = uniform_state.value;
state = uniform_state.state;
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
return {{uniform, counter}};
return {{uniform, state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
auto sampler =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
[this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
xla::ConstantR0<float>(builder, 1.0));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
auto normal =
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
normal = MaybeConvertF32ToBF16(normal, dtype_);
return {{normal, counter}};
xla::RngOutput value_state = xla::NormalF32Distribution(
key, state, xla::ThreeFryBitGenerator, xla_shape);
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
return {{normal, value_state.state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
auto sample_with_threefry =
xla::XlaBuilder* builder = ctx->builder();
auto sampler =
// Needs explicit lambda return type because it fails to be inferred.
[builder, this](xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
[builder, this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
auto uniform_counter = StatefulRngUniform(
key, counter, xla_shape,
xla::RngOutput uniform_result = StatefulRngUniform(
key, state, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
auto uniform = uniform_counter.first;
counter = uniform_counter.second;
xla::XlaOp uniform = uniform_result.value;
state = uniform_result.state;
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
return {{truncated_normal, counter}};
return {{truncated_normal, state}};
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
/*shape_input_idx=*/2, sample_with_threefry));
/*shape_input_idx=*/2, sampler));
}
private:
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
xla::XlaOp minval = ctx->Input(3);
xla::XlaOp maxval = ctx->Input(4);
auto sample_with_threefry = [minval, maxval, this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto sample_with_threefry = [this](
xla::XlaOp counter, xla::XlaOp key,
TensorShape shape) -> sampler_return_type {
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
TensorShape shape) -> SamplerReturnType {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
return StatefulRngUniformFullInt(key, counter, xla_shape);
return StatefulRngUniformFullInt(key, state, xla_shape);
};
OP_REQUIRES_OK(ctx,
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,

View File

@ -36,8 +36,8 @@ namespace tensorflow {
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
if (dtype == DT_BFLOAT16) {
xla::XlaBuilder* builder = input.builder();
auto output = xla::BitcastConvertType(input, xla::U32) &
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
xla::BF16);
} else {
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
}
}
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
}
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval) {
xla::XlaBuilder* builder = seeds.builder();
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
// values with uniform distribution in the range [minval, maxval) for the given
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
// S64 are implemented.
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
xla::XlaOp seed, xla::XlaOp minval,
xla::XlaOp maxval) {
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
switch (type) {
case xla::F32:
return xla::UniformF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, minval,
maxval, shape)
.value;
case xla::S32: // fall through
case xla::S64:
return UniformIntDistribution(key, initial_state,
xla::ThreeFryBitGenerator, minval, maxval,
shape)
.value;
break;
default:
return builder->ReportError(xla::Unimplemented(
"Types other than F32, S32 and S64 are not implemented by "
"StatelessRngUniform; got %s",
xla::primitive_util::LowercasePrimitiveTypeName(type)));
}
}
namespace {
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
xla::ConstantR0<float>(builder, 1.0));
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
ctx->SetOutput(0, uniform);
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
xla::XlaOp uniform =
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
ctx->SetOutput(0, uniform);
}
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
errors::InvalidArgument("seed must have shape [2], not ",
seed_shape.DebugString()));
xla::XlaOp seed = ctx->Input(1);
xla::XlaBuilder* builder = ctx->builder();
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed,
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
xla::ConstantR0<float>(builder, 1.0));
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
xla::XlaBuilder* builder = seed.builder();
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(builder, xla::U64, 32));
xla::XlaOp normal =
xla::NormalF32Distribution(key, initial_state,
xla::ThreeFryBitGenerator, xla_shape)
.value;
normal = MaybeConvertF32ToBF16(normal, dtype_);
ctx->SetOutput(0, normal);
}
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
xla::XlaOp uniform = StatelessRandomUniformImpl(
xla_shape, dtype_, seed,
xla::XlaOp uniform = StatelessRngUniform(
seed, xla_shape,
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
xla::One(builder, xla_shape.element_type()));
xla::XlaOp truncated_normal = TruncatedNormal(uniform);

View File

@ -47,16 +47,19 @@ class TensorListLengthOp : public XlaOpKernel {
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp index;
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
ctx->SetOutput(0, index);
TensorShape buffer_shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &buffer_shape));
Tensor length_tensor(DT_INT32, {});
length_tensor.scalar<int32>()() =
static_cast<int32>(buffer_shape.dim_size(0));
ctx->SetConstantOutput(0, length_tensor);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
};
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
// Creates an empty list with size (leading_dim, *element_shape) if
// element_shape is known at compile time. Otherwise creates one with size

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape.h"
@ -35,6 +36,17 @@ Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
return Status::OK();
}
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
xla::PrimitiveType* type) {
TF_RET_CHECK(op.builder());
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
op.builder()->GetShape(op));
xla::Shape buffer_shape =
xla::ShapeUtil::GetTupleElementShape(list_tuple_shape, 0);
*type = buffer_shape.element_type();
return Status::OK();
}
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
TF_RET_CHECK(op.builder());
*buffer = xla::GetTupleElement(op, 0);
@ -97,4 +109,12 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
return BuildTensorList(new_buffer, push_index, output_list);
}
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
xla::PrimitiveType type, xla::XlaOp* list) {
auto zero =
xla::ConstantLiteral(ctx->builder(), xla::LiteralUtil::Zero(type));
*list = xla::Broadcast(zero, buffer_shape.dim_sizes());
return Status::OK();
}
} // namespace tensorflow

View File

@ -35,6 +35,10 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
xla::XlaOp* output_list);
// Returns XLA PrimitiveType for the TensorList.
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
xla::PrimitiveType* type);
// Returns the buffer for the TensorList.
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
@ -62,6 +66,10 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
const TensorShape& buffer_shape,
xla::XlaOp* output_list);
// Returns a TensorList filled with zero.
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
xla::PrimitiveType type, xla::XlaOp* list);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_

View File

@ -529,7 +529,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
int resource_index = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
if (IsTensorListInput(ctx, i)) {
ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
} else {
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
}
++resource_index;
} else {
break;

View File

@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, retval_nodes);
PruneForReverseReachability(graph, std::move(retval_nodes));
FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
@ -295,6 +295,8 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
compiler_options.custom_fake_quant_op_calls =
config.conversion_options().custom_fake_quant_op_calls();
XlaCompiler compiler(compiler_options);
XlaCompiler::CompilationResult result;

View File

@ -53,6 +53,14 @@ message Variable {
bool readonly = 5;
}
// Options used during the conversion and compilation process.
message ConversionOptions {
// When true tf.fake_quant_* ops will be emitted as custom calls to a
// 'fake_quant_with_min_max_vars' function accepting the input, min, max,
// num_bits, and narrow_range values as runtime arguments.
bool custom_fake_quant_op_calls = 1;
}
// Config represents configuration information for tf2xla conversion.
message Config {
// Each feed is a positional input argument for the generated computation.
@ -63,4 +71,6 @@ message Config {
repeated Fetch fetch = 2;
// Each variable is a named input and output of the generated computation.
repeated Variable variable = 3;
// Optional conversion options.
ConversionOptions conversion_options = 4;
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@ -200,8 +201,13 @@ Status BuildComputation(
output.shape = output.constant_value.shape();
break;
case XlaExpression::Kind::kTensorList:
TF_FALLTHROUGH_INTENDED;
case XlaExpression::Kind::kTensorList: {
output.is_tensor_list = true;
xla::XlaOp value = retval.handle();
elems.push_back(value);
break;
}
case XlaExpression::Kind::kXlaOp: {
output.is_constant = false;
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
@ -403,6 +409,8 @@ string XlaCompiler::Argument::HumanString() const {
}
case kParameter:
return absl::StrCat("kind=parameter", common);
case kTensorList:
return absl::StrCat("kind=tensorlist", common);
case kToken:
return absl::StrCat("token", common);
}
@ -641,6 +649,11 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
}
return Status::OK();
}
case XlaCompiler::Argument::kTensorList: {
TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
*xla_shape = absl::get<xla::Shape>(arg.shape);
return Status::OK();
}
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
@ -744,6 +757,7 @@ Status XlaCompiler::BuildArguments(
break;
}
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kTensorList:
case XlaCompiler::Argument::kToken: {
input_to_args->push_back(i);
break;
@ -902,6 +916,10 @@ Status XlaCompiler::BuildArguments(
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
}
break;
case XlaCompiler::Argument::kTensorList: {
arg_expression = XlaExpression::TensorList(arg_handles[i]);
break;
}
case XlaCompiler::Argument::kToken: {
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
break;

View File

@ -116,6 +116,9 @@ class XlaCompiler {
// Argument is an XLA token.
kToken,
// Argument is a TensorList.
kTensorList,
};
Kind kind = kInvalid;
@ -226,6 +229,9 @@ class XlaCompiler {
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
// the index of the input that contains the resource.
int input_index;
// Whether this output is a TensorList.
bool is_tensor_list = false;
};
// Describes a variable write side effect of the computation.
@ -305,6 +311,12 @@ class XlaCompiler {
// for CPU.
bool allow_cpu_custom_calls = false;
// If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_*
// ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars'
// function accepting the input, min, max, num_bits, and narrow_range values
// as runtime arguments.
bool custom_fake_quant_op_calls = false;
// If set, the XLA representation of variables represented to XLA as the
// shape given by this shape function. Variables are reshaped to this shape
// on write, and reshaped to their original shape on read.

View File

@ -14,10 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@ -35,7 +40,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@ -1498,5 +1505,76 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
}
}
TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
FunctionDefLibrary fdef_lib;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
// Build cond fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
auto result = ops::Const<bool>(scope, {true}, {});
ops::_Retval(scope.WithOpName("ret"), result, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
// Build body fn for While.
{
Scope scope = Scope::NewRootScope().ExitOnError();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
ops::_Retval(scope.WithOpName("ret"), arg, 0);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
FunctionDef fdef;
TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
}
Scope scope = Scope::NewRootScope().ExitOnError();
auto element_shape = ops::Const<int32>(scope, {1}, {1});
auto max_elements = ops::Const<int32>(scope, {10}, {});
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
std::initializer_list<Output> out = {arg, arg};
auto add_n = ops::AddN(scope, out);
NameAttrList cond_fn, body_fn;
cond_fn.set_name("cond");
body_fn.set_name("body");
auto while_op =
ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kTensorList;
xla::Shape tensor_list_element_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
&tensor_list_element_shape));
xla::Shape index_shape;
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
args[0].shape = arg_shape;
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args,
/*user_aliases=*/{}, &result));
ASSERT_EQ(result.outputs.size(), 2);
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
ASSERT_TRUE(output0.is_tensor_list);
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
ASSERT_TRUE(output1.is_tensor_list);
}
} // namespace
} // namespace tensorflow

View File

@ -153,7 +153,6 @@ XlaOpRegistry::~XlaOpRegistry() = default;
cpu_global_jit
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
registration.compile_all_resource_ops = false;
}
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
DeviceRegistration& registration =
@ -161,7 +160,6 @@ XlaOpRegistry::~XlaOpRegistry() = default;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
registration.compile_all_resource_ops = false;
}
return nullptr;
}();

View File

@ -88,8 +88,33 @@ class XlaOpRegistry {
// When should we autocluster operators assigned to this device?
AutoclusteringPolicy autoclustering_policy;
// Enable compilation of operators that use DT_RESOURCE types?
bool compile_all_resource_ops = false;
// If we should ignore the resource variable memory model when clustering
// resource variable reads and writes placed on this device.
bool cluster_resource_variable_ops_unsafely = false;
// If we should auto-cluster Stack operations placed on this device.
bool cluster_stack_ops = false;
// If we should auto-cluster TensorArray operations placed on this device.
bool cluster_tensor_array_ops = false;
// If we should auto-cluster stateful RNG operations placed on this device.
// Stateful RNG semantics are not properly supported by XLA so it is not
// necessarily correct to auto-cluster stateful RNG ops in general.
bool cluster_stateful_rng_ops = false;
// If we should auto-cluster ControlTrigger operations placed on this
// device. ControlTrigger operations are not necessarily safe to cluster
// since they affect deadness (a dead ControlTrigger produces a live
// output).
bool cluster_control_trigger = false;
// If we should cluster Assert and CheckNumerics by eliding them (XLA does
// not natively support Assert or CheckNumerics).
bool elide_assert_and_checknumerics = false;
// If we should cluster operations returning DT_VARIANT.
bool cluster_variant_ops = false;
};
// Registers an XLA backend. `compilation_device_name` is the name of the

View File

@ -287,6 +287,7 @@ tf_cc_test(
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:test_main",
"@com_google_absl//absl/hash:hash_testing",
"@com_google_absl//absl/strings",
],
)

View File

@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) {
namespace {
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
PrimitiveType value_type,
PrimitiveType index_type, bool is_min) {
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
XlaBuilder* b = sub_builder.get();
XlaOp lhs_value =
Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
XlaOp lhs_index =
Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
XlaOp rhs_value =
Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
XlaOp rhs_index =
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
XlaOp max = Select(cmp, lhs_value, rhs_value);
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
Tuple(b, {max, arg_max});
return b->Build().ConsumeValueOrDie();
}
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp value_init_value;
if (is_min) {
value_init_value = MaxValue(builder, input_shape.element_type());
} else {
value_init_value = MinValue(builder, input_shape.element_type());
}
int64 dimension_size = input_shape.dimensions(axis);
auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
XlaOp index_init_value = Zero(builder, index_type);
auto iota_shape = input_shape;
iota_shape.set_element_type(index_type);
XlaOp iota = Iota(builder, iota_shape, axis);
XlaComputation reducer = CreateMinMaxComputation(
builder, input_shape.element_type(), index_type, is_min);
XlaOp max_argmax = Reduce(builder, {input, iota},
{value_init_value, index_init_value}, reducer,
/*dimensions_to_reduce=*/{axis});
XlaOp argmax = GetTupleElement(max_argmax, 1);
if (index_type != output_type) {
argmax = ConvertElementType(argmax, output_type);
}
return argmax;
});
}
XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
bool is_min) {
XlaBuilder* builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
XlaOp init_value;
@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
/*dimensions_to_reduce=*/{axis});
});
}
} // namespace
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false);
}
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true);
}
} // namespace xla

View File

@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates);
// Returns the argmax of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis);
// Returns the argmin of `input` along `axis`. `output_type` is the type to
// use for the output.
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis);
} // namespace xla

View File

@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
}
} // namespace
// The internal state of the Three Fry implementation.
using ThreeFry2x32State = std::array<XlaOp, 2>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
XlaBuilder* builder = input[0].builder();
key[0] = BitcastConvertType(key[0], U32);
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
return x;
}
// Returns the inputs with unique counter values for ThreeFry2x32.
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
ThreeFry2x32State inputs;
inputs[0] = Iota(builder, U32, size);
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
return inputs;
}
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
XlaBuilder* builder = key[0].builder();
const int64 size = ShapeUtil::ElementsIn(shape);
const int64 half_size = CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
ThreeFry2x32State inputs = GetInputs(half_size, builder);
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto result = ConcatInDim(builder, outputs, 0);
return Reshape(result, AsInt64Slice(shape.dimensions()));
}
// Converts a uint64 to two uint32s.
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
auto builder = u64.builder();
auto const32 = ConstantR0WithType(builder, U64, 32);
auto fst = ConvertElementType(u64, U32);
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
XlaBuilder* builder = u64.builder();
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
XlaOp fst = ConvertElementType(u64, U32);
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
return {fst, snd};
}
// Converts two uint32s to a uint64.
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
auto builder = u32s[0].builder();
XlaBuilder* builder = u32s[0].builder();
return ConvertElementType(u32s[0], U64) |
ShiftLeft(ConvertElementType(u32s[1], U64),
ConstantR0WithType(builder, U64, 32));
}
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
XlaBuilder* builder = key[0].builder();
const int64 size = ShapeUtil::ElementsIn(shape);
ThreeFry2x32State inputs = GetInputs(size, builder);
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
// low 32 bit: outputs[0], high 32 bit: outputs[1]
auto result = Uint32sToUint64(outputs);
return Reshape(result, AsInt64Slice(shape.dimensions()));
// Given the initial state and the request number of random numbers to be
// generated, returns the input for the random number generator and a new state.
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
XlaOp initial_state, const int64 size) {
XlaBuilder* builder = initial_state.builder();
XlaOp input_u64 = Iota(builder, U64, size);
input_u64 = input_u64 + initial_state;
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
}
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = bits.builder();
// Generates random 32bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
XlaBuilder* builder = key.builder();
const int64 size = ShapeUtil::ElementsIn(shape);
const int64 half_size = CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
XlaOp result = ConcatInDim(builder, outputs, 0);
return {Reshape(result, AsInt64Slice(shape.dimensions())),
inputs_state.second};
}
// Generates random 64bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
const int64 size = ShapeUtil::ElementsIn(shape);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, size);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
XlaOp result = Uint32sToUint64(outputs);
return {Reshape(result, AsInt64Slice(shape.dimensions())),
inputs_state.second};
}
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = bits.builder();
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
// forces the random bits into the mantissa.
constexpr int kFloatBits = 32;
@ -161,50 +177,119 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
bits = ShiftRightLogical(
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
auto floats = BitcastConvertType(bits, F32);
XlaOp values = BitcastConvertType(bits, F32);
// We have a floating point number in the range [1.0, 2.0).
// Subtract 1.0f to shift to the range [0.0, 1.0)
floats = floats - ConstantR0<float>(builder, 1.0f);
values = values - ConstantR0<float>(builder, 1.0f);
// Multiply and add to shift to the range [minval, maxval).
return floats * (maxval - minval) + minval;
return values * (maxval - minval) + minval;
}
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type) {
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type,
PrimitiveType unsigned_type) {
XlaBuilder* builder = bits.builder();
auto range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type);
auto dist = Rem(bits, range);
auto dist_div_2 =
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type);
XlaOp dist = Rem(bits, range);
XlaOp dist_div_2 =
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
return minval + BitcastConvertType(dist_div_2, type) +
BitcastConvertType(dist - dist_div_2, type);
}
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval) {
XlaBuilder* builder = seeds[0].builder();
// Implements the Box-Muller transform, which converts random floats in the
// range of [0, 1] from uniform distribution to normal distribution with mean 0
// and variance 1. For more detail on the Box-Muller transform, see
// http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
// Do not send a really small number to log().
XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
return {Sin(v1) * u2, Cos(v1) * u2};
}
} // namespace
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32: {
auto bits = StatelessRngUniformU32(seeds, shape);
return StatelessRngUniformF32(bits, minval, maxval);
}
case S32: {
auto bits = StatelessRngUniformU32(seeds, shape);
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
}
case S64: {
auto bits = StatelessRngUniformU64(seeds, shape);
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
}
case F32:
case U32:
case S32:
return ThreeFryRngBit32(key, initial_state, shape);
case U64:
case S64:
return ThreeFryRngBit64(key, initial_state, shape);
default:
return builder->ReportError(Unimplemented(
"Types other than F32, S32 and S64 are not implemented by "
"StatelessRngUniform."));
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, U32, S32, U64 and S64 "
"are not implemented by ThreeFryBitGenerator; got %s",
primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const Shape& shape) {
DCHECK_EQ(shape.element_type(), F32);
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
}
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const Shape& shape) {
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
PrimitiveType type = shape.element_type();
PrimitiveType unsigned_type;
if (type == U32 || type == S32) {
unsigned_type = U32;
} else {
DCHECK(type == U64 || type == S64);
unsigned_type = U64;
}
return {
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
new_state};
}
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
const Shape& shape) {
DCHECK_EQ(shape.element_type(), F32);
XlaBuilder* builder = key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
RngOutput bits_state = UniformF32Distribution(
key, initial_state, bit_generator, ConstantR0<float>(builder, 0.0),
ConstantR0<float>(builder, 1.0),
ShapeUtil::MakeShape(F32, {num_pairs * 2}));
// Separate the bits into two groups to perform the Box-Muller transform.
XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
// Put the numbers in the two groups back to form the requested shape.
XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
if (num_elems != num_pairs * 2) {
normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
/*strides=*/{1});
}
normal = Reshape(normal, shape.dimensions());
return {normal, bits_state.state};
}
} // namespace xla

View File

@ -23,37 +23,52 @@ limitations under the License.
namespace xla {
// Records the bits and state generated by a random number generator.
struct RngOutput {
XlaOp value;
XlaOp state;
};
// A BitGenerator returns random bits and updated random bit generator state.
//
// key: is a value input to a random number generator that can affect the
// sequence of number it will generate. A random number generator constructs
// its seed using the key and the initial state. The tf2xla bridge passes the
// seed operand of a tensorflow random operation as a key to the random bit
// generator, for example.
// initial_state: initial_state is the initial state of the current random
// number generation. It could be 0 for a stateless random operation, and
// the returned state from a previous execution for a stateful random
// operation.
// shape: the shape of the random bits.
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
const xla::Shape& shape)>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
using ThreeFry2x32State = std::array<XlaOp, 2>;
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const xla::Shape& shape);
// Returns a tensor containing 'shape' random values uniformly distributed in
// the range [minval, maxval). Requires 2 32-bit integer seeds.
// Currently only 'shape's of type F32, S32 and S64 are implemented.
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
XlaOp minval, XlaOp maxval);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of uniform distribution in the given range.
// Returns the random numbers and the state of the random number generator.
// This function is for shape with float element type.
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const xla::Shape& shape);
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
// float32 in the range [minval, maxval).
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
// Similar to UniformF32Distribution but for shape with integer element types.
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const xla::Shape& shape);
// Converts an integer random number 'bits' of type 'type' to a random number
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
// unsigned version of 'type' (could be the same) with the same bit width.
// The algorithm is the same one that TF uses right now, but it's
// uniform only when maxval - minval is a divisor of the range that bits is
// generated from.
// TODO(b/72573764): Generate real uniform integer distribution.
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type, PrimitiveType unsigned_type);
// The following 2 functions, for converting between one uint64 and two uint32s,
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
// second".
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of normal distribution.
// Returns the random numbers and the state of the random number generator.
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
const xla::Shape& shape);
} // namespace xla

View File

@ -292,6 +292,11 @@ static void AllocateFlags() {
flag_values->xla_gpu_crash_on_verification_failures(),
"Crashes the program on extra verification failures, e.g. cuDNN "
"cross checking failures"),
tensorflow::Flag(
"xla_gpu_disable_autotune",
bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune),
flag_values->xla_gpu_disable_autotune(),
"Disable GEMM and Convolution auto-tuning."),
tensorflow::Flag(
"xla_force_host_platform_device_count",
int32_setter_for(

View File

@ -120,9 +120,14 @@ class Sharding(object):
tile_assignment_dimensions=tile_assignment_dims,
tile_assignment_devices=range(num_devices)))
def apply_to_tensor(self, tensor):
"""Applies this Sharding attribute to `tensor`."""
if len(tensor.op.outputs) > 1:
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
"""Applies this Sharding attribute to `tensor`.
Args:
tensor: A tf.Tensor to split.
assign_tuple_sharding: If the sharding type should be a tuple.
"""
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
proto = self._get_or_create_tuple_proto(tensor.op)
# We can't mutate an element of old_proto.tuple_shardings, so create
# a new proto.
@ -166,21 +171,30 @@ class Sharding(object):
# tensor = xla_sharding.replicate(tensor)
def replicate(tensor):
Sharding.replicate().apply_to_tensor(tensor)
def replicate(tensor, assign_tuple_sharding=False):
Sharding.replicate().apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor
def assign_device(tensor, device):
Sharding.assign_device(device).apply_to_tensor(tensor)
def assign_device(tensor, device, assign_tuple_sharding=False):
Sharding.assign_device(device).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor
def tile(tensor, tile_assignment):
Sharding.tile(tile_assignment).apply_to_tensor(tensor)
def tile(tensor, tile_assignment, assign_tuple_sharding=False):
Sharding.tile(tile_assignment).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding
)
return tensor
def split(tensor, split_dimension, num_devices):
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False):
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
tensor,
assign_tuple_sharding=assign_tuple_sharding)
return tensor

View File

@ -303,6 +303,37 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and
`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
`f32[2, 3]` and all the values in the result will be `2.0f`.
## BroadcastInDim
See also
[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Expands the size and rank of an array by duplicating the data in the array.
<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b>
| Arguments | Type | Semantics |
| ---------------------- | ------------------- | ----------------------------- |
| `operand` | `XlaOp` | The array to duplicate |
| `out_dim_size` | `ArraySlice<int64>` | The sizes of the dimensions |
: : : of the target shape :
| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target |
: : : shape each dimension of the :
: : : operand shape corresponds to :
Similar to Broadcast, but allows adding dimensions anywhere and expanding
existing dimensions with size 1.
The `operand` is broadcast to the shape described by `out_dim_size`.
`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the
target shape, i.e. the i'th dimension of the operand is mapped to the
broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of
`operand` must have size 1 or be the same size as the dimension in in the output
shape they are mapped to. The remaining dimensions are filled with dimensions of
size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate
dimensions to reach the output shape. The semantics are described in detail on
the [broadcasting page](broadcasting.md).
## Call
See also
@ -1258,6 +1289,59 @@ Arguments | Type | Semantics
The function is applied to each element in the `operand` array, resulting in an
array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
## Fft
The XLA FFT operation implements the forward and inverse Fourier Transforms for
real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
supported, except on TPU, where only a single axis is supported (please file a
github issue if you require higher order).
See also
[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
| Arguments | Type | Semantics |
| ------------ | ------------------- | ------------------------ |
| `operand` | `XlaOp` | The array we are Fourier |
: : : transforming. :
| `fft_type` | `FftType` | See the table below. |
| `fft_length` | `ArraySlice<int64>` | The time-domain lengths |
: : : of the axes being :
: : : transformed. This is :
: : : needed in particular for :
: : : IRFFT to right-size the :
: : : innermost axis, since :
: : : `RFFT(fft_length=[16])` :
: : : has the same output :
: : : shape as :
: : : `RFFT(fft_length=[17])`. :
| `FftType` | Semantics |
| --------- | ---------------------------------------------------------------- |
| `FFT` | Forward complex-to-complex FFT. Shape is unchanged. |
| `IFFT` | Inverse complex-to-complex FFT. Shape is unchanged. |
| `RFFT` | Forward real-to-complex FFT. Shape of the innermost axis is |
: : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a :
: : non-zero value, omitting the reversed conjugate part of the :
: : transformed signal beyond the Nyquist frequency. :
| `IRFFT` | Inverse real-to-complex FFT (i.e. takes complex, returns real). |
: : Shape of the innermost axis is expanded to `fft_length[-1]` if :
: : `fft_length[-1]` is a non-zero value, inferring the part of the :
: : transformed signal beyond the Nyquist frequency from the reverse :
: : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries. :
#### Multidimensional FFT
When more than 1 `fft_length` is provided, this is equivalent to applying a
cascade of FFT operations to each of the innermost axes. Note that for the
real->complex and complex->real cases, the innermost axis transform is
(effectively) performed first (RFFT; last for IRFFT), which is why the innermost
axis is the one which changes size. Other axis transforms will then be
complex->complex.
#### Implementation details
CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT.
## Gather
The XLA gather operation stitches together several slices (each slice at a

View File

@ -69,6 +69,11 @@ class Tile {
// combined with the next minor dimension before tiling is applied.
static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
template <typename H>
friend H AbslHashValue(H h, const Tile& t) {
return H::combine(std::move(h), t.dimensions_);
}
private:
// The bounds of the tile.
std::vector<int64> dimensions_;
@ -212,6 +217,13 @@ class Layout {
element_size_in_bits_ = 0;
}
template <typename H>
friend H AbslHashValue(H h, const Layout& l) {
return H::combine(std::move(h), l.format_, l.minor_to_major_,
l.max_sparse_elements_, l.tiles_,
l.element_size_in_bits_);
}
private:
// The format of this layout.
Format format_ = INVALID_FORMAT;

View File

@ -109,6 +109,9 @@ tf_pybind_extension(
":xrt",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@pybind11",
@ -132,11 +135,13 @@ tf_pybind_extension(
"//tensorflow/compiler/xla/client/lib:svd",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
# Do NOT remove this dependency. The XLA Python extension must not
# depend on any part of TensorFlow at runtime, **including**
# libtensorflow_framework.so. The XLA module is deployed self-contained

View File

@ -20,6 +20,11 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
@ -33,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace xla {
namespace xla_python {
@ -55,7 +61,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
return Status::OK();
}
StatusOr<LocalClient*> GetLocalClient(const std::string& platform_name) {
StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
const std::string& platform_name) {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform(platform_name));
if (platform->VisibleDeviceCount() <= 0) {
@ -64,47 +71,158 @@ StatusOr<LocalClient*> GetLocalClient(const std::string& platform_name) {
}
LocalClientOptions options;
options.set_platform(platform);
return ClientLibrary::GetOrCreateLocalClient(options);
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
return absl::make_unique<PyLocalClient>(client);
}
/* static */
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
const py::object& argument, LocalClient* client, int device_ordinal) {
VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
<< device_ordinal;
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
PyLocalClient::PyLocalClient(LocalClient* client)
: client_(client),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()),
execute_pool_(tensorflow::Env::Default(), "py_xla_execute",
client->device_count()) {}
// We are done manipulating Python objects; release the GIL.
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
int device_ordinal) {
py::gil_scoped_release gil_release;
DeviceMemoryAllocator* allocator = client->backend().memory_allocator();
TransferManager* transfer_manager = client->backend().transfer_manager();
return client_->TransferToInfeedLocal(literal, device_ordinal);
}
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
const Shape& shape, int device_ordinal) {
Literal literal;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
}
return LiteralToPython(absl::make_unique<Literal>(std::move(literal)));
}
static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
se::Stream* stream) {
DeviceMemoryAllocator* allocator =
client->client()->backend().memory_allocator();
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
transfer_manager->AllocateScopedShapedBuffer(
tree.shape, allocator, device_ordinal));
TF_ASSIGN_OR_RETURN(auto stream,
client->mutable_backend()->BorrowStream(device_ordinal));
shape, allocator, device_ordinal));
TF_RETURN_IF_ERROR(
transfer_manager->WriteTupleIndexTables(stream.get(), buffer));
transfer_manager->WriteTupleIndexTablesAsync(stream, buffer));
auto it = tree.leaves.begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(tree.shape)) {
ShapeUtil::GetLeafShapes(shape)) {
TF_RET_CHECK(it != tree.leaves.end());
ShapedBuffer leaf(
indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
client->platform(), device_ordinal);
client->client()->platform(), device_ordinal);
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
stream.get(), *it, leaf));
TF_RETURN_IF_ERROR(
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
++it;
}
stream->BlockHostUntilDone();
return LocalShapedBuffer(std::move(buffer), client);
}
/* static */
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
const py::object& argument, PyLocalClient* client, int device_ordinal) {
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPython");
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
// We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release;
VLOG(1) << "LocalShapedBuffer::FromPython: shape: " << tree.shape.ToString()
<< " device ordinal: " << device_ordinal;
TF_ASSIGN_OR_RETURN(
StreamPool::Ptr stream,
client->client()->mutable_backend()->BorrowStream(device_ordinal));
TF_ASSIGN_OR_RETURN(
LocalShapedBuffer buffer,
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
stream->BlockHostUntilDone();
return buffer;
}
/*static */ StatusOr<std::vector<LocalShapedBuffer>>
LocalShapedBuffer::FromPythonValues(
const std::vector<std::pair<py::object, int>>& arguments,
PyLocalClient* client) {
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPythonValues");
int num_arguments = static_cast<int>(arguments.size());
std::vector<LocalShapedBuffer> outputs(num_arguments);
if (num_arguments == 0) {
return outputs;
}
struct H2DTransfer {
PythonBufferTree tree;
StreamPool::Ptr stream;
StatusOr<LocalShapedBuffer> buffer;
};
std::vector<H2DTransfer> transfers(num_arguments);
for (int i = 0; i < num_arguments; ++i) {
TF_ASSIGN_OR_RETURN(transfers[i].tree,
GetPythonBufferTree(arguments[i].first));
}
// We are done manipulating Python objects; release the GIL.
py::gil_scoped_release gil_release;
for (int i = 0; i < num_arguments; ++i) {
int device_ordinal = arguments[i].second;
TF_ASSIGN_OR_RETURN(
transfers[i].stream,
client->client()->mutable_backend()->BorrowStream(device_ordinal));
}
auto transfer_h2d = [&](int i) -> StatusOr<LocalShapedBuffer> {
int device_ordinal = arguments[i].second;
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
transfers[i].stream.get());
};
// We perform the transfers on a thread pool in case XLA needs to do any
// host-side preprocessing of the input data.
if (num_arguments == 1) {
transfers[0].buffer = transfer_h2d(0);
} else {
absl::BlockingCounter counter(num_arguments - 1);
for (int i = 1; i < num_arguments; ++i) {
client->h2d_transfer_pool()->Schedule([&, i]() {
transfers[i].buffer = transfer_h2d(i);
counter.DecrementCount();
});
}
// Perform the first transfer on the main thread.
transfers[0].buffer = transfer_h2d(0);
counter.Wait();
}
// First, wait for all transfers to complete. We wait for all to complete
// since currently we maintain the invariant that the device's view of the
// state matches the host's view of the state. Returning early would mean that
// we might deallocate device-side memory before a transfer completes, which
// violates that invariant.
for (int i = 0; i < num_arguments; ++i) {
transfers[i].stream->BlockHostUntilDone();
}
for (int i = 0; i < num_arguments; ++i) {
TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer));
}
return outputs;
}
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
LocalClient* client)
PyLocalClient* client)
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
@ -122,16 +240,18 @@ const Shape& LocalShapedBuffer::shape() const {
}
StatusOr<py::object> LocalShapedBuffer::ToPython() const {
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::ToPython");
auto literal = absl::make_unique<Literal>();
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(*literal,
client_->ShapedBufferToLiteral(*shaped_buffer()));
TF_ASSIGN_OR_RETURN(
*literal, client_->client()->ShapedBufferToLiteral(*shaped_buffer()));
}
return LiteralToPython(std::move(literal));
}
StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::DestructureTuple");
const Shape tuple_shape = shape();
if (!tuple_shape.IsTuple()) {
@ -173,14 +293,14 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
return results;
}
LocalExecutableWrapper::LocalExecutableWrapper(
PyLocalExecutable::PyLocalExecutable(
std::unique_ptr<LocalExecutable> executable,
DeviceAssignment device_assignment, LocalClient* client)
DeviceAssignment device_assignment, PyLocalClient* client)
: executable_(std::move(executable)),
device_assignment_(std::move(device_assignment)),
client_(client) {}
std::vector<int> LocalExecutableWrapper::DeviceOrdinals() const {
std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
int num_replicas = device_assignment_.replica_count();
std::vector<int> device_ordinals;
device_ordinals.reserve(num_replicas);
@ -190,8 +310,9 @@ std::vector<int> LocalExecutableWrapper::DeviceOrdinals() const {
return device_ordinals;
}
StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
absl::Span<LocalShapedBuffer* const> argument_handles) {
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute()",
@ -210,26 +331,23 @@ StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal);
options.set_allocator(client_->backend().memory_allocator());
options.set_allocator(client_->client()->backend().memory_allocator());
options.set_intra_op_thread_pool(
client_->backend().eigen_intra_op_thread_pool_device());
client_->client()->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment_);
result_buffer_status = executable_->Run(argument_buffers, options);
if (!result_buffer_status.ok()) {
return InternalError(
"Failed running replica 0 (other replicas may have failed as well): "
"%s.",
result_buffer_status.status().ToString());
return result_buffer_status.status();
}
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
client_);
}
StatusOr<std::vector<LocalShapedBuffer>>
LocalExecutableWrapper::ExecutePerReplica(
StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica");
const int num_devices = client_->device_count();
if (argument_handles.size() != num_replicas()) {
@ -245,8 +363,8 @@ LocalExecutableWrapper::ExecutePerReplica(
VLOG(1) << "Executing with " << num_replicas() << " replicas.";
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
auto execute = [this, &argument_handles, &results](int replica) {
auto execute =
[this, &argument_handles](int replica) -> StatusOr<ScopedShapedBuffer> {
const int device_ordinal = device_assignment_(replica, 0);
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: " << device_ordinal;
@ -259,39 +377,83 @@ LocalExecutableWrapper::ExecutePerReplica(
ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal);
options.set_allocator(client_->backend().memory_allocator());
options.set_allocator(client_->client()->backend().memory_allocator());
options.set_intra_op_thread_pool(
client_->backend().eigen_intra_op_thread_pool_device());
client_->client()->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment_);
StatusOr<ScopedShapedBuffer> result_buffer_status =
executable_->Run(argument_buffers, options);
results[replica] = std::move(result_buffer_status);
VLOG(1) << "Replica " << replica
<< " completed; ok=" << result_buffer_status.ok();
if (!result_buffer_status.ok()) {
LOG(ERROR) << "Execution of replica " << replica
<< " failed: " << result_buffer_status.status();
}
return result_buffer_status;
};
VLOG(1) << "Executing replicated computation; num_replicas="
<< num_replicas();
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
if (num_replicas() == 1) {
// Fast-path if there is only one replica — run the computation on the
// current thread.
execute(0);
results[0] = execute(0);
} else {
// TODO(phawkins): don't recreate the threadpool for each execution.
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
num_replicas() - 1);
absl::Mutex mu;
int running GUARDED_BY(mu) = num_replicas();
int failed GUARDED_BY(mu) = 0;
for (int replica = 0; replica < num_replicas() - 1; ++replica) {
pool.Schedule([&execute, replica] { execute(replica); });
for (int replica = 0; replica < num_replicas(); ++replica) {
client_->execute_pool()->Schedule(
[&execute, &mu, &running, &failed, &results, replica] {
results[replica] = execute(replica);
absl::MutexLock lock(&mu);
--running;
if (!results[replica].ok()) {
++failed;
}
});
}
auto done_running_or_failed = [&]() {
mu.AssertHeld();
return running == 0 || failed > 0;
};
absl::MutexLock lock(&mu);
mu.Await(absl::Condition(&done_running_or_failed));
if (failed > 0) {
auto done_running = [&]() {
mu.AssertHeld();
return running == 0;
};
// If execution does not terminate within a reasonable amount of time, we
// may be stuck at a cross-replica barrier on-device. Terminate the
// process since that's the only way we can escape this situation at the
// moment (b/130629719).
if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
absl::Seconds(10))) {
LOG(FATAL)
<< "Replicated computation launch failed, but not all replicas "
"terminated. Aborting process to work around deadlock. See the "
"error log for details of the failure.";
}
}
execute(num_replicas() - 1);
}
VLOG(1) << "Replicated execution complete.";
std::vector<LocalShapedBuffer> wrapped_results(num_replicas());
for (int replica = 0; replica < num_replicas(); ++replica) {
auto& statusor = results[replica];
if (!statusor.ok()) {
return InternalError(
"Failed running replica %d (other replicas may have failed as well): "
"%s.",
replica, statusor.status().ToString());
return AppendStatus(
statusor.status(),
absl::StrFormat(
"while running replica %d of a replicated computation (other "
"replicas may have failed as well).",
replica));
}
wrapped_results[replica] =
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
@ -334,29 +496,62 @@ StatusOr<std::string> GetComputationHloDotGraph(
RenderedGraphFormat::kDot);
}
/*static*/ StatusOr<std::unique_ptr<LocalExecutableWrapper>>
LocalExecutableWrapper::Compile(const XlaComputation& computation,
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options,
LocalClient* client) {
std::vector<const Shape*> argument_shape_pointers;
argument_shape_pointers.reserve(argument_shapes.size());
for (auto& argument_shape : argument_shapes) {
argument_shape_pointers.push_back(&argument_shape);
/*static*/ StatusOr<std::unique_ptr<PyLocalExecutable>>
PyLocalExecutable::Compile(const XlaComputation& computation,
std::vector<Shape> argument_layouts,
const ExecutableBuildOptions* build_options,
PyLocalClient* client) {
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
std::vector<const Shape*> argument_layout_pointers;
argument_layout_pointers.reserve(argument_layouts.size());
// Assign a default layout to any array subshapes that are missing layouts.
auto assign_layouts = [client](Shape* shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
shape, [&](Shape* subshape, const ShapeIndex&) {
if (subshape->IsArray() && !subshape->has_layout()) {
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(*subshape,
client->client()
->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(*subshape));
}
return Status::OK();
});
};
for (Shape& layout : argument_layouts) {
argument_layout_pointers.push_back(&layout);
assign_layouts(&layout);
}
ExecutableBuildOptions options;
if (build_options != nullptr) {
options = *build_options;
}
TF_ASSIGN_OR_RETURN(
auto local_executable,
client->Compile(computation, argument_shape_pointers, options));
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
client->backend().computation_placer()->AssignDevices(
options.num_replicas(), /*computation_count=*/1));
return absl::make_unique<LocalExecutableWrapper>(
Shape result_layout;
if (options.result_layout()) {
result_layout = *options.result_layout();
} else {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
result_layout = program_shape.result();
LayoutUtil::ClearLayout(&result_layout);
}
assign_layouts(&result_layout);
options.set_result_layout(result_layout);
TF_ASSIGN_OR_RETURN(std::unique_ptr<LocalExecutable> local_executable,
client->client()->Compile(
computation, argument_layout_pointers, options));
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
client->client()->backend().computation_placer()->AssignDevices(
options.num_replicas(), /*computation_count=*/1));
return absl::make_unique<PyLocalExecutable>(
std::move(local_executable), std::move(device_assignment), client);
}

Some files were not shown because too many files have changed in this diff Show More