Merge branch 'master' into unify_doccitations
This commit is contained in:
commit
d502acb5ae
ISSUE_TEMPLATE.mdLICENSEREADME.mdWORKSPACEconfigure.py
tensorflow
api_template.__init__.pyapi_template_v1.__init__.py
c
BUILDc_api_experimental.ccc_api_experimental_test.ccc_api_function.ccc_api_function_test.cccheckpoint_reader.cccheckpoint_reader.h
compat_template_v1.__init__.pyeager
compiler
aot
jit
BUILDdeadness_analysis.ccdeadness_analysis.hdeadness_analysis_test.ccextract_outside_compilation_pass.ccmark_for_compilation_pass.ccmark_for_compilation_pass.hmark_for_compilation_pass_test.ccmark_for_compilation_pass_test_helper.ccmark_for_compilation_pass_test_helper.hresource_operation_safety_analysis.ccxla_cluster_util.ccxla_cluster_util.hxla_cpu_device.ccxla_device.ccxla_device_ops.hxla_fusion_optimizer.ccxla_fusion_optimizer.hxla_fusion_optimizer_test.ccxla_gpu_device.ccxla_interpreter_device.ccxla_launch_util.cc
tests
BUILDbinary_ops_test.pyeager_test.pyjit_test.pypooling_ops_test.pytensor_array_ops_test.pytensor_list_ops_test.py
tf2tensorrt
tf2xla
BUILDgraph_compiler.cc
kernels
aggregate_ops.ccbatch_matmul_op.cccategorical_op.ccconv_op_helpers.ccconv_op_helpers.hfake_quantize_ops.ccindex_ops.ccindex_ops.hrandom_ops_util.hreshape_op.ccshape_op.ccstateful_random_ops.ccstateless_random_ops.cctensor_list_ops.cctensor_list_utils.cctensor_list_utils.hwhile_op.cc
tf2xla.cctf2xla.protoxla_compiler.ccxla_compiler.hxla_compiler_test.ccxla_op_registry.ccxla_op_registry.hxla
@ -32,7 +32,7 @@ https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
|
||||
You can obtain the TensorFlow version with:
|
||||
|
||||
```bash
|
||||
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
|
||||
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
|
||||
```
|
||||
|
||||
### Describe the problem
|
||||
|
2
LICENSE
2
LICENSE
@ -1,4 +1,4 @@
|
||||
Copyright 2018 The TensorFlow Authors. All rights reserved.
|
||||
Copyright 2019 The TensorFlow Authors. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
|
10
README.md
10
README.md
@ -25,7 +25,7 @@ networks research. The system is general enough to be applicable in a wide
|
||||
variety of other domains, as well.
|
||||
|
||||
TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
|
||||
compatible API's for C++, Go, Java, JavaScript and Swift.
|
||||
compatible API's for C++, Go, Java, JavaScript, and Swift.
|
||||
|
||||
Keep up to date with release announcements and security updates by
|
||||
subscribing to
|
||||
@ -50,10 +50,10 @@ instructions, and how to build from source.*
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
**Nightly pip packages**
|
||||
* We are pleased to announce that TensorFlow now offers nightly pip packages
|
||||
under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi.
|
||||
**Nightly pip packages** * We are pleased to announce that TensorFlow now offers
|
||||
nightly pip packages under the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi.
|
||||
Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean
|
||||
environment to install the nightly TensorFlow build. We support CPU and GPU
|
||||
packages on Linux, Mac, and Windows.
|
||||
|
66
WORKSPACE
66
WORKSPACE
@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file"
|
||||
|
||||
http_archive(
|
||||
name = "io_bazel_rules_closure",
|
||||
sha256 = "ddce3b3a3909f99b28b25071c40b7fec7e2e1d1d1a4b2e933f3082aa99517105",
|
||||
strip_prefix = "rules_closure-316e6133888bfc39fb860a4f1a31cfcbae485aef",
|
||||
sha256 = "e0a111000aeed2051f29fcc7a3f83be3ad8c6c93c186e64beb1ad313f0c7f9f9",
|
||||
strip_prefix = "rules_closure-cf1e44edb908e9616030cc83d085989b8e6cd6df",
|
||||
urls = [
|
||||
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/316e6133888bfc39fb860a4f1a31cfcbae485aef.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/316e6133888bfc39fb860a4f1a31cfcbae485aef.tar.gz", # 2019-03-21
|
||||
"http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/cf1e44edb908e9616030cc83d085989b8e6cd6df.tar.gz", # 2019-04-04
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,47 +43,37 @@ remote_config_workspace()
|
||||
# Apple and Swift rules.
|
||||
http_archive(
|
||||
name = "build_bazel_rules_apple",
|
||||
sha256 = "4b90786009fa8df25230442244bad2832ba8d6bc4987f68150a7de59c8827e90",
|
||||
strip_prefix = "rules_apple-0.14.0",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/archive/0.14.0.tar.gz"],
|
||||
)
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2c62d8cd4ab1e65c08647eb4afe38f51591f43f7f0885e7769832fa137633dcb",
|
||||
strip_prefix = "bazel-skylib-0.7.0",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.7.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "8f32e2839fba28d549e1670dbed83606dd339a9f7489118e481814d61738270f",
|
||||
urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.14.0/rules_apple.0.14.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_apple/releases
|
||||
http_archive(
|
||||
name = "build_bazel_apple_support",
|
||||
sha256 = "835663c4bb02f4bf01dce8a2a176df7fa682dbb867d3698ae12258c1628bb8f0",
|
||||
strip_prefix = "apple_support-0.5.0",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/archive/0.5.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471",
|
||||
urls = ["https://github.com/bazelbuild/apple_support/releases/download/0.6.0/apple_support.0.6.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/apple_support/releases
|
||||
http_archive(
|
||||
name = "bazel_skylib",
|
||||
sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e",
|
||||
urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/bazel-skylib/releases
|
||||
http_archive(
|
||||
name = "build_bazel_rules_swift",
|
||||
sha256 = "32d124878cd49775d84f59ba90440c8b23b7c775aec8fec1978f751c76ddee8a",
|
||||
strip_prefix = "rules_swift-0.7.0",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/archive/0.7.0.tar.gz"],
|
||||
)
|
||||
|
||||
sha256 = "31aad005a9c4e56b256125844ad05eb27c88303502d74138186f9083479f93a6",
|
||||
urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.8.0/rules_swift.0.8.0.tar.gz"],
|
||||
) # https://github.com/bazelbuild/rules_swift/releases
|
||||
http_archive(
|
||||
name = "com_github_apple_swift_swift_protobuf",
|
||||
type = "zip",
|
||||
strip_prefix = "swift-protobuf-1.2.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.2.0.zip"],
|
||||
)
|
||||
|
||||
# Use swift_rules_dependencies to fetch the tolchains.
|
||||
# Since we defined all the "git_repository" rules above, the following call will
|
||||
# skip redefining them.
|
||||
strip_prefix = "swift-protobuf-1.4.0/",
|
||||
urls = ["https://github.com/apple/swift-protobuf/archive/1.4.0.zip"],
|
||||
) # https://github.com/apple/swift-protobuf/releases
|
||||
http_file(
|
||||
name = "xctestrunner",
|
||||
executable = 1,
|
||||
urls = ["https://github.com/google/xctestrunner/releases/download/0.2.7/ios_test_runner.par"],
|
||||
) # https://github.com/google/xctestrunner/releases
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||
swift_rules_dependencies()
|
||||
|
||||
|
604
configure.py
604
configure.py
@ -33,13 +33,11 @@ except ImportError:
|
||||
from distutils.spawn import find_executable as which
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
_DEFAULT_CUDA_VERSION = '10.0'
|
||||
_DEFAULT_CUDA_VERSION = '10'
|
||||
_DEFAULT_CUDNN_VERSION = '7'
|
||||
_DEFAULT_TENSORRT_VERSION = '5'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
|
||||
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
||||
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
||||
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
|
||||
|
||||
_TF_OPENCL_VERSION = '1.2'
|
||||
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
|
||||
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
|
||||
@ -58,6 +56,7 @@ NCCL_LIB_PATHS = [
|
||||
|
||||
# List of files to configure when building Bazel on Apple platforms.
|
||||
APPLE_BAZEL_FILES = [
|
||||
'tensorflow/lite/experimental/ios/BUILD',
|
||||
'tensorflow/lite/experimental/objc/BUILD',
|
||||
'tensorflow/lite/experimental/swift/BUILD'
|
||||
]
|
||||
@ -68,11 +67,6 @@ IOS_FILES = [
|
||||
'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec',
|
||||
]
|
||||
|
||||
if platform.machine() == 'ppc64le':
|
||||
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
|
||||
else:
|
||||
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
|
||||
|
||||
|
||||
class UserInputError(Exception):
|
||||
pass
|
||||
@ -206,9 +200,10 @@ def setup_python(environ_cp):
|
||||
ask_python_bin_path = ('Please specify the location of python. [Default is '
|
||||
'%s]: ') % default_python_bin_path
|
||||
while True:
|
||||
python_bin_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
|
||||
default_python_bin_path)
|
||||
python_bin_path = get_from_env_or_user_or_default(environ_cp,
|
||||
'PYTHON_BIN_PATH',
|
||||
ask_python_bin_path,
|
||||
default_python_bin_path)
|
||||
# Check if the path is valid
|
||||
if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
|
||||
break
|
||||
@ -392,14 +387,14 @@ def set_build_var(environ_cp,
|
||||
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
|
||||
environ_cp[var_name] = var
|
||||
if var == '1':
|
||||
write_to_bazelrc(
|
||||
'build:%s --define %s=true' % (bazel_config_name, option_name))
|
||||
write_to_bazelrc('build:%s --define %s=true' %
|
||||
(bazel_config_name, option_name))
|
||||
write_to_bazelrc('build --config=%s' % bazel_config_name)
|
||||
elif bazel_config_name is not None:
|
||||
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
|
||||
# options and not to set build configs through environment variables.
|
||||
write_to_bazelrc(
|
||||
'build:%s --define %s=true' % (bazel_config_name, option_name))
|
||||
write_to_bazelrc('build:%s --define %s=true' %
|
||||
(bazel_config_name, option_name))
|
||||
|
||||
|
||||
def set_action_env_var(environ_cp,
|
||||
@ -446,6 +441,9 @@ def convert_version_to_int(version):
|
||||
"""
|
||||
version = version.split('-')[0]
|
||||
version_segments = version.split('.')
|
||||
# Treat "0.24" as "0.24.0"
|
||||
if len(version_segments) == 2:
|
||||
version_segments.append('0')
|
||||
for seg in version_segments:
|
||||
if not seg.isdigit():
|
||||
return None
|
||||
@ -665,9 +663,9 @@ def prompt_loop_or_load_from_env(environ_cp,
|
||||
print(error_msg % val)
|
||||
environ_cp[var_name] = ''
|
||||
else:
|
||||
raise UserInputError(
|
||||
'Invalid %s setting was provided %d times in a row. '
|
||||
'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts))
|
||||
raise UserInputError('Invalid %s setting was provided %d times in a row. '
|
||||
'Assuming to be a scripting mistake.' %
|
||||
(var_name, n_ask_attempts))
|
||||
|
||||
environ_cp[var_name] = val
|
||||
return val
|
||||
@ -676,8 +674,8 @@ def prompt_loop_or_load_from_env(environ_cp,
|
||||
def create_android_ndk_rule(environ_cp):
|
||||
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
|
||||
if is_windows() or is_cygwin():
|
||||
default_ndk_path = cygpath(
|
||||
'%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA'])
|
||||
default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
|
||||
environ_cp['APPDATA'])
|
||||
elif is_macos():
|
||||
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
|
||||
else:
|
||||
@ -696,8 +694,9 @@ def create_android_ndk_rule(environ_cp):
|
||||
error_msg=('The path %s or its child file "source.properties" '
|
||||
'does not exist.'))
|
||||
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
|
||||
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
|
||||
check_ndk_level(android_ndk_home_path))
|
||||
write_action_env_to_bazelrc(
|
||||
'ANDROID_NDK_API_LEVEL',
|
||||
get_ndk_api_level(environ_cp, android_ndk_home_path))
|
||||
|
||||
|
||||
def create_android_sdk_rule(environ_cp):
|
||||
@ -764,8 +763,10 @@ def create_android_sdk_rule(environ_cp):
|
||||
write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
|
||||
|
||||
|
||||
def check_ndk_level(android_ndk_home_path):
|
||||
"""Check the revision number of an Android NDK path."""
|
||||
def get_ndk_api_level(environ_cp, android_ndk_home_path):
|
||||
"""Gets the appropriate NDK API level to use for the provided Android NDK path."""
|
||||
|
||||
# First check to see if we're using a blessed version of the NDK.
|
||||
properties_path = '%s/source.properties' % android_ndk_home_path
|
||||
if is_windows() or is_cygwin():
|
||||
properties_path = cygpath(properties_path)
|
||||
@ -774,17 +775,40 @@ def check_ndk_level(android_ndk_home_path):
|
||||
|
||||
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
|
||||
if revision:
|
||||
ndk_api_level = revision.group(1)
|
||||
ndk_version = revision.group(1)
|
||||
else:
|
||||
raise Exception('Unable to parse NDK revision.')
|
||||
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
|
||||
print(
|
||||
'WARNING: The API level of the NDK in %s is %s, which is not '
|
||||
'supported by Bazel (officially supported versions: %s). Please use '
|
||||
'another version. Compiling Android targets may result in confusing '
|
||||
'errors.\n' %
|
||||
(android_ndk_home_path, ndk_api_level, _SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
return ndk_api_level
|
||||
if int(ndk_version) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
|
||||
print('WARNING: The NDK version in %s is %s, which is not '
|
||||
'supported by Bazel (officially supported versions: %s). Please use '
|
||||
'another version. Compiling Android targets may result in confusing '
|
||||
'errors.\n' % (android_ndk_home_path, ndk_version,
|
||||
_SUPPORTED_ANDROID_NDK_VERSIONS))
|
||||
|
||||
# Now grab the NDK API level to use. Note that this is different from the
|
||||
# SDK API level, as the NDK API level is effectively the *min* target SDK
|
||||
# version.
|
||||
platforms = os.path.join(android_ndk_home_path, 'platforms')
|
||||
api_levels = sorted(os.listdir(platforms))
|
||||
api_levels = [
|
||||
x.replace('android-', '') for x in api_levels if 'android-' in x
|
||||
]
|
||||
|
||||
def valid_api_level(api_level):
|
||||
return os.path.exists(
|
||||
os.path.join(android_ndk_home_path, 'platforms',
|
||||
'android-' + api_level))
|
||||
|
||||
android_ndk_api_level = prompt_loop_or_load_from_env(
|
||||
environ_cp,
|
||||
var_name='ANDROID_NDK_API_LEVEL',
|
||||
var_default='18', # 18 is required for GPU acceleration.
|
||||
ask_for_var=('Please specify the (min) Android NDK API level to use. '
|
||||
'[Available levels: %s]') % api_levels,
|
||||
check_success=valid_api_level,
|
||||
error_msg='Android-%s is not present in the NDK path.')
|
||||
|
||||
return android_ndk_api_level
|
||||
|
||||
|
||||
def set_gcc_host_compiler_path(environ_cp):
|
||||
@ -831,149 +855,39 @@ def reformat_version_sequence(version_str, sequence_count):
|
||||
return '.'.join(v[:sequence_count])
|
||||
|
||||
|
||||
def set_tf_cuda_paths(environ_cp):
|
||||
"""Set TF_CUDA_PATHS."""
|
||||
ask_cuda_paths = (
|
||||
'Please specify the comma-separated list of base paths to look for CUDA '
|
||||
'libraries and headers. [Leave empty to use the default]: ')
|
||||
tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS',
|
||||
ask_cuda_paths, '')
|
||||
if tf_cuda_paths:
|
||||
environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths
|
||||
|
||||
|
||||
def set_tf_cuda_version(environ_cp):
|
||||
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
|
||||
"""Set TF_CUDA_VERSION."""
|
||||
ask_cuda_version = (
|
||||
'Please specify the CUDA SDK version you want to use. '
|
||||
'[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
# Configure the Cuda SDK version to use.
|
||||
tf_cuda_version = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
|
||||
tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
|
||||
|
||||
# Find out where the CUDA toolkit is installed
|
||||
default_cuda_path = _DEFAULT_CUDA_PATH
|
||||
if is_windows() or is_cygwin():
|
||||
default_cuda_path = cygpath(
|
||||
environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN))
|
||||
elif is_linux():
|
||||
# If the default doesn't exist, try an alternative default.
|
||||
if (not os.path.exists(default_cuda_path)
|
||||
) and os.path.exists(_DEFAULT_CUDA_PATH_LINUX):
|
||||
default_cuda_path = _DEFAULT_CUDA_PATH_LINUX
|
||||
ask_cuda_path = ('Please specify the location where CUDA %s toolkit is'
|
||||
' installed. Refer to README.md for more details. '
|
||||
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
|
||||
cuda_toolkit_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
|
||||
if is_windows() or is_cygwin():
|
||||
cuda_toolkit_path = cygpath(cuda_toolkit_path)
|
||||
|
||||
if is_windows():
|
||||
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
|
||||
elif is_linux():
|
||||
cuda_rt_lib_paths = [
|
||||
'%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
|
||||
'lib64',
|
||||
'lib/powerpc64le-linux-gnu',
|
||||
'lib/x86_64-linux-gnu',
|
||||
]
|
||||
]
|
||||
elif is_macos():
|
||||
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
|
||||
|
||||
cuda_toolkit_paths_full = [
|
||||
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
|
||||
]
|
||||
if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
|
||||
break
|
||||
|
||||
# Reset and retry
|
||||
print('Invalid path to CUDA %s toolkit. %s cannot be found' %
|
||||
(tf_cuda_version, cuda_toolkit_paths_full))
|
||||
environ_cp['TF_CUDA_VERSION'] = ''
|
||||
environ_cp['CUDA_TOOLKIT_PATH'] = ''
|
||||
|
||||
else:
|
||||
raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
# Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION
|
||||
environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path
|
||||
write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path)
|
||||
tf_cuda_version = get_from_env_or_user_or_default(environ_cp,
|
||||
'TF_CUDA_VERSION',
|
||||
ask_cuda_version,
|
||||
_DEFAULT_CUDA_VERSION)
|
||||
environ_cp['TF_CUDA_VERSION'] = tf_cuda_version
|
||||
write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version)
|
||||
|
||||
|
||||
def set_tf_cudnn_version(environ_cp):
|
||||
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
|
||||
"""Set TF_CUDNN_VERSION."""
|
||||
ask_cudnn_version = (
|
||||
'Please specify the cuDNN version you want to use. '
|
||||
'[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
tf_cudnn_version = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
|
||||
_DEFAULT_CUDNN_VERSION)
|
||||
tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
|
||||
|
||||
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
|
||||
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
|
||||
'installed. Refer to README.md for more details. [Default'
|
||||
' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
|
||||
cudnn_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
cudnn_install_path = os.path.realpath(
|
||||
os.path.expanduser(cudnn_install_path))
|
||||
if is_windows() or is_cygwin():
|
||||
cudnn_install_path = cygpath(cudnn_install_path)
|
||||
|
||||
if is_windows():
|
||||
cuda_dnn_lib_path = 'lib/x64/cudnn.lib'
|
||||
cuda_dnn_lib_alt_path = 'lib/x64/cudnn.lib'
|
||||
elif is_linux():
|
||||
cuda_dnn_lib_path = 'lib64/libcudnn.so.%s' % tf_cudnn_version
|
||||
cuda_dnn_lib_alt_path = 'libcudnn.so.%s' % tf_cudnn_version
|
||||
elif is_macos():
|
||||
cuda_dnn_lib_path = 'lib/libcudnn.%s.dylib' % tf_cudnn_version
|
||||
cuda_dnn_lib_alt_path = 'libcudnn.%s.dylib' % tf_cudnn_version
|
||||
|
||||
cuda_dnn_lib_path_full = os.path.join(cudnn_install_path, cuda_dnn_lib_path)
|
||||
cuda_dnn_lib_alt_path_full = os.path.join(cudnn_install_path,
|
||||
cuda_dnn_lib_alt_path)
|
||||
if os.path.exists(cuda_dnn_lib_path_full) or os.path.exists(
|
||||
cuda_dnn_lib_alt_path_full):
|
||||
break
|
||||
|
||||
# Try another alternative for Linux
|
||||
if is_linux():
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
|
||||
cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)',
|
||||
cudnn_path_from_ldconfig)
|
||||
if cudnn_path_from_ldconfig:
|
||||
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
|
||||
if os.path.exists(
|
||||
'%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
|
||||
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
print(
|
||||
'Invalid path to cuDNN %s toolkit. None of the following files can be '
|
||||
'found:' % tf_cudnn_version)
|
||||
print(cuda_dnn_lib_path_full)
|
||||
print(cuda_dnn_lib_alt_path_full)
|
||||
if is_linux():
|
||||
print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version))
|
||||
|
||||
environ_cp['TF_CUDNN_VERSION'] = ''
|
||||
else:
|
||||
raise UserInputError('Invalid TF_CUDNN setting was provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
# Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION
|
||||
environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path
|
||||
write_action_env_to_bazelrc('CUDNN_INSTALL_PATH', cudnn_install_path)
|
||||
tf_cudnn_version = get_from_env_or_user_or_default(environ_cp,
|
||||
'TF_CUDNN_VERSION',
|
||||
ask_cudnn_version,
|
||||
_DEFAULT_CUDNN_VERSION)
|
||||
environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version
|
||||
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
|
||||
|
||||
|
||||
def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
|
||||
@ -1005,252 +919,38 @@ def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
|
||||
return cudnn_ok and cuda_ok
|
||||
|
||||
|
||||
def set_tf_tensorrt_install_path(environ_cp):
|
||||
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
|
||||
|
||||
Adapted from code contributed by Sami Kama (https://github.com/samikama).
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
|
||||
Raises:
|
||||
ValueError: if this method was called under non-Linux platform.
|
||||
UserInputError: if user has provided invalid input multiple times.
|
||||
"""
|
||||
def set_tf_tensorrt_version(environ_cp):
|
||||
"""Set TF_TENSORRT_VERSION."""
|
||||
if not is_linux():
|
||||
raise ValueError('Currently TensorRT is only supported on Linux platform.')
|
||||
|
||||
# Ask user whether to add TensorRT support.
|
||||
if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
|
||||
False))) != '1':
|
||||
if not int(environ_cp.get('TF_NEED_TENSORRT', False)):
|
||||
return
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
ask_tensorrt_path = (r'Please specify the location where TensorRT is '
|
||||
'installed. [Default is %s]:') % (
|
||||
_DEFAULT_TENSORRT_PATH_LINUX)
|
||||
trt_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
|
||||
_DEFAULT_TENSORRT_PATH_LINUX)
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
|
||||
|
||||
def find_libs(search_path):
|
||||
"""Search for libnvinfer.so in "search_path"."""
|
||||
fl = set()
|
||||
if os.path.exists(search_path) and os.path.isdir(search_path):
|
||||
fl.update([
|
||||
os.path.realpath(os.path.join(search_path, x))
|
||||
for x in os.listdir(search_path)
|
||||
if 'libnvinfer.so' in x
|
||||
])
|
||||
return fl
|
||||
|
||||
possible_files = find_libs(trt_install_path)
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
|
||||
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
|
||||
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
|
||||
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
|
||||
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
|
||||
highest_ver = [0, None, None]
|
||||
|
||||
for lib_file in possible_files:
|
||||
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
|
||||
matches = nvinfer_pattern.search(lib_file)
|
||||
if not matches.groups():
|
||||
continue
|
||||
ver_str = matches.group(1)
|
||||
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
|
||||
if ver > highest_ver[0]:
|
||||
highest_ver = [ver, ver_str, lib_file]
|
||||
if highest_ver[1] is not None:
|
||||
trt_install_path = os.path.dirname(highest_ver[2])
|
||||
tf_tensorrt_version = highest_ver[1]
|
||||
break
|
||||
|
||||
# Try another alternative from ldconfig.
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
ldconfig_output = run_shell([ldconfig_bin, '-p'])
|
||||
search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
|
||||
ldconfig_output)
|
||||
if search_result:
|
||||
libnvinfer_path_from_ldconfig = search_result.group(2)
|
||||
if os.path.exists(libnvinfer_path_from_ldconfig):
|
||||
if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
|
||||
cudnn_ver):
|
||||
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
|
||||
tf_tensorrt_version = search_result.group(1)
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
if possible_files:
|
||||
print('TensorRT libraries found in one the following directories',
|
||||
'are not compatible with selected cuda and cudnn installations')
|
||||
print(trt_install_path)
|
||||
print(os.path.join(trt_install_path, 'lib'))
|
||||
print(os.path.join(trt_install_path, 'lib64'))
|
||||
if search_result:
|
||||
print(libnvinfer_path_from_ldconfig)
|
||||
else:
|
||||
print(
|
||||
'Invalid path to TensorRT. None of the following files can be found:')
|
||||
print(trt_install_path)
|
||||
print(os.path.join(trt_install_path, 'lib'))
|
||||
print(os.path.join(trt_install_path, 'lib64'))
|
||||
if search_result:
|
||||
print(libnvinfer_path_from_ldconfig)
|
||||
|
||||
else:
|
||||
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
# Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
|
||||
environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
|
||||
write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
|
||||
ask_tensorrt_version = (
|
||||
'Please specify the TensorRT version you want to use. '
|
||||
'[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION
|
||||
tf_tensorrt_version = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version,
|
||||
_DEFAULT_TENSORRT_VERSION)
|
||||
environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
|
||||
write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
|
||||
|
||||
|
||||
def set_tf_nccl_install_path(environ_cp):
|
||||
"""Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
|
||||
|
||||
Args:
|
||||
environ_cp: copy of the os.environ.
|
||||
|
||||
Raises:
|
||||
ValueError: if this method was called under non-Linux platform.
|
||||
UserInputError: if user has provided invalid input multiple times.
|
||||
"""
|
||||
def set_tf_nccl_version(environ_cp):
|
||||
"""Set TF_NCCL_VERSION."""
|
||||
if not is_linux():
|
||||
raise ValueError('Currently NCCL is only supported on Linux platforms.')
|
||||
raise ValueError('Currently NCCL is only supported on Linux platform.')
|
||||
|
||||
if 'TF_NCCL_VERSION' in environ_cp:
|
||||
return
|
||||
|
||||
ask_nccl_version = (
|
||||
'Please specify the locally installed NCCL version you want to use. '
|
||||
'[Default is to use https://github.com/nvidia/nccl]: ')
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
tf_nccl_version = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '')
|
||||
|
||||
if not tf_nccl_version:
|
||||
break # No need to get install path, building the open source code.
|
||||
|
||||
tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
|
||||
|
||||
# Look with ldconfig first if we can find the library in paths
|
||||
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
|
||||
# include directory. This is where the NCCL .deb packages install them.
|
||||
|
||||
# First check to see if NCCL is in the ldconfig.
|
||||
# If its found, use that location.
|
||||
if is_linux():
|
||||
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
|
||||
nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
|
||||
nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
|
||||
nccl2_path_from_ldconfig)
|
||||
if nccl2_path_from_ldconfig:
|
||||
nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
|
||||
if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
|
||||
nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
|
||||
print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
|
||||
|
||||
# Check if this is the main system lib location
|
||||
if re.search('.*linux-gnu', nccl_install_path):
|
||||
trunc_nccl_install_path = '/usr'
|
||||
print('This looks like a system path.')
|
||||
else:
|
||||
trunc_nccl_install_path = nccl_install_path + '/..'
|
||||
|
||||
# Look for header
|
||||
nccl_hdr_path = trunc_nccl_install_path + '/include'
|
||||
print('Assuming NCCL header path is ' + nccl_hdr_path)
|
||||
if os.path.exists(nccl_hdr_path + '/nccl.h'):
|
||||
# Set NCCL_INSTALL_PATH
|
||||
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
|
||||
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
|
||||
|
||||
# Set NCCL_HDR_PATH
|
||||
environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
|
||||
write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
|
||||
break
|
||||
else:
|
||||
print(
|
||||
'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
|
||||
)
|
||||
else:
|
||||
print('NCCL2 is listed by ldconfig but the library is not found. '
|
||||
'Your ldconfig is out of date. Please run sudo ldconfig.')
|
||||
else:
|
||||
# NCCL is not found in ldconfig. Ask the user for the location.
|
||||
default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
|
||||
ask_nccl_path = (
|
||||
r'Please specify the location where NCCL %s library is '
|
||||
'installed. Refer to README.md for more details. [Default '
|
||||
'is %s]:') % (tf_nccl_version, default_nccl_path)
|
||||
nccl_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
|
||||
|
||||
# Result returned from "read" will be used unexpanded. That make "~"
|
||||
# unusable. Going through one more level of expansion to handle that.
|
||||
nccl_install_path = os.path.realpath(
|
||||
os.path.expanduser(nccl_install_path))
|
||||
if is_windows() or is_cygwin():
|
||||
nccl_install_path = cygpath(nccl_install_path)
|
||||
|
||||
nccl_lib_path = ''
|
||||
if is_windows():
|
||||
nccl_lib_path = 'lib/x64/nccl.lib'
|
||||
elif is_linux():
|
||||
nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
|
||||
nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
|
||||
if not os.path.exists(nccl_lpath):
|
||||
for relative_path in NCCL_LIB_PATHS:
|
||||
path = '%s/%s%s' % (nccl_install_path, relative_path,
|
||||
nccl_lib_filename)
|
||||
if os.path.exists(path):
|
||||
print('NCCL found at ' + path)
|
||||
nccl_lib_path = path
|
||||
break
|
||||
else:
|
||||
nccl_lib_path = nccl_lpath
|
||||
elif is_macos():
|
||||
nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
|
||||
|
||||
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
|
||||
nccl_hdr_path = os.path.join(
|
||||
os.path.dirname(nccl_lib_path), '../include/nccl.h')
|
||||
print('Assuming NCCL header path is ' + nccl_hdr_path)
|
||||
if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
|
||||
# Set NCCL_INSTALL_PATH
|
||||
environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
|
||||
write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
|
||||
os.path.dirname(nccl_lib_path))
|
||||
|
||||
# Set NCCL_HDR_PATH
|
||||
environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
|
||||
write_action_env_to_bazelrc('NCCL_HDR_PATH',
|
||||
os.path.dirname(nccl_hdr_path))
|
||||
break
|
||||
|
||||
# Reset and Retry
|
||||
print(
|
||||
'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
|
||||
'O/S agnostic package of NCCL 2' %
|
||||
(tf_nccl_version, nccl_lib_path, nccl_hdr_path))
|
||||
|
||||
environ_cp['TF_NCCL_VERSION'] = ''
|
||||
else:
|
||||
raise UserInputError('Invalid TF_NCCL setting was provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
# Set TF_NCCL_VERSION
|
||||
'[Leave empty to use http://github.com/nvidia/nccl]: ')
|
||||
tf_nccl_version = get_from_env_or_user_or_default(environ_cp,
|
||||
'TF_NCCL_VERSION',
|
||||
ask_nccl_version, '')
|
||||
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
|
||||
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
|
||||
|
||||
|
||||
def get_native_cuda_compute_capabilities(environ_cp):
|
||||
"""Get native cuda compute capabilities.
|
||||
@ -1607,6 +1307,66 @@ def configure_ios():
|
||||
symlink_force(filepath, new_filepath)
|
||||
|
||||
|
||||
def validate_cuda_config(environ_cp):
|
||||
"""Run find_cuda_config.py and return cuda_toolkit_path, or None."""
|
||||
|
||||
def maybe_encode_env(env):
|
||||
"""Encodes unicode in env to str on Windows python 2.x."""
|
||||
if not is_windows() or sys.version_info[0] != 2:
|
||||
return env
|
||||
for k, v in env.items():
|
||||
if isinstance(k, unicode):
|
||||
k = k.encode('ascii')
|
||||
if isinstance(v, unicode):
|
||||
v = v.encode('ascii')
|
||||
env[k] = v
|
||||
return env
|
||||
|
||||
cuda_libraries = ['cuda', 'cudnn']
|
||||
if is_linux():
|
||||
if 'TF_TENSORRT_VERSION' in environ_cp: # if env variable exists
|
||||
cuda_libraries.append('tensorrt')
|
||||
if environ_cp.get('TF_NCCL_VERSION', None): # if env variable not empty
|
||||
cuda_libraries.append('nccl')
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] +
|
||||
cuda_libraries,
|
||||
stdout=subprocess.PIPE,
|
||||
env=maybe_encode_env(environ_cp))
|
||||
|
||||
if proc.wait():
|
||||
# Errors from find_cuda_config.py were sent to stderr.
|
||||
print('Asking for detailed CUDA configuration...\n')
|
||||
return False
|
||||
|
||||
config = dict(
|
||||
tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout)
|
||||
|
||||
print('Found CUDA %s in:' % config['cuda_version'])
|
||||
print(' %s' % config['cuda_library_dir'])
|
||||
print(' %s' % config['cuda_include_dir'])
|
||||
|
||||
print('Found cuDNN %s in:' % config['cudnn_version'])
|
||||
print(' %s' % config['cudnn_library_dir'])
|
||||
print(' %s' % config['cudnn_include_dir'])
|
||||
|
||||
if 'tensorrt_version' in config:
|
||||
print('Found TensorRT %s in:' % config['tensorrt_version'])
|
||||
print(' %s' % config['tensorrt_library_dir'])
|
||||
print(' %s' % config['tensorrt_include_dir'])
|
||||
|
||||
if config.get('nccl_version', None):
|
||||
print('Found NCCL %s in:' % config['nccl_version'])
|
||||
print(' %s' % config['nccl_library_dir'])
|
||||
print(' %s' % config['nccl_include_dir'])
|
||||
|
||||
print('\n')
|
||||
|
||||
environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path']
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
global _TF_WORKSPACE_ROOT
|
||||
global _TF_BAZELRC
|
||||
@ -1627,7 +1387,7 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
current_bazel_version = check_bazel_version('0.22.0', '0.24.0')
|
||||
current_bazel_version = check_bazel_version('0.24.1', '0.25.0')
|
||||
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)
|
||||
|
||||
reset_tf_configure_bazelrc()
|
||||
@ -1683,11 +1443,39 @@ def main():
|
||||
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
'TF_CUDA_CONFIG_REPO' not in environ_cp):
|
||||
set_tf_cuda_version(environ_cp)
|
||||
set_tf_cudnn_version(environ_cp)
|
||||
if is_linux():
|
||||
set_tf_tensorrt_install_path(environ_cp)
|
||||
set_tf_nccl_install_path(environ_cp)
|
||||
|
||||
set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
|
||||
|
||||
environ_save = dict(environ_cp)
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
|
||||
if validate_cuda_config(environ_cp):
|
||||
cuda_env_names = [
|
||||
'TF_CUDA_VERSION', 'TF_CUBLAS_VERSION', 'TF_CUDNN_VERSION',
|
||||
'TF_TENSORRT_VERSION', 'TF_NCCL_VERSION', 'TF_CUDA_PATHS',
|
||||
'CUDA_TOOLKIT_PATH'
|
||||
]
|
||||
for name in cuda_env_names:
|
||||
if name in environ_cp:
|
||||
write_action_env_to_bazelrc(name, environ_cp[name])
|
||||
break
|
||||
|
||||
# Restore settings changed below if CUDA config could not be validated.
|
||||
environ_cp = dict(environ_save)
|
||||
|
||||
set_tf_cuda_version(environ_cp)
|
||||
set_tf_cudnn_version(environ_cp)
|
||||
if is_linux():
|
||||
set_tf_tensorrt_version(environ_cp)
|
||||
set_tf_nccl_version(environ_cp)
|
||||
|
||||
set_tf_cuda_paths(environ_cp)
|
||||
|
||||
else:
|
||||
raise UserInputError(
|
||||
'Invalid CUDA setting were provided %d '
|
||||
'times in a row. Assuming to be a scripting mistake.' %
|
||||
_DEFAULT_PROMPT_ASK_ATTEMPTS)
|
||||
|
||||
set_tf_cuda_compute_capabilities(environ_cp)
|
||||
if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
|
||||
@ -1755,10 +1543,8 @@ def main():
|
||||
|
||||
system_specific_test_config(os.environ)
|
||||
|
||||
if get_var(environ_cp, 'TF_CONFIGURE_IOS', 'Configure TensorFlow for iOS',
|
||||
False, ('Would you like to configure TensorFlow for iOS builds?'),
|
||||
'Configuring TensorFlow for iOS builds.',
|
||||
'Not configuring TensorFlow for iOS builds.'):
|
||||
set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False)
|
||||
if environ_cp.get('TF_CONFIGURE_IOS') == '1':
|
||||
configure_ios()
|
||||
else:
|
||||
# TODO(pcloudy): Remove BAZEL_USE_CPP_ONLY_TOOLCHAIN after Bazel is upgraded
|
||||
|
@ -12,7 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Bring in all of the public TensorFlow interface into this module."""
|
||||
"""
|
||||
Top-level module of TensorFlow. By convention, we refer to this module as
|
||||
`tf` instead of `tensorflow`, following the common practice of importing
|
||||
TensorFlow via the command `import tensorflow as tf`.
|
||||
|
||||
The primary function of this module is to import all of the public TensorFlow
|
||||
interfaces into a single place. The interfaces themselves are located in
|
||||
sub-modules, as described below.
|
||||
|
||||
Note that the file `__init__.py` in the TensorFlow source code tree is actually
|
||||
only a placeholder to enable test cases to run. The TensorFlow build replaces
|
||||
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import as _absolute_import
|
||||
from __future__ import division as _division
|
||||
|
@ -118,7 +118,11 @@ if _running_from_pip_package():
|
||||
# pylint: disable=undefined-variable
|
||||
try:
|
||||
del python
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('python')
|
||||
del core
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('core')
|
||||
except NameError:
|
||||
# Don't fail if these modules are not available.
|
||||
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
|
||||
@ -129,6 +133,8 @@ except NameError:
|
||||
# others don't exist.
|
||||
try:
|
||||
del compiler
|
||||
if '__all__' in vars():
|
||||
vars()['__all__'].remove('compiler')
|
||||
except NameError:
|
||||
pass
|
||||
# pylint: enable=undefined-variable
|
||||
|
@ -104,6 +104,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/cc/saved_model:loader_lite",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
|
@ -799,8 +799,8 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
const auto& op_type = op->operation.Name();
|
||||
auto op_name =
|
||||
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
||||
auto* desc =
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str());
|
||||
std::unique_ptr<TF_OperationDescription> desc(
|
||||
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
|
||||
|
||||
VLOG(1) << "Adding attrs.";
|
||||
tensorflow::AttrValueMap attrs;
|
||||
@ -814,30 +814,42 @@ TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
||||
size_t inputIndex = 0;
|
||||
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
||||
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
||||
// TODO(bgogul): Add support for number attributes.
|
||||
DCHECK(input_arg.number_attr().empty())
|
||||
<< "Number attributes is not implemented yet.";
|
||||
if (input_arg.type_list_attr().empty()) {
|
||||
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
|
||||
auto symbolic_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TF_AddInput(desc, symbolic_input);
|
||||
TF_AddInput(desc.get(), symbolic_input);
|
||||
continue;
|
||||
}
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
DCHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
std::vector<TF_Output> list_inputs(attr_value.list().type_size());
|
||||
size_t list_size = 0;
|
||||
if (!input_arg.type_list_attr().empty()) {
|
||||
const std::string& type_list_attr = input_arg.type_list_attr();
|
||||
const auto& attr_value = attrs[type_list_attr];
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
||||
<< "Type list attribute should be a list!";
|
||||
list_size = attr_value.list().type_size();
|
||||
} else {
|
||||
CHECK(!input_arg.number_attr().empty());
|
||||
const auto& attr_value = attrs[input_arg.number_attr()];
|
||||
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
|
||||
<< "Number attribute should be int!";
|
||||
if (attr_value.i() < 0) {
|
||||
status->status = tensorflow::errors::Internal(
|
||||
"Number attribute for length should be >=0!");
|
||||
return nullptr;
|
||||
}
|
||||
list_size = attr_value.i();
|
||||
}
|
||||
std::vector<TF_Output> list_inputs(list_size);
|
||||
for (TF_Output& list_input : list_inputs) {
|
||||
list_input =
|
||||
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
TF_AddInputList(desc, list_inputs.data(), list_inputs.size());
|
||||
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
|
||||
}
|
||||
|
||||
auto* graph_op = TF_FinishOperation(desc, status);
|
||||
auto* graph_op = TF_FinishOperation(desc.release(), status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
VLOG(1) << "Op finalized; setting return tensors.";
|
||||
|
@ -376,5 +376,60 @@ TEST_F(AddEagerOpToGraphTest, ListInputsAreAddedCorrectly) {
|
||||
TFE_DeleteOp(identityn);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest, NumberAttributesAreHandledCorrectly) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", 2);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
constexpr size_t kNumInputs = 2;
|
||||
for (size_t i = 0; i < kNumInputs; ++i) {
|
||||
TFE_OpAddInput(concatv2, matrix, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
}
|
||||
TFE_OpAddInput(concatv2, axis, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
AddEagerOpToGraphAndCheck(
|
||||
concatv2, [this, kNumInputs](TF_Operation* graph_op) {
|
||||
EXPECT_EQ(TF_OperationNumInputs(graph_op), kNumInputs + 1);
|
||||
int64_t attrN;
|
||||
TF_OperationGetAttrInt(graph_op, "N", &attrN, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
EXPECT_EQ(attrN, kNumInputs);
|
||||
EXPECT_EQ(TF_OperationInputListLength(graph_op, "values", status_),
|
||||
kNumInputs);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
});
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
TFE_DeleteOp(concatv2);
|
||||
}
|
||||
|
||||
TEST_F(AddEagerOpToGraphTest,
|
||||
GeneratesInternalErrorsForInvalidNumberAttributes) {
|
||||
TFE_TensorHandle* matrix = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
int num_retvals = 5;
|
||||
TFE_TensorHandle* retvals[5];
|
||||
|
||||
TFE_Op* concatv2 = TFE_NewOp(eager_ctx_, "ConcatV2", status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
TFE_OpSetAttrType(concatv2, "T", TF_FLOAT);
|
||||
TFE_OpSetAttrInt(concatv2, "N", -1);
|
||||
TFE_OpSetAttrType(concatv2, "Tidx", TF_INT32);
|
||||
|
||||
TF_Operation* graph_op = TFE_AddEagerOpToGraph(concatv2, trace_ctx_, retvals,
|
||||
&num_retvals, status_);
|
||||
EXPECT_EQ(graph_op, nullptr);
|
||||
EXPECT_EQ(status_->status.error_message(),
|
||||
"Number attribute for length should be >=0!");
|
||||
|
||||
TFE_DeleteOp(concatv2);
|
||||
TFE_DeleteTensorHandle(axis);
|
||||
TFE_DeleteTensorHandle(matrix);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -352,6 +352,16 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
|
||||
argdef->set_type(node->output_type(idx));
|
||||
const string& input_name = node_names.GetInputName(node->name());
|
||||
argdef->set_name(input_name);
|
||||
auto& arg_attrs = (*fdef->mutable_arg_attr())[i];
|
||||
for (const auto& attr : node->attrs()) {
|
||||
// Only copy internal attributes. These attributes will be applied to
|
||||
// _Arg/Placeholder nodes when this FunctionDef is converted to graph, and
|
||||
// normal attributes for nodes cannot be applied to those _Arg/Placeholder
|
||||
// nodes.
|
||||
if (absl::StartsWith(attr.first, "_")) {
|
||||
arg_attrs.mutable_attr()->insert(attr);
|
||||
}
|
||||
}
|
||||
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
|
||||
}
|
||||
|
||||
|
@ -1278,6 +1278,46 @@ TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
||||
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
|
||||
}
|
||||
|
||||
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
||||
const char* attr_name, const char* attr_value,
|
||||
TF_Operation** op) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
||||
TF_SetAttrType(desc, "dtype", TF_INT32);
|
||||
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
|
||||
*op = TF_FinishOperation(desc, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
ASSERT_NE(*op, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
||||
TF_NewGraph(), TF_DeleteGraph);
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
||||
TF_DeleteStatus);
|
||||
|
||||
TF_Operation* node;
|
||||
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
|
||||
&node);
|
||||
|
||||
TF_Output inputs[] = {{node, 0}};
|
||||
TF_Output outputs[] = {};
|
||||
func_ = TF_GraphToFunction(
|
||||
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 1, inputs, 0, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
||||
ASSERT_NE(func_, nullptr);
|
||||
|
||||
// Verify that FunctionDef ArgDef has attributes.
|
||||
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
|
||||
auto arg_attrs = func_->fdef.arg_attr().find(0);
|
||||
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
|
||||
auto iter = arg_attrs->second.attr().find("_test_attr");
|
||||
ASSERT_NE(iter, arg_attrs->second.attr().end());
|
||||
EXPECT_EQ(iter->second.s(), "value");
|
||||
}
|
||||
|
||||
TEST_F(CApiFunctionTest, SetGradientAndRun) {
|
||||
// Define the function and its grad
|
||||
DefineFunction(func_name_, &func_);
|
||||
|
@ -29,8 +29,7 @@ namespace checkpoint {
|
||||
|
||||
class TensorSliceReader;
|
||||
|
||||
CheckpointReader::CheckpointReader(const string& filename,
|
||||
TF_Status* out_status)
|
||||
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
|
||||
: reader_(nullptr),
|
||||
v2_reader_(nullptr),
|
||||
var_to_shape_map_(nullptr),
|
||||
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
v2_reader_.reset(
|
||||
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
|
||||
if (!v2_reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, v2_reader_->status());
|
||||
Set_TF_Status_from_Status(status, v2_reader_->status());
|
||||
return;
|
||||
}
|
||||
auto result = BuildV2VarMaps();
|
||||
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
} else {
|
||||
reader_.reset(new TensorSliceReader(filename));
|
||||
if (!reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, reader_->status());
|
||||
Set_TF_Status_from_Status(status, reader_->status());
|
||||
return;
|
||||
}
|
||||
var_to_shape_map_.reset(
|
||||
|
@ -39,7 +39,7 @@ class TensorSliceReader;
|
||||
// variables.
|
||||
class CheckpointReader {
|
||||
public:
|
||||
CheckpointReader(const string& filepattern, TF_Status* out_status);
|
||||
CheckpointReader(const string& filename, TF_Status* status);
|
||||
|
||||
bool HasTensor(const string& name) const;
|
||||
const string DebugString() const;
|
||||
|
@ -70,6 +70,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/profiler/lib:profiler_eager_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
@ -110,6 +111,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/profiler/lib:profiler_eager_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
],
|
||||
)
|
||||
@ -200,6 +202,7 @@ tf_cuda_library(
|
||||
"//conditions:default": [],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
@ -236,7 +239,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/profiler:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -143,7 +143,9 @@ tensorflow::Status CreateRemoteContexts(
|
||||
request.mutable_server_def()->set_task_index(parsed_name.task);
|
||||
request.set_async(async);
|
||||
request.set_keep_alive_secs(keep_alive_secs);
|
||||
auto* eager_client = remote_eager_workers->GetClient(remote_worker);
|
||||
tensorflow::eager::EagerClient* eager_client;
|
||||
TF_RETURN_IF_ERROR(
|
||||
remote_eager_workers->GetClient(remote_worker, &eager_client));
|
||||
if (eager_client == nullptr) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot find a client for the given target:", remote_worker);
|
||||
|
@ -17,6 +17,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
|
||||
@ -92,3 +98,123 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
num_tracing_attempts);
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
static tensorflow::mutex gauges_map_lock(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
static std::unordered_map<string,
|
||||
tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
|
||||
get_gauges_map() EXCLUSIVE_LOCKS_REQUIRED(gauges_map_lock) {
|
||||
static std::unordered_map<
|
||||
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>*
|
||||
gauges_map = new std::unordered_map<
|
||||
string, tensorflow::monitoring::Gauge<tensorflow::int64, 1>*>;
|
||||
return gauges_map;
|
||||
}
|
||||
|
||||
static tensorflow::mutex samplers_map_lock(tensorflow::LINKER_INITIALIZED);
|
||||
|
||||
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
|
||||
get_samplers_map() EXCLUSIVE_LOCKS_REQUIRED(samplers_map_lock) {
|
||||
static std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>*
|
||||
samplers_map =
|
||||
new std::unordered_map<string, tensorflow::monitoring::Sampler<1>*>;
|
||||
return samplers_map;
|
||||
}
|
||||
|
||||
void TFE_MonitoringSetGauge(const char* name, const char* label,
|
||||
int64_t value) {
|
||||
tensorflow::mutex_lock l(gauges_map_lock);
|
||||
auto gauges_map = get_gauges_map();
|
||||
if (gauges_map->find(name) == gauges_map->end()) {
|
||||
gauges_map->emplace(
|
||||
name, tensorflow::monitoring::Gauge<tensorflow::int64, 1>::New(
|
||||
name,
|
||||
tensorflow::strings::StrCat(
|
||||
name, " :Gauge metric collected from Python API."),
|
||||
"metric_descriptor"));
|
||||
}
|
||||
gauges_map->at(name)->GetCell(label)->Set(value);
|
||||
}
|
||||
|
||||
void TFE_MonitoringAddSampler(const char* name, const char* label,
|
||||
double value) {
|
||||
tensorflow::mutex_lock l(samplers_map_lock);
|
||||
auto samplers_map = get_samplers_map();
|
||||
if (samplers_map->find(name) == samplers_map->end()) {
|
||||
samplers_map->emplace(
|
||||
name, tensorflow::monitoring::Sampler<1>::New(
|
||||
{name,
|
||||
tensorflow::strings::StrCat(
|
||||
name, " :Counter metric collected from Python API."),
|
||||
"metric_descriptor"},
|
||||
{tensorflow::monitoring::Buckets::Exponential(1, 2, 30)}));
|
||||
}
|
||||
samplers_map->at(name)->GetCell(label)->Add(value);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
}
|
||||
|
||||
int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
|
||||
return cell->cell.value();
|
||||
}
|
||||
|
||||
TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description) {
|
||||
auto* result = new TFE_MonitoringCounter0({name, description});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
|
||||
delete counter;
|
||||
}
|
||||
|
||||
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
|
||||
TFE_MonitoringCounter0* counter) {
|
||||
return static_cast<TFE_MonitoringCounterCell*>(
|
||||
static_cast<void*>(counter->counter->GetCell()));
|
||||
}
|
||||
|
||||
TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1) {
|
||||
auto* result = new TFE_MonitoringCounter1({name, description, label1});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
|
||||
delete counter;
|
||||
}
|
||||
|
||||
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
|
||||
TFE_MonitoringCounter1* counter, const char* label1) {
|
||||
return static_cast<TFE_MonitoringCounterCell*>(
|
||||
static_cast<void*>(counter->counter->GetCell(label1)));
|
||||
}
|
||||
|
||||
TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
|
||||
TF_Status* status,
|
||||
const char* description,
|
||||
const char* label1,
|
||||
const char* label2) {
|
||||
auto* result =
|
||||
new TFE_MonitoringCounter2({name, description, label1, label2});
|
||||
Set_TF_Status_from_Status(status, result->counter->GetStatus());
|
||||
return result;
|
||||
}
|
||||
|
||||
void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
|
||||
delete counter;
|
||||
}
|
||||
|
||||
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
|
||||
TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
|
||||
return static_cast<TFE_MonitoringCounterCell*>(
|
||||
static_cast<void*>(counter->counter->GetCell(label1, label2)));
|
||||
}
|
||||
|
@ -87,6 +87,68 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
const char* service_addr, const char* logdir, const char* worker_list,
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts);
|
||||
|
||||
// Set the value of a Gauge metric. If the metric with given name does not
|
||||
// exist, it will create a new Gauge metric. Right now it only supports type
|
||||
// int64, consider to add more type supports if needed.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringSetGauge(const char* name,
|
||||
const char* label,
|
||||
int64_t value);
|
||||
|
||||
// Add the given value to a Sampler metric. If the metric with given name
|
||||
// does not exist, it will create a new Sampler metric.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringAddSampler(const char* name,
|
||||
const char* label,
|
||||
double value);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
// These APIs de-templated monitoring Counter for swig.
|
||||
|
||||
typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
|
||||
|
||||
// Atomically increments the value of the cell. The value must be non-negative.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
|
||||
TFE_MonitoringCounterCell* cell, int64_t value);
|
||||
|
||||
// Retrieves the current value of the cell.
|
||||
TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
|
||||
TFE_MonitoringCounterCell* cell);
|
||||
|
||||
// APIs for Counter without label.
|
||||
typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
|
||||
// Returns a new Counter metric object. The caller should manage lifetime of
|
||||
// the object. Using duplicate metric name will crash the program with fatal
|
||||
// error.
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
|
||||
const char* name, TF_Status* status, const char* description);
|
||||
// Deletes the Counter object.
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
|
||||
TFE_MonitoringCounter0* counter);
|
||||
// Retrieves the cell from the Counter object. The Counter object will manage
|
||||
// lifetime of the cell.
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
|
||||
TFE_MonitoringCounter0* counter);
|
||||
|
||||
// APIs for Counter with 1 label.
|
||||
typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
|
||||
const char* name, TF_Status* status, const char* description,
|
||||
const char* label1);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
|
||||
TFE_MonitoringCounter1* counter);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
|
||||
TFE_MonitoringCounter1* counter, const char* label1);
|
||||
|
||||
// APIs for Counter with 2 labels.
|
||||
typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
|
||||
const char* name, TF_Status* status, const char* description,
|
||||
const char* label1, const char* label2);
|
||||
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
|
||||
TFE_MonitoringCounter2* counter);
|
||||
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
|
||||
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -16,14 +16,16 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/profiler/trace_events.pb.h"
|
||||
#include "tensorflow/core/protobuf/trace_events.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -79,11 +81,15 @@ void ExecuteWithProfiling(bool async) {
|
||||
profiler_result->length}));
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "GPU:0"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
// device name with "stream:all" is collected by Device Tracer.
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all"));
|
||||
// TODO(fishx): move following check out from this if statement.
|
||||
// This is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
}
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "CPU:0"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:CPU:0"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
|
||||
TF_DeleteBuffer(profiler_result);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
@ -125,5 +131,94 @@ TEST(CAPI, MultipleProfilerSession) {
|
||||
TFE_DeleteProfilerContext(profiler_context);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringSetGauge) {
|
||||
TFE_MonitoringSetGauge("test/gauge", "label", 1);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
collection_registry->CollectMetrics(options);
|
||||
|
||||
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
|
||||
EXPECT_EQ(1,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
|
||||
TFE_MonitoringSetGauge("test/gauge", "label", 5);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(5,
|
||||
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounter0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* counter =
|
||||
TFE_MonitoringNewCounter0("test/counter", status, "description");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
auto* cell = TFE_MonitoringGetCellCounter0(counter);
|
||||
TFE_MonitoringCounterCellIncrementBy(cell, 1);
|
||||
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
collection_registry->CollectMetrics(options);
|
||||
|
||||
EXPECT_EQ("test/counter",
|
||||
metrics->point_set_map.at("test/counter")->metric_name);
|
||||
EXPECT_EQ(
|
||||
1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
|
||||
|
||||
TFE_MonitoringCounterCellIncrementBy(cell, 5);
|
||||
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(
|
||||
6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
|
||||
|
||||
TFE_MonitoringDeleteCounter0(counter);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(metrics->point_set_map.end(),
|
||||
metrics->point_set_map.find("test/counter"));
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounterMultiple) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
|
||||
"description", "label1");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
|
||||
TFE_MonitoringCounterCellIncrementBy(cell1, 1);
|
||||
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
|
||||
|
||||
auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
|
||||
"description", "label1", "label2");
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
|
||||
TFE_MonitoringCounterCellIncrementBy(cell2, 2);
|
||||
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
|
||||
|
||||
TFE_MonitoringDeleteCounter1(counter1);
|
||||
TFE_MonitoringDeleteCounter2(counter2);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringAddSampler) {
|
||||
TFE_MonitoringAddSampler("test/sampler", "label", 1.0);
|
||||
auto* collection_registry = monitoring::CollectionRegistry::Default();
|
||||
monitoring::CollectionRegistry::CollectMetricsOptions options;
|
||||
std::unique_ptr<monitoring::CollectedMetrics> metrics =
|
||||
collection_registry->CollectMetrics(options);
|
||||
|
||||
EXPECT_EQ("test/sampler",
|
||||
metrics->point_set_map.at("test/sampler")->metric_name);
|
||||
EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
|
||||
->points.at(0)
|
||||
->histogram_value.sum());
|
||||
|
||||
TFE_MonitoringAddSampler("test/sampler", "label", 5.0);
|
||||
metrics = collection_registry->CollectMetrics(options);
|
||||
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
|
||||
->points.at(0)
|
||||
->histogram_value.sum());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -15,8 +15,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
@ -28,6 +26,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -50,6 +49,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
@ -133,6 +133,32 @@ struct TFE_Profiler {
|
||||
std::unique_ptr<tensorflow::ProfilerSession> profiler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function as _print_function
|
||||
import os as _os
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
|
||||
# API IMPORTS PLACEHOLDER
|
||||
|
||||
|
@ -163,7 +163,10 @@ def tf_library(
|
||||
header_file = name + ".h"
|
||||
metadata_object_file = name + "_tfcompile_metadata.o"
|
||||
function_object_file = name + "_tfcompile_function.o"
|
||||
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
|
||||
|
||||
# The XLA backends morph kernal name prefix __ that is not in the form of
|
||||
# __xla_.
|
||||
ep = ("__xla_" + native.package_name() + "__" + name).replace("/", "_")
|
||||
if type(tfcompile_flags) == type(""):
|
||||
flags = tfcompile_flags
|
||||
else:
|
||||
|
@ -371,6 +371,7 @@ cc_library(
|
||||
srcs = ["resource_operation_safety_analysis.cc"],
|
||||
hdrs = ["resource_operation_safety_analysis.h"],
|
||||
deps = [
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/core:framework",
|
||||
@ -521,6 +522,7 @@ cc_library(
|
||||
":device_info_cache",
|
||||
":encapsulate_util",
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
":shape_inference_helpers",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
@ -565,7 +567,6 @@ cc_library(
|
||||
hdrs = ["xla_cluster_util.h"],
|
||||
deps = [
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -732,44 +733,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_fusion_optimizer",
|
||||
srcs = ["xla_fusion_optimizer.cc"],
|
||||
hdrs = ["xla_fusion_optimizer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilation_passes",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "xla_fusion_optimizer_test",
|
||||
srcs = ["xla_fusion_optimizer_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_cluster_util",
|
||||
":xla_fusion_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "node_matchers",
|
||||
testonly = True,
|
||||
|
@ -106,6 +106,8 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
using se::port::StatusOr;
|
||||
|
||||
// Represents a logical predicate, used as described in the algorithm overview
|
||||
// above.
|
||||
class Predicate {
|
||||
@ -698,7 +700,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
|
||||
Status Populate();
|
||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
||||
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
||||
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
||||
Node* n, int oidx) const override;
|
||||
void Print() const override;
|
||||
absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
|
||||
const;
|
||||
@ -1113,42 +1116,13 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
|
||||
CHECK(!node.IsMerge());
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
|
||||
}
|
||||
|
||||
Predicate* pred = nullptr;
|
||||
for (const Edge* edge : node.in_edges()) {
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
|
||||
CHECK(it != predicate_map_.end());
|
||||
if (vlog_) {
|
||||
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
|
||||
<< it->second->ToString();
|
||||
}
|
||||
|
||||
// Today we just compare the predicates for equality (with some
|
||||
// canonicalization/simplification happening before) but we could be more
|
||||
// sophisticated here if need be. Comparing pointers is sufficient because
|
||||
// we intern Predicate instances by their content.
|
||||
if (pred != nullptr && pred != it->second) {
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> true";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
pred = it->second;
|
||||
}
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> false";
|
||||
}
|
||||
|
||||
return false;
|
||||
StatusOr<DeadnessAnalysis::DeadnessPredicate>
|
||||
DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
|
||||
auto it = predicate_map_.find(TensorId(n->name(), oidx));
|
||||
TF_RET_CHECK(it != predicate_map_.end())
|
||||
<< "could not find " << TensorId(n->name(), oidx).ToString()
|
||||
<< " in predicate map";
|
||||
return MakeDeadnessPredicate(it->second);
|
||||
}
|
||||
|
||||
void DeadnessAnalysisImpl::Print() const {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -43,14 +44,38 @@ namespace tensorflow {
|
||||
// "liveness" already has other connotations.
|
||||
class DeadnessAnalysis {
|
||||
public:
|
||||
// Returns true if `node` may have some live inputs and some dead inputs.
|
||||
//
|
||||
// This is a conservatively correct routine -- if it returns false then `node`
|
||||
// is guaranteed to not have inputs with mismatching liveness, but not the
|
||||
// converse.
|
||||
//
|
||||
// REQUIRES: node is not a Merge operation.
|
||||
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
|
||||
// An opaque representation of a predicate. DeadnessPredicate
|
||||
// instances that compare equal via operator== represent predicates
|
||||
// that always evaluate to the same value.
|
||||
struct DeadnessPredicate {
|
||||
public:
|
||||
DeadnessPredicate(const DeadnessPredicate&) = default;
|
||||
DeadnessPredicate(DeadnessPredicate&&) = default;
|
||||
|
||||
DeadnessPredicate& operator=(const DeadnessPredicate&) = default;
|
||||
DeadnessPredicate& operator=(DeadnessPredicate&&) = default;
|
||||
|
||||
bool operator==(const DeadnessPredicate& other) const {
|
||||
return other.pred_ == pred_;
|
||||
}
|
||||
|
||||
bool operator!=(const DeadnessPredicate& other) const {
|
||||
return other.pred_ != pred_;
|
||||
}
|
||||
|
||||
private:
|
||||
explicit DeadnessPredicate(void* pred) : pred_(pred) {}
|
||||
|
||||
// This is really a Predicate*, but we don't want to expose that
|
||||
// implementation detail to our clients. `pred_` has pointer equality so we
|
||||
// can just compare the pointer in operator== and operator!=.
|
||||
void* pred_;
|
||||
|
||||
friend class DeadnessAnalysis;
|
||||
};
|
||||
|
||||
virtual se::port::StatusOr<DeadnessPredicate> GetPredicateFor(
|
||||
Node* n, int oidx) const = 0;
|
||||
|
||||
// Prints out the internal state of this instance. For debugging purposes
|
||||
// only.
|
||||
@ -61,6 +86,11 @@ class DeadnessAnalysis {
|
||||
// instance of DeadnessAnalysis in `result`.
|
||||
static Status Run(const Graph& graph,
|
||||
std::unique_ptr<DeadnessAnalysis>* result);
|
||||
|
||||
protected:
|
||||
static DeadnessPredicate MakeDeadnessPredicate(void* pred) {
|
||||
return DeadnessPredicate(pred);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -37,6 +37,22 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
se::port::StatusOr<bool> HasInputsWithMismatchingDeadness(
|
||||
const DeadnessAnalysis& deadness_analysis, const Node& n) {
|
||||
absl::optional<DeadnessAnalysis::DeadnessPredicate> pred;
|
||||
for (const Edge* edge : n.in_edges()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeadnessAnalysis::DeadnessPredicate this_pred,
|
||||
deadness_analysis.GetPredicateFor(edge->src(), edge->src_output()));
|
||||
if (pred && *pred != this_pred) {
|
||||
return true;
|
||||
}
|
||||
pred = this_pred;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
using deadness_analysis_internal::ComputePredicates;
|
||||
using deadness_analysis_internal::PredicateMapTy;
|
||||
|
||||
@ -219,7 +235,10 @@ TEST(DeadnessAnalysisTest, BasicPositive) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, BasicNegative) {
|
||||
@ -232,7 +251,10 @@ TEST(DeadnessAnalysisTest, BasicNegative) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsCommutative) {
|
||||
@ -260,11 +282,27 @@ TEST(DeadnessAnalysisTest, AndIsCommutative) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
bool has_inputs_with_mismatching_deadness;
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *live0.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *live1.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsAssociative) {
|
||||
@ -287,7 +325,10 @@ TEST(DeadnessAnalysisTest, AndIsAssociative) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsCommutative) {
|
||||
@ -312,11 +353,27 @@ TEST(DeadnessAnalysisTest, OrIsCommutative) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
bool has_inputs_with_mismatching_deadness;
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *live0.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *live1.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *halfdead0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *halfdead1.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsAssociative) {
|
||||
@ -336,7 +393,10 @@ TEST(DeadnessAnalysisTest, OrIsAssociative) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndOfOr) {
|
||||
@ -358,7 +418,10 @@ TEST(DeadnessAnalysisTest, AndOfOr) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add2.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrOfAnd) {
|
||||
@ -382,7 +445,10 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add2.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
|
||||
@ -430,7 +496,10 @@ TEST(DeadnessAnalysisTest, AndOrDistributive) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add3.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Ternary) {
|
||||
@ -454,7 +523,10 @@ TEST(DeadnessAnalysisTest, Ternary) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Recv) {
|
||||
@ -469,7 +541,10 @@ TEST(DeadnessAnalysisTest, Recv) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, HostRecv) {
|
||||
@ -484,7 +559,10 @@ TEST(DeadnessAnalysisTest, HostRecv) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Loop) {
|
||||
@ -505,8 +583,17 @@ TEST(DeadnessAnalysisTest, Loop) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
|
||||
bool has_inputs_with_mismatching_deadness;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add1.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
@ -544,7 +631,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
@ -634,7 +724,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
@ -693,7 +786,10 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
{
|
||||
@ -792,7 +888,10 @@ TEST(DeadnessAnalysisTest, ControlInputs) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlTrigger) {
|
||||
@ -819,7 +918,10 @@ TEST(DeadnessAnalysisTest, ControlTrigger) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
|
||||
@ -840,7 +942,10 @@ TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add.node()));
|
||||
EXPECT_FALSE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
|
||||
@ -857,7 +962,10 @@ TEST(DeadnessAnalysisTest, RecvVsSwitch) {
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *logical_and.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RecvVsSwitchText) {
|
||||
|
@ -599,7 +599,8 @@ Status ConstructHostGraph(
|
||||
Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
|
||||
FunctionLibraryDefinition* fld,
|
||||
const string& host_graph_func_name,
|
||||
Node* xla_computation_node) {
|
||||
Node* xla_computation_node,
|
||||
Node* pivot_node) {
|
||||
// Temporarily use "0" as "device_ordinal". It will be rewritten with the
|
||||
// correct value in a later pass. We cannot just use placeholder value here
|
||||
// because FunctionDef instantiation does not allow placeholder value for
|
||||
@ -620,7 +621,11 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
|
||||
|
||||
// Copy all nodes.
|
||||
std::map<const Node*, Node*> node_map;
|
||||
node_map[host_graph->source_node()] = main_graph->source_node();
|
||||
if (pivot_node) {
|
||||
node_map[host_graph->source_node()] = pivot_node;
|
||||
} else {
|
||||
node_map[host_graph->source_node()] = main_graph->source_node();
|
||||
}
|
||||
node_map[host_graph->sink_node()] = main_graph->sink_node();
|
||||
Status s = Status::OK();
|
||||
auto copy_node_fn = [&](const Node* n) {
|
||||
@ -673,7 +678,7 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
|
||||
// 2) Remove control edges.
|
||||
// 3) Prune nodes that are not useful for shape inference.
|
||||
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
|
||||
Graph* host_graph,
|
||||
Graph* host_graph, Node* pivot_node,
|
||||
FunctionLibraryDefinition* fld) {
|
||||
// Use "0" as "device_ordinal". It does not matter for shape inference.
|
||||
AttrValue device_ordinal_attr;
|
||||
@ -717,41 +722,45 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
|
||||
for (Node* n : nodes) {
|
||||
g->RemoveNode(n);
|
||||
}
|
||||
|
||||
std::map<const Node*, Node*> node_map;
|
||||
node_map[host_graph->source_node()] = g->source_node();
|
||||
Status s;
|
||||
auto copy_node_fn = [&](const Node* n) {
|
||||
if (!s.ok()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (node_map.find(n) != node_map.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
NodeDef copy_def = n->def();
|
||||
Node* copy = g->AddNode(copy_def, &s);
|
||||
if (!s.ok()) {
|
||||
return;
|
||||
}
|
||||
for (auto e : n->in_edges()) {
|
||||
if (node_map.find(e->src()) == node_map.end()) {
|
||||
s = errors::Internal("Cannot find node image for ",
|
||||
e->src()->DebugString());
|
||||
return;
|
||||
}
|
||||
g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input());
|
||||
}
|
||||
|
||||
node_map[n] = copy;
|
||||
Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
|
||||
// Reverse DFS from send_from_host_main_graph, and stop at start_node.
|
||||
struct Visit {
|
||||
Node* n;
|
||||
bool is_exiting;
|
||||
};
|
||||
// TODO(b/77601805): consolidate copy graph functions.
|
||||
ReverseDFSFrom(*host_graph,
|
||||
std::vector<const Node*>{send_from_host_main_graph},
|
||||
/*enter=*/nullptr, copy_node_fn, NodeComparatorID());
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
std::vector<Visit> stack{{send_from_host_main_graph, false}};
|
||||
std::map<Node*, Node*> node_map;
|
||||
node_map[host_graph->source_node()] = g->source_node();
|
||||
while (!stack.empty()) {
|
||||
Visit& curr = stack.back();
|
||||
if (curr.is_exiting) {
|
||||
if (node_map.find(curr.n) == node_map.end()) {
|
||||
Node* copy = g->CopyNode(curr.n);
|
||||
if (curr.n != start_node) {
|
||||
for (const Edge* e : curr.n->in_edges()) {
|
||||
auto node_iter = node_map.find(e->src());
|
||||
if (node_iter == node_map.end()) {
|
||||
return errors::Internal("Cannot find node image for ",
|
||||
e->src()->DebugString());
|
||||
}
|
||||
g->AddEdge(node_iter->second, e->src_output(), copy,
|
||||
e->dst_input());
|
||||
}
|
||||
}
|
||||
node_map[curr.n] = copy;
|
||||
}
|
||||
stack.pop_back();
|
||||
} else {
|
||||
curr.is_exiting = true;
|
||||
if (curr.n != start_node) {
|
||||
for (const Edge* e : curr.n->in_edges()) {
|
||||
if (node_map.find(e->src()) != node_map.end()) {
|
||||
continue;
|
||||
}
|
||||
stack.push_back({e->src(), false});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
send_from_host = node_map[send_from_host_main_graph];
|
||||
@ -1687,13 +1696,14 @@ Status ExtractOutsideCompilation(
|
||||
DumpGraphToFile("extract_outside_compilation_before", *g, fld);
|
||||
}
|
||||
|
||||
std::vector<string> shape_inference_graphs;
|
||||
auto node_name_index = g->BuildNodeNameIndex();
|
||||
for (auto& iter : clusters) {
|
||||
string xla_cluster_name = iter.first;
|
||||
Node* n = iter.second.node;
|
||||
auto const& func_name_attrs = iter.second.func_name_attrs;
|
||||
auto const& host_compute_core = iter.second.host_compute_core;
|
||||
|
||||
std::vector<string> shape_inference_graphs;
|
||||
bool has_outside_compilation;
|
||||
string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
|
||||
TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
|
||||
@ -1701,14 +1711,18 @@ Status ExtractOutsideCompilation(
|
||||
func_name_attrs, func_name_attrs.name(), host_graph_func_name,
|
||||
host_compute_core, flr, fld, &shape_inference_graphs,
|
||||
&has_outside_compilation));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
|
||||
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
|
||||
}
|
||||
|
||||
for (auto shape_inference_graph_name : shape_inference_graphs) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
|
||||
string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
|
||||
Node* pivot_node = node_name_index[pivot_name];
|
||||
TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
|
||||
g, fld, host_graph_func_name, n, pivot_node));
|
||||
|
||||
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
|
||||
|
||||
for (auto shape_inference_graph_name : shape_inference_graphs) {
|
||||
TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
|
||||
g, pivot_node, fld));
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(4)) {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -41,7 +41,8 @@ class MarkForCompilationPass : public GraphOptimizationPass {
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
private:
|
||||
Status RunForTest(const GraphOptimizationPassOptions& options);
|
||||
Status RunForTest(const GraphOptimizationPassOptions& options,
|
||||
bool disable_deadness_analysis);
|
||||
|
||||
friend class MarkForCompilationPassTestHelper;
|
||||
};
|
||||
|
@ -525,8 +525,8 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
|
||||
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(
|
||||
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
// The computation is: C = A + relu(A)
|
||||
@ -564,8 +564,8 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
|
||||
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(
|
||||
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
// The computation is: D = relu(A) + (A @ relu(A))
|
||||
@ -598,8 +598,8 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
|
||||
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(
|
||||
MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
// The computation is: C = A @ relu(A)
|
||||
@ -610,6 +610,77 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
|
||||
Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
|
||||
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
|
||||
ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
|
||||
ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
|
||||
|
||||
Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
|
||||
Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
|
||||
|
||||
Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
|
||||
Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_NE(clusters["tan_a0"], "");
|
||||
EXPECT_NE(clusters["tan_a1"], "");
|
||||
EXPECT_NE(clusters["tan_b0"], "");
|
||||
EXPECT_NE(clusters["tan_b1"], "");
|
||||
|
||||
EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
|
||||
EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
|
||||
|
||||
EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
|
||||
Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
|
||||
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
|
||||
ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
|
||||
ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
|
||||
|
||||
Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
|
||||
switch_b.output_true);
|
||||
Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
|
||||
switch_b.output_true);
|
||||
Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_NE(clusters["add_a"], "");
|
||||
EXPECT_NE(clusters["add_b"], "");
|
||||
EXPECT_NE(clusters["add_c"], "");
|
||||
|
||||
EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
|
||||
EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
|
||||
}
|
||||
|
||||
namespace {
|
||||
Node* MakeRead(const Scope& scope, const string& id,
|
||||
Node** var_handle_op = nullptr) {
|
||||
@ -703,7 +774,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
|
||||
ASSERT_EQ(cluster_sets.size(), 1);
|
||||
|
||||
std::vector<string> expected_clustered_nodes_a = {
|
||||
"AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
|
||||
"AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
|
||||
ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
bool enable_global_jit) {
|
||||
MarkForCompilationPassTestHelper::Options options) {
|
||||
// Assign all unassigned nodes to the CPU device.
|
||||
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
for (Node* n : (*graph)->nodes()) {
|
||||
@ -31,7 +31,7 @@ namespace tensorflow {
|
||||
}
|
||||
|
||||
SessionOptions session_options;
|
||||
if (enable_global_jit) {
|
||||
if (options.enable_global_jit) {
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_global_jit_level(OptimizerOptions::ON_2);
|
||||
@ -49,13 +49,16 @@ namespace tensorflow {
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.flib_def = flib_def;
|
||||
MarkForCompilationPass pass;
|
||||
return pass.RunForTest(opt_options);
|
||||
return pass.RunForTest(
|
||||
opt_options,
|
||||
/*disable_deadness_analysis=*/options.disable_deadness_analysis);
|
||||
}
|
||||
|
||||
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
std::unique_ptr<Graph>* graph, bool enable_global_jit) {
|
||||
std::unique_ptr<Graph>* graph,
|
||||
MarkForCompilationPassTestHelper::Options options) {
|
||||
FunctionDefLibrary flib;
|
||||
FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
|
||||
return MarkForCompilation(graph, &flib_def, enable_global_jit);
|
||||
return MarkForCompilation(graph, &flib_def, options);
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -21,16 +21,35 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
class MarkForCompilationPassTestHelper {
|
||||
public:
|
||||
struct Options {
|
||||
bool enable_global_jit;
|
||||
bool disable_deadness_analysis;
|
||||
|
||||
Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
|
||||
|
||||
Options WithNoGlobalJit() {
|
||||
Options copy = *this;
|
||||
copy.enable_global_jit = false;
|
||||
return copy;
|
||||
}
|
||||
|
||||
Options WithDeadnessAnalysis() {
|
||||
Options copy = *this;
|
||||
copy.disable_deadness_analysis = false;
|
||||
return copy;
|
||||
}
|
||||
};
|
||||
|
||||
// Runs the MarkForCompilation pass on `graph` after assigning all nodes in
|
||||
// `graph` to the CPU device. To make testing easier, ignores device
|
||||
// registration, _XlaCompile attributes and input deadness.
|
||||
// registration and _XlaCompile attributes.
|
||||
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
bool enable_global_jit = true);
|
||||
Options options = Options());
|
||||
|
||||
// Like `MarkForCompilation` but creates `flib_def` from the op registry.
|
||||
static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
|
||||
bool enable_global_jit = true);
|
||||
Options options = Options());
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -84,6 +84,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -93,22 +94,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// Returns true if `n` may call a function.
|
||||
Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
|
||||
bool* out_result) {
|
||||
if (flib_def->Contains(n.type_string())) {
|
||||
*out_result = true;
|
||||
} else {
|
||||
*out_result =
|
||||
std::any_of(n.def().attr().begin(), n.def().attr().end(),
|
||||
[](const std::pair<string, AttrValue>& name_attr_pair) {
|
||||
return name_attr_pair.second.has_func();
|
||||
});
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
|
||||
// not a resource operation recognized by XLA then sets `out_resource_op_kind`
|
||||
// to nullopt.
|
||||
@ -134,9 +119,7 @@ Status XlaResourceOpKindForNode(
|
||||
// We conservatively assume that functions will both read and write resource
|
||||
// variables. In the future we may consider doing some form of
|
||||
// inter-procedural analysis.
|
||||
bool may_call_function;
|
||||
TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
|
||||
if (may_call_function) {
|
||||
if (MayCallFunction(n, flib_def)) {
|
||||
*out_resource_op_kind = XlaResourceOpKind::kReadWrite;
|
||||
} else {
|
||||
*out_resource_op_kind = absl::nullopt;
|
||||
|
@ -25,7 +25,6 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -227,28 +226,6 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
||||
|
||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||
|
||||
Status AdjustCycleDetectionGraphForResourceOps(
|
||||
const Graph* graph, const FunctionLibraryDefinition* flib_def,
|
||||
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
|
||||
GraphCycles* cycles) {
|
||||
std::vector<std::pair<int, int>> unsafe_deps;
|
||||
TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs(
|
||||
*graph, flib_def, resource_ops_to_ignore, &unsafe_deps));
|
||||
|
||||
// An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are
|
||||
// operations that interact with resource variables, must not be put in the
|
||||
// same cluster. We enforce this constraint by creating a phantom node, X,
|
||||
// and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P
|
||||
// and Q together since that would create a cycle with X.
|
||||
|
||||
for (std::pair<int, int> unsafe_dep : unsafe_deps) {
|
||||
int phantom_node_id = cycles->NewNode();
|
||||
CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id));
|
||||
CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device,
|
||||
@ -436,4 +413,16 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
||||
return result;
|
||||
}
|
||||
|
||||
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def) {
|
||||
if (flib_def->Contains(n.type_string())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// This is a conservative check: there may be nodes with a `func`
|
||||
// attribute that do not make function calls.
|
||||
return absl::c_any_of(n.def().attr(),
|
||||
[](const std::pair<string, AttrValue>& name_attr_pair) {
|
||||
return name_attr_pair.second.has_func();
|
||||
});
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -74,13 +74,6 @@ void RemoveFromXlaCluster(Node* node);
|
||||
// Returns true if `node` has a DT_RESOURCE typed input or output.
|
||||
bool HasResourceInputOrOutput(const Node& node);
|
||||
|
||||
// Adds edges to `cycles` to prevent clustering resource operations that cannot
|
||||
// be legally clustered.
|
||||
Status AdjustCycleDetectionGraphForResourceOps(
|
||||
const Graph* graph, const FunctionLibraryDefinition* flib_def,
|
||||
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
|
||||
GraphCycles* cycles);
|
||||
|
||||
// Picks the device for which XLA should compile a cluster that contains
|
||||
// operations placed in devices in `device_names`. For instance a cluster that
|
||||
// contains operations solely placed on the CPU will be compiled into a CPU
|
||||
@ -134,6 +127,10 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph(
|
||||
// Returns true if `g` is a single-GPU graph. A single-GPU graph uses exactly
|
||||
// one GPU (and any number of CPUs).
|
||||
bool IsSingleGpuGraph(const Graph& g);
|
||||
|
||||
// Returns true if it is possible (but not guaranteed) that `n` calls a
|
||||
// function.
|
||||
bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
@ -30,10 +30,17 @@ namespace tensorflow {
|
||||
|
||||
class XlaCpuDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||
};
|
||||
|
||||
Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
devices->push_back(absl::StrCat("/physical_device:", DEVICE_XLA_CPU, ":0"));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaCpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
@ -46,7 +53,13 @@ Status XlaCpuDeviceFactory::CreateDevices(
|
||||
compile_on_demand
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_all_resource_ops = true;
|
||||
registration.cluster_resource_variable_ops_unsafely = true;
|
||||
registration.cluster_stack_ops = false;
|
||||
registration.cluster_tensor_array_ops = true;
|
||||
registration.cluster_stateful_rng_ops = true;
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_CPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -519,7 +519,7 @@ Status XlaDevice::RefreshStatus() {
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device) {
|
||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||
// gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
|
||||
// gets executed by an XlaCompileOnDemandOp, which compiles it and executes
|
||||
// it just-in-time.
|
||||
OpKernel* (*factory)(OpKernelConstruction*) =
|
||||
[](OpKernelConstruction* context) -> OpKernel* {
|
||||
|
@ -247,6 +247,9 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
data::MakeIteratorOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AnonymousIteratorV2").Device(DEVICE).HostMemory("deleter"), \
|
||||
data::AnonymousIteratorHandleOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
|
||||
data::IteratorGetNextOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
|
||||
|
@ -1,349 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Is 'node' an operator that consumes only the shape of its input, not the
|
||||
// data itself?
|
||||
static bool IsShapeConsumerOp(const Node& node) {
|
||||
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
|
||||
node.type_string() == "Rank" || node.type_string() == "Size";
|
||||
}
|
||||
|
||||
// Returns true if the op can be decomposed into XLA ops for which
|
||||
// there are fusible elemental implementations.
|
||||
static bool IsXlaFusible(const NodeDef& node) {
|
||||
static const std::unordered_set<std::string>* elementwise_ops =
|
||||
new std::unordered_set<std::string>(
|
||||
{// tf2xla/kernels/aggregate_ops.cc
|
||||
"AddN",
|
||||
// tf2xla/kernels/binary_ops.cc
|
||||
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
|
||||
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
|
||||
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
|
||||
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
|
||||
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
|
||||
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
|
||||
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
|
||||
// tf2xla/kernels/unary_ops.cc
|
||||
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
|
||||
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
|
||||
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
|
||||
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
|
||||
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
|
||||
"Square", "Tan", "Tanh", "Real", "Imag",
|
||||
// tf2xla/kernels/bcast_ops.cc
|
||||
"BroadcastArgs", "BroadcastGradientArgs",
|
||||
// tf2xla/kernels/bias_ops.cc
|
||||
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
|
||||
// tf2xla/kernels/cast_op.cc
|
||||
"Cast",
|
||||
// tf2xla/kernels/concat_op.cc
|
||||
"Concat", "ConcatV2", "ConcatOffset",
|
||||
// tf2xla/kernels/const_op.cc
|
||||
"Const",
|
||||
// tf2xla/kernels/elu_op.cc
|
||||
"Elu", "EluGrad", "Selu", "SeluGrad",
|
||||
// tf2xla/kernels/fill_op.cc
|
||||
"Fill",
|
||||
// tf2xla/kernels/identity_op.cc
|
||||
"Identity", "IdentityN", "PreventGradient",
|
||||
"StopGradient", /*"Snapshot",*/
|
||||
// tf2xla/kernels/index_ops.cc
|
||||
"ArgMax", "ArgMin",
|
||||
// tf2xla/kernels/mirror_pad_op.cc
|
||||
"MirrorPad",
|
||||
// tf2xla/kernels/one_hot_op.cc
|
||||
"OneHot",
|
||||
// tf2xla/kernels/pack_op.cc
|
||||
"Pack",
|
||||
// tf2xla/kernels/pad_op.cc
|
||||
"Pad", "PadV2",
|
||||
// tf2xla/kernels/relu_op.cc
|
||||
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
|
||||
// tf2xla/kernels/reshape_op.cc
|
||||
"Reshape",
|
||||
// tf2xla/kernels/reverse_op.cc
|
||||
"Reverse", "ReverseV2",
|
||||
// tf2xla/kernels/reverse_sequence_op.cc
|
||||
"ReverseSequence",
|
||||
// tf2xla/kernels/shape_op.cc
|
||||
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
|
||||
"ZerosLike", "OnesLike",
|
||||
// tf2xla/kernels/slice_op.cc
|
||||
"Slice",
|
||||
// tf2xla/kernels/split_op.cc
|
||||
"Split", "SplitV",
|
||||
// tf2xla/kernels/strided_slice_op.cc
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
// tf2xla/kernels/tile_ops.cc
|
||||
"Tile",
|
||||
// tf2xla/kernels/transpose_op.cc
|
||||
"Transpose", "InvertPermutation",
|
||||
// tf2xla/kernels/unpack_op.cc
|
||||
"Unpack"});
|
||||
|
||||
return elementwise_ops->count(node.op()) > 0;
|
||||
}
|
||||
|
||||
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
const grappler::GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
VLOG(2) << "Here at fusion optimizer";
|
||||
|
||||
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
|
||||
// Once that happens, the expected interaction between this optimizer and when
|
||||
// the global_jit_level is set is as follows: Fusion optimizer will replace
|
||||
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
|
||||
// be further compiled where possible via mark_for_compilation_pass. Note that
|
||||
// this might lead to inefficient clustering, and it is best to use either the
|
||||
// fusion optimizer or the global_jit flag, and not combine the two.
|
||||
|
||||
// Create a Graph out of GraphDef. This is required currently because the
|
||||
// helpers around clustering, encapsulation etc work on graphs.
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
item.graph.library());
|
||||
Graph graph(function_library);
|
||||
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
|
||||
shape_refiner.set_require_shape_inference_fns(false);
|
||||
shape_refiner.set_disable_constant_propagation(true);
|
||||
ImportGraphDefOptions options;
|
||||
// Graph optimization happens at the late stage of graph execution, when
|
||||
// colocation constraints are already validated previously and the device
|
||||
// placement of nodes has also completed, so there is no need to validate
|
||||
// colocation constraints again.
|
||||
options.validate_colocation_constraints = false;
|
||||
options.validate_shape = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> deadness;
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
|
||||
|
||||
// Collect nodes that can be fused via XLA, while ignoring those that
|
||||
// explicitly ask for XLA: (*) nodes that are marked to be compiled
|
||||
// explicitly. (*) nodes assigned to XLA device.
|
||||
OrderedNodeSet compilation_candidates;
|
||||
for (Node* node : graph.op_nodes()) {
|
||||
// If there is a _XlaCompile annotation, ignore the node if it is
|
||||
// true. Nodes are marked with this attr via experimental_jit_scope, and
|
||||
// will be handled by the mark_for_compilation pass.
|
||||
bool compile = false;
|
||||
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
|
||||
if (status.ok() && compile) {
|
||||
continue;
|
||||
}
|
||||
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
|
||||
// marked with this attr to indicate they are already part of a cluster and
|
||||
// hence ignored.
|
||||
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
|
||||
if (status.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If there is an explicit XLA device placement, ignore the node.
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
|
||||
if (device_type.type_string().find("XLA") != string::npos) continue;
|
||||
|
||||
// Assume all fusible ops are registered.
|
||||
// TODO(hpucha): Check for registration if possible.
|
||||
if (!IsXlaFusible(node->def())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// XLA does not offer guaranteed aliasing between the input and output of
|
||||
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
|
||||
// Leave such nodes out of XLA clusters.
|
||||
if (HasForwardedRefInput(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If inputs to `node` can have conflicting deadness (i.e. some are alive
|
||||
// and some are dead) then don't compile it. XLA cannot represent the
|
||||
// deadness semantics of these nodes correctly and auto-clustering these
|
||||
// nodes can cause deadness to propagate to nodes that should be live.
|
||||
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates.insert(node);
|
||||
}
|
||||
|
||||
if (compilation_candidates.empty()) {
|
||||
VLOG(2) << "No compilable candidates";
|
||||
*output = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
GraphCycles cycles;
|
||||
TF_ASSIGN_OR_RETURN(bool cycle_detection_graph_ok,
|
||||
CreateCycleDetectionGraph(&graph, &cycles));
|
||||
if (!cycle_detection_graph_ok) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps(
|
||||
&graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles));
|
||||
|
||||
// TODO(hpucha): Make clustering more robust. There are two known issues that
|
||||
// we need to mitigate: (a) Non-resource variables can cause deadlocks
|
||||
// when clustering changes order of execution. See b/77263461 for a specific
|
||||
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
|
||||
// example.
|
||||
|
||||
struct Cluster {
|
||||
// Identifies the node that represents this cluster in the cycle detection
|
||||
// graph.
|
||||
int representative = -1;
|
||||
};
|
||||
|
||||
// Each compilation candidate belongs to a cluster. The cluster's
|
||||
// representative names the node in the 'cycles' graph that represents the
|
||||
// cluster.
|
||||
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
|
||||
std::deque<UnionFind<Cluster>*> worklist;
|
||||
for (Node* node : compilation_candidates) {
|
||||
Cluster& cluster = clusters[node->id()].Get();
|
||||
cluster.representative = node->id();
|
||||
worklist.push_back(&clusters[node->id()]);
|
||||
}
|
||||
|
||||
// Repeatedly contract edges between clusters that are on the same device,
|
||||
// provided the contraction would not create a cycle. This is a simplified
|
||||
// version of the clustering in mark_for_compilation_pass that also deals with
|
||||
// nodes that are explicitly tagged to be compiled/clustered.
|
||||
while (!worklist.empty()) {
|
||||
int from = worklist.front()->Get().representative;
|
||||
worklist.pop_front();
|
||||
|
||||
Node* node_from = graph.FindNodeId(from);
|
||||
if (node_from->IsControlFlow()) {
|
||||
// Control flow nodes aren't compilation candidates and should never
|
||||
// appear.
|
||||
return errors::Internal(
|
||||
"Found control flow node in clustering worklist: ",
|
||||
node_from->type_string());
|
||||
}
|
||||
for (int to : cycles.Successors(from)) {
|
||||
if (to >= graph.num_node_ids()) {
|
||||
// Node is a "frame" node that is present only in the cycle detection
|
||||
// graph. No clustering is possible.
|
||||
continue;
|
||||
}
|
||||
Node* node_to = graph.FindNodeId(to);
|
||||
if (compilation_candidates.find(node_to) ==
|
||||
compilation_candidates.cend()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not cluster across devices.
|
||||
if (node_from->def().device() != node_to->def().device()) {
|
||||
VLOG(2) << "Devices " << node_from->def().device() << " "
|
||||
<< node_to->def().device();
|
||||
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
|
||||
<< node_to->assigned_device_name();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ops that consume shapes cannot be the root of a cluster. This is an
|
||||
// optimization.
|
||||
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If contracting the edge would create a cycle, bail out.
|
||||
// However, just because we can't merge the clusters now does not mean
|
||||
// we won't be able to merge them in the future.
|
||||
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
|
||||
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
|
||||
if (!cycles.ContractEdge(from, to)) continue;
|
||||
|
||||
// Merge the clusters. ContractEdge uses 'from' as the number of the
|
||||
// merged node, so make sure 'from' is the chosen representative.
|
||||
clusters[from].Merge(&clusters[to]);
|
||||
|
||||
worklist.push_back(&clusters[from]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Count the number of non-trivial elements in each cluster.
|
||||
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
|
||||
for (const Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
// Identity nodes will be removed if the node gets marked for compilation.
|
||||
// Therefore we don't want to count them towards the effective cluster size.
|
||||
if (n->def().op() != "Identity") {
|
||||
effective_cluster_sizes[cluster]++;
|
||||
}
|
||||
}
|
||||
|
||||
const int min_cluster_size = 2;
|
||||
int num_clusters = 0;
|
||||
for (auto size : effective_cluster_sizes) {
|
||||
if (size >= min_cluster_size) {
|
||||
VLOG(3) << "Cluster " << num_clusters << " " << size;
|
||||
num_clusters++;
|
||||
}
|
||||
}
|
||||
|
||||
// Names for each cluster.
|
||||
std::unordered_map<int, string> cluster_names;
|
||||
// Sequence number generator to ensure clusters have unique names.
|
||||
static std::atomic<int64> cluster_sequence_num;
|
||||
|
||||
for (Node* n : compilation_candidates) {
|
||||
int cluster = clusters[n->id()].Get().representative;
|
||||
|
||||
// Compile if this is a cluster of >= min_cluster_size compilable operators.
|
||||
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
|
||||
string& name = cluster_names[cluster];
|
||||
|
||||
if (name.empty()) {
|
||||
name = absl::StrCat("cluster_", cluster_sequence_num++);
|
||||
}
|
||||
n->AddAttr(kXlaClusterAttr, name);
|
||||
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
|
||||
}
|
||||
}
|
||||
|
||||
graph.ToGraphDef(output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Optimizes graphs by fusing ops where possible, resulting in more efficient
|
||||
// execution.
|
||||
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
|
||||
public:
|
||||
XlaFusionOptimizer() {}
|
||||
~XlaFusionOptimizer() override {}
|
||||
|
||||
Status Init(
|
||||
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string name() const override { return "xla-fusion"; };
|
||||
|
||||
Status Optimize(grappler::Cluster* cluster,
|
||||
const grappler::GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {
|
||||
// Nothing to do for XlaFusionOptimizer.
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
|
@ -1,208 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
REGISTER_OP("UncompilableNullary").Output("o: float");
|
||||
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
|
||||
|
||||
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
|
||||
protected:
|
||||
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
|
||||
std::unordered_map<string, string> ids;
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
string cluster;
|
||||
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
|
||||
CHECK(!cluster.empty());
|
||||
ids[node.name()] = cluster;
|
||||
}
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, Chains) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a =
|
||||
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
|
||||
Node* d =
|
||||
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
|
||||
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
|
||||
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(4, clusters.size());
|
||||
EXPECT_EQ(clusters["B"], clusters["C"]);
|
||||
EXPECT_EQ(clusters["E"], clusters["F"]);
|
||||
EXPECT_NE(clusters["B"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("A") == clusters.cend());
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, FusibleOps) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
Node* b = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
|
||||
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
|
||||
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
|
||||
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(2, clusters.size());
|
||||
EXPECT_EQ(clusters["C"], clusters["E"]);
|
||||
EXPECT_TRUE(clusters.find("D") == clusters.cend());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
Node* b = ops::SourceOp(
|
||||
"Placeholder",
|
||||
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
|
||||
|
||||
Node* c = ops::BinaryOp(
|
||||
"Add", a, b,
|
||||
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
|
||||
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
|
||||
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
|
||||
ops::UnaryOp("Cos", e,
|
||||
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b =
|
||||
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
|
||||
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_TRUE(clusters.empty());
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
|
||||
GraphDef graph;
|
||||
{
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = ops::SourceOp("Const", builder.opts()
|
||||
.WithName("A")
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.WithAttr("value", Tensor()));
|
||||
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
|
||||
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
|
||||
TF_ASSERT_OK(builder.ToGraphDef(&graph));
|
||||
}
|
||||
grappler::GrapplerItem item;
|
||||
item.graph = graph;
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_EQ(3, clusters.size());
|
||||
EXPECT_EQ(clusters["A"], clusters["B"]);
|
||||
EXPECT_EQ(clusters["A"], clusters["C"]);
|
||||
}
|
||||
|
||||
TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output var_handle =
|
||||
ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({}));
|
||||
Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f);
|
||||
Output begin = ops::Const(root.WithOpName("begin"), 0);
|
||||
Output end = ops::Const(root.WithOpName("end"), 1);
|
||||
Output strides = ops::Const(root.WithOpName("strides"), 1);
|
||||
ops::ResourceStridedSliceAssign assign_1(
|
||||
root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign);
|
||||
ops::ResourceStridedSliceAssign assign_2(
|
||||
root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign);
|
||||
root.graph()->AddControlEdge(assign_1.operation.node(),
|
||||
assign_2.operation.node());
|
||||
grappler::GrapplerItem item;
|
||||
root.graph()->ToGraphDef(&item.graph);
|
||||
|
||||
XlaFusionOptimizer optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
auto clusters = GetClusters(output);
|
||||
EXPECT_NE(clusters["assign_1"], clusters["assign_2"]);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -55,10 +55,32 @@ static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
|
||||
|
||||
class XlaGpuDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||
};
|
||||
|
||||
Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
|
||||
auto platform = se::MultiPlatformManager::PlatformWithName("CUDA");
|
||||
if (!platform.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int device_count = platform.ValueOrDie()->VisibleDeviceCount();
|
||||
if (device_count <= 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
devices->push_back(
|
||||
absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaGpuDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
@ -66,7 +88,13 @@ Status XlaGpuDeviceFactory::CreateDevices(
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_all_resource_ops = true;
|
||||
registration.cluster_resource_variable_ops_unsafely = true;
|
||||
registration.cluster_stack_ops = false;
|
||||
registration.cluster_tensor_array_ops = true;
|
||||
registration.cluster_stateful_rng_ops = true;
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
|
||||
|
||||
static XlaDeviceOpRegistrations* registrations =
|
||||
|
@ -32,10 +32,19 @@ constexpr std::array<DataType, 10> kExecAllTypes = {
|
||||
|
||||
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
||||
public:
|
||||
Status ListPhysicalDevices(std::vector<string>* devices) override;
|
||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||
};
|
||||
|
||||
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
|
||||
std::vector<string>* devices) {
|
||||
devices->push_back(
|
||||
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
const SessionOptions& session_options, const string& name_prefix,
|
||||
std::vector<std::unique_ptr<Device>>* devices) {
|
||||
@ -47,7 +56,13 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
registration.compile_all_resource_ops = true;
|
||||
registration.cluster_resource_variable_ops_unsafely = true;
|
||||
registration.cluster_stack_ops = false;
|
||||
registration.cluster_tensor_array_ops = true;
|
||||
registration.cluster_stateful_rng_ops = true;
|
||||
registration.cluster_control_trigger = true;
|
||||
registration.elide_assert_and_checknumerics = true;
|
||||
registration.cluster_variant_ops = true;
|
||||
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
|
||||
registration);
|
||||
|
||||
|
@ -347,9 +347,11 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
|
||||
<< "Invalid input for outputs " << i;
|
||||
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
|
||||
int input_index =
|
||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
if (allocate_xla_tensors_) {
|
||||
|
@ -875,6 +875,7 @@ tf_xla_py_test(
|
||||
name = "stack_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
tags = ["config-cuda-only"],
|
||||
use_xla_device = False,
|
||||
deps = [
|
||||
":xla_test",
|
||||
@ -921,6 +922,7 @@ tf_xla_py_test(
|
||||
srcs = ["tensor_array_ops_test.py"],
|
||||
# TensorArray ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
tags = ["config-cuda-only"],
|
||||
use_xla_device = False,
|
||||
deps = [
|
||||
":xla_test",
|
||||
|
@ -23,6 +23,7 @@ import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -1041,6 +1042,62 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([2], dtype=np.int64),
|
||||
expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
|
||||
|
||||
def testBatchMatMulBroadcast(self):
|
||||
"""Tests broadcasting behavior of BatchMatMul."""
|
||||
with compat.forward_compatibility_horizon(2019, 4, 19):
|
||||
# [2, 3] @ [1, 3, 4] -> [1, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[10, 20, 30], [11, 21, 31]], dtype=np.float32),
|
||||
np.array([[[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]]],
|
||||
dtype=np.float32),
|
||||
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
|
||||
dtype=np.float32))
|
||||
# [1, 2, 3] @ [3, 4] -> [1, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[10, 20, 30], [11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]],
|
||||
dtype=np.float32),
|
||||
expected=np.array([[[140, 280, 420, 560], [146, 292, 438, 584]]],
|
||||
dtype=np.float32))
|
||||
# [2, 1, 3] @ [3, 1] -> [2, 1, 1]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1], [2], [3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 1, 3] @ [1, 3] -> [2, 1, 1] (adjoint_b)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
|
||||
np.array([[[10, 20, 30]], [[11, 21, 31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 3, 1] @ [3, 1] -> [2, 1, 1] (adjoint_a)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_a=True),
|
||||
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
|
||||
np.array([[1], [2], [3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [2, 3, 1] @ [1, 3] -> [2, 1, 1] (adjoint_a and adjoint_b)
|
||||
self._testBinary(
|
||||
lambda x, y: math_ops.matmul(x, y, adjoint_a=True, adjoint_b=True),
|
||||
np.array([[[10], [20], [30]], [[11], [21], [31]]], dtype=np.float32),
|
||||
np.array([[1, 2, 3]], dtype=np.float32),
|
||||
expected=np.array([[[140]], [[146]]], dtype=np.float32))
|
||||
# [5, 1, 2, 3] @ [1, 7, 3, 4] -> [5, 7, 2, 4]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.ones([5, 1, 2, 3], dtype=np.float32),
|
||||
np.ones([1, 7, 3, 4], dtype=np.float32),
|
||||
expected=np.full([5, 7, 2, 4], 3, dtype=np.float32))
|
||||
# [4, 5, 1, 2, 3] @ [1, 1, 3, 5] -> [4, 5, 1, 2, 5]
|
||||
self._testBinary(
|
||||
math_ops.matmul,
|
||||
np.full([4, 5, 1, 2, 3], 2., dtype=np.float32),
|
||||
np.full([1, 1, 3, 5], 3., dtype=np.float32),
|
||||
expected=np.full([4, 5, 1, 2, 5], 18., dtype=np.float32))
|
||||
|
||||
def testPad(self):
|
||||
for dtype, pad_type in itertools.product(
|
||||
self.numeric_types, [np.int32, np.int64]):
|
||||
|
@ -341,7 +341,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
var = f()
|
||||
self.assertEqual(1.0, var.numpy())
|
||||
|
||||
def DISALBED_testResourceVariableNoInlineReadWrite(self):
|
||||
def testResourceVariableNoInlineReadWrite(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(1.0)
|
||||
w = resource_variable_ops.ResourceVariable(0.0)
|
||||
@ -359,8 +359,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
self.assertEqual(145.0, f().numpy())
|
||||
self.assertEqual(15.0, w.read_value().numpy())
|
||||
|
||||
# TODO(b/36139787)
|
||||
def DISABLED_testResourceVariableNoInlineReadOnly(self):
|
||||
def testResourceVariableNoInlineReadOnly(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(10.0)
|
||||
|
||||
@ -374,8 +373,7 @@ class EagerFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
self.assertEqual(50.0, f().numpy())
|
||||
|
||||
# TODO(b/36139787)
|
||||
def DISABLED_testResourceVariableNoInlineWriteOnly(self):
|
||||
def testResourceVariableNoInlineWriteOnly(self):
|
||||
with self.test_scope():
|
||||
v = resource_variable_ops.ResourceVariable(0.0)
|
||||
|
||||
|
@ -95,7 +95,12 @@ class JitLaunchTest(test.TestCase):
|
||||
# If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
|
||||
# node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
|
||||
# ops to be constant-folded away, so the check is optional.
|
||||
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
|
||||
def _compare(self,
|
||||
fn,
|
||||
args,
|
||||
require_kernel_launch=True,
|
||||
name=None,
|
||||
noinline=None):
|
||||
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
|
||||
placeholders = []
|
||||
feeds = {}
|
||||
@ -105,7 +110,8 @@ class JitLaunchTest(test.TestCase):
|
||||
placeholders.append(placeholder)
|
||||
feeds[placeholder] = arg
|
||||
|
||||
compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
|
||||
compiled_op = CompiledKernel(
|
||||
fn, *placeholders, name=name, noinline=noinline)
|
||||
direct_op = fn(*placeholders)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
@ -155,17 +161,16 @@ class JitLaunchTest(test.TestCase):
|
||||
# to symbolically execute Bar correctly regardless of whether Bar is inlined
|
||||
# or not.
|
||||
|
||||
# TODO(b/36139787): Re-enable this test when noinline works again.
|
||||
# Tests compiled=True and noinline=True.
|
||||
# self._compare(
|
||||
# AddOnceReturnTwice, [np.array(
|
||||
# [[[0.5, -1.0]]], dtype=np.float32)],
|
||||
# noinline=True)
|
||||
self._compare(
|
||||
AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
|
||||
name="AddOnceReturnTwice_inline",
|
||||
noinline=True)
|
||||
|
||||
# Tests compiled=True and noinline=False.
|
||||
self._compare(
|
||||
AddOnceReturnTwice, [np.array(
|
||||
[[[0.5, -1.0]]], dtype=np.float32)],
|
||||
AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
|
||||
name="AddOnceReturnTwice_noinline",
|
||||
noinline=False)
|
||||
|
||||
def testOneConstOutput(self):
|
||||
|
@ -454,7 +454,7 @@ class PoolGradTest(xla_test.XLATestCase):
|
||||
"""Verifies the output values of the pooling function.
|
||||
|
||||
Args:
|
||||
pool_func: Pooling function to be called, e.g., tf.nn.max_pool
|
||||
pool_func: Pooling function to be called, e.g., tf.nn.max_pool2d
|
||||
pool_grad_func: Corresponding pooling gradient function.
|
||||
input_sizes: Input tensor dimensions.
|
||||
ksize: The kernel size dimensions
|
||||
|
@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
def fn():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, tensor_array_name="foo", size=3)
|
||||
return ta.write(-1, np.int32(7)).flow
|
||||
return ta.write(-1, constant_op.constant(7)).flow
|
||||
|
||||
# Test writing the wrong datatype.
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but op has dtype int32"):
|
||||
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
|
||||
# callers provide proper init dtype.
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
r"("
|
||||
r"conversion requested dtype float32 for Tensor with dtype int32"
|
||||
r"|"
|
||||
r"TensorArray dtype is float but op has dtype int32"
|
||||
r")"):
|
||||
xla.compile(fn)[0].eval()
|
||||
|
||||
@test_util.disable_control_flow_v2("b/124334096 verify dtype")
|
||||
|
@ -125,7 +125,7 @@ class ListOpsTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(e0, 2.0)
|
||||
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(e1, 1.0)
|
||||
self.assertAllEqual(list_ops.tensor_list_length(l), 0)
|
||||
self.assertAllEqual(list_ops.tensor_list_length(l), 2)
|
||||
|
||||
def testGetSet(self):
|
||||
with self.cached_session(), self.test_scope():
|
||||
@ -211,6 +211,18 @@ class ListOpsTest(xla_test.XLATestCase):
|
||||
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(t, [0., 0., 0.])
|
||||
|
||||
def testZerosLikeForTensorList(self):
|
||||
with self.cached_session(), self.test_scope():
|
||||
l = list_ops.empty_tensor_list(
|
||||
element_dtype=dtypes.float32,
|
||||
element_shape=[],
|
||||
max_num_elements=2)
|
||||
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
||||
z = array_ops.zeros_like(l)
|
||||
z = list_ops.tensor_list_stack(z, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(z.shape.as_list(), [None])
|
||||
self.assertAllEqual(z, [0.0, 0.0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
|
||||
os.environ.get('TF_XLA_FLAGS', ''))
|
||||
|
@ -23,9 +23,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load(
|
||||
"@local_config_tensorrt//:build_defs.bzl",
|
||||
"if_tensorrt",
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_additional_all_protos",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
|
||||
|
||||
# Google-internal targets go here (must be at the end).
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "tensorrt_test_cc",
|
||||
@ -74,13 +78,67 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "trt_engine_resource_op_kernels",
|
||||
srcs = ["kernels/trt_engine_resource_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
":trt_engine_instance_proto_cc",
|
||||
":trt_logging",
|
||||
":trt_plugins",
|
||||
":trt_resources",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]) + tf_custom_op_library_additional_deps(),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_resource_ops_test",
|
||||
size = "small",
|
||||
srcs = ["kernels/trt_engine_resource_ops_test.cc"],
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":trt_engine_instance_proto_cc",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
":trt_logging",
|
||||
":trt_resources",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "python/ops/libtftrt.so",
|
||||
copts = tf_copts(is_external = True),
|
||||
linkopts = ["-lm"],
|
||||
deps = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
] + if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
@ -112,10 +170,40 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "trt_engine_op_test",
|
||||
size = "small",
|
||||
srcs = ["kernels/trt_engine_op_test.cc"],
|
||||
tags = [
|
||||
"no_cuda_on_cpu_tap",
|
||||
"no_windows",
|
||||
"nomac",
|
||||
],
|
||||
deps = [
|
||||
":trt_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_resources",
|
||||
"@com_google_googletest//:gtest",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_testutil",
|
||||
] + if_tensorrt([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = [
|
||||
"trt_engine_op",
|
||||
"get_serialized_resource_op",
|
||||
"trt_engine_resource_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -142,6 +230,7 @@ tf_cuda_library(
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "trt_ops",
|
||||
deps = [
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
":trt_op_libs",
|
||||
],
|
||||
)
|
||||
@ -156,7 +245,9 @@ tf_custom_op_py_library(
|
||||
]),
|
||||
kernels = [
|
||||
":trt_op_kernels",
|
||||
":trt_engine_resource_op_kernels",
|
||||
":trt_op_libs",
|
||||
":trt_engine_resource_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
@ -173,6 +264,7 @@ tf_cuda_library(
|
||||
name = "trt_resources",
|
||||
srcs = [
|
||||
"utils/trt_int8_calibrator.cc",
|
||||
"utils/trt_lru_cache.cc",
|
||||
"utils/trt_resources.cc",
|
||||
],
|
||||
hdrs = [
|
||||
@ -440,6 +532,13 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "trt_engine_instance_proto",
|
||||
srcs = ["utils/trt_engine_instance.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_utils",
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@ -1350,11 +1351,19 @@ Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
|
||||
// the dims are unknown or need to be inferred. And we don't do further checks
|
||||
// but rely on the caller to not make mistakes.
|
||||
// Otherwise we do simple check to make sure the total sizes are the same.
|
||||
if (AreDimsStaticWithDifferentSize(input_dims, dims, input.is_tensor())) {
|
||||
// If an input is a weight, it is going to become a tensor via
|
||||
// CreateConstantLayer. So we can treat it as a tensor for
|
||||
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
|
||||
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
|
||||
return errors::InvalidArgument(
|
||||
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
|
||||
DebugString(dims));
|
||||
}
|
||||
// ConstantLayer requires static shapes (cannot infer -1).
|
||||
if (input.is_weights() && !HasStaticShape(dims)) {
|
||||
return errors::InvalidArgument("Shape is not fully defined: ",
|
||||
DebugString(dims));
|
||||
}
|
||||
if (validation_only) {
|
||||
*tensor = nullptr;
|
||||
return Status::OK();
|
||||
@ -1589,18 +1598,6 @@ Status AllowDataTypes(const OpConverterParams& params,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
|
||||
const TRT_ShapedWeights& weights_src) {
|
||||
TRT_ShapedWeights weights =
|
||||
store->GetTempWeights(nvinfer1::DataType::kHALF, weights_src.shape_);
|
||||
const float* src = static_cast<const float*>(weights_src.GetValues());
|
||||
Eigen::half* dst = static_cast<Eigen::half*>(weights.GetValues());
|
||||
for (int64_t i = 0; i < weights_src.count(); i++) {
|
||||
dst[i] = Eigen::half_impl::float_to_half_rtne(src[i]);
|
||||
}
|
||||
return weights;
|
||||
}
|
||||
|
||||
// ****************************************************************************
|
||||
// Constant folding functions for weights.
|
||||
// TODO(laigd): we should probably use eigen directly.
|
||||
@ -1614,7 +1611,7 @@ struct LambdaFactory {
|
||||
switch (op) {
|
||||
case OP_CATEGORY::RSQRT: {
|
||||
VLOG(2) << "RSQRT GETS DONE";
|
||||
return [](T t) -> T { return 1.0 / sqrt(t); };
|
||||
return [](T t) -> T { return 1.0 / std::sqrt(t); };
|
||||
}
|
||||
case OP_CATEGORY::NEG:
|
||||
return [](T t) -> T { return -t; };
|
||||
@ -1633,7 +1630,7 @@ std::function<Eigen::half(Eigen::half)> LambdaFactory::unary<Eigen::half>() {
|
||||
case OP_CATEGORY::RSQRT: {
|
||||
VLOG(2) << "RSQRT GETS DONE";
|
||||
return [](Eigen::half t) {
|
||||
return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
|
||||
return Eigen::half(1.0 / std::sqrt(static_cast<float>(t)));
|
||||
};
|
||||
}
|
||||
case OP_CATEGORY::NEG:
|
||||
@ -1772,10 +1769,6 @@ Status BinaryTensorOpWeight(OpConverterParams* params,
|
||||
params->converter->TransposeTensor(tensor, permutation, &tensor));
|
||||
}
|
||||
|
||||
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
|
||||
weights = ConvertFP32ToFP16(params->weight_store, weights);
|
||||
}
|
||||
|
||||
// Prepare weights
|
||||
TRT_ShapedWeights shift_weights(weights.TrtDType());
|
||||
TRT_ShapedWeights scale_weights(weights.TrtDType());
|
||||
@ -1937,9 +1930,6 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
||||
// num_groups will be 1.
|
||||
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
|
||||
|
||||
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
|
||||
weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck);
|
||||
}
|
||||
// For conv, TF weights are RSCK, and TRT expects KCRS.
|
||||
// For backprop, TF weights are RSKC, and TRT expects CKRS.
|
||||
// Therefore, this reorder will work for both cases.
|
||||
@ -3038,9 +3028,6 @@ Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
}
|
||||
|
||||
TRT_ShapedWeights weights = inputs.at(1).weights();
|
||||
if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
|
||||
weights = ConvertFP32ToFP16(params->weight_store, weights);
|
||||
}
|
||||
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
|
||||
if (weights.shape_.d[0] == 1) {
|
||||
mode = nvinfer1::ScaleMode::kUNIFORM;
|
||||
@ -4237,6 +4224,95 @@ Status ConvertTopK(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertDepthSpaceShuffle(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
|
||||
TF_RETURN_IF_ERROR(AllowDataTypes(
|
||||
*params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32}));
|
||||
TFAttrs attrs(node_def);
|
||||
const int block_size = attrs.get<int64>("block_size");
|
||||
if (block_size < 2) {
|
||||
return errors::InvalidArgument("Block size must be 2 or greater, at ",
|
||||
node_def.name());
|
||||
}
|
||||
const string data_format = attrs.get<string>("data_format");
|
||||
if (data_format != "NCHW" && data_format != "NHWC") {
|
||||
return errors::Unimplemented("Data format ", data_format,
|
||||
" is not supported, at ", node_def.name());
|
||||
}
|
||||
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
|
||||
if (dims.nbDims != 3) {
|
||||
return errors::InvalidArgument("The input to ", node_def.op(),
|
||||
" must be rank 4, at ", node_def.name());
|
||||
}
|
||||
const int num_channels = data_format == "NCHW" ? dims.d[0] : dims.d[2];
|
||||
const int h = data_format == "NCHW" ? dims.d[1] : dims.d[0];
|
||||
const int w = data_format == "NCHW" ? dims.d[2] : dims.d[1];
|
||||
// Get shuffle parameters.
|
||||
nvinfer1::Dims first_shuffle_shape;
|
||||
nvinfer1::Permutation transpose_perm;
|
||||
nvinfer1::Dims second_shuffle_shape;
|
||||
if (node_def.op() == "DepthToSpace") {
|
||||
if (num_channels % (block_size * block_size) != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of channels must be divisible by block_size*block_size, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// First Reshape [C, H, W] - > [r, r, C/(r*r), H, W]
|
||||
first_shuffle_shape = {
|
||||
/*nbDims=*/5,
|
||||
/*d=*/{block_size, block_size, num_channels / (block_size * block_size),
|
||||
h, w}};
|
||||
// Transpose [r, r, C/(r*r), H, W] -> [C/(r*r), H, r, W, r]
|
||||
transpose_perm = {2, 3, 0, 4, 1};
|
||||
// Second Reshape [C/(r*r), H, r, W, r] -> [C/(r*r), H * r, W * r]
|
||||
second_shuffle_shape =
|
||||
nvinfer1::DimsCHW(num_channels / (block_size * block_size),
|
||||
h * block_size, w * block_size);
|
||||
} else if (node_def.op() == "SpaceToDepth") {
|
||||
if (h % block_size != 0 || w % block_size != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Width and height must be divisible by block_size, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// First Reshape [C, H, W] -> [C, H/r, r, W/r, r]
|
||||
first_shuffle_shape = {/*nbDims=*/5,
|
||||
/*d=*/{num_channels, h / block_size, block_size,
|
||||
w / block_size, block_size}};
|
||||
// Transpose [C, H/r, r, W/r, r] -> [r, r, C, H/r, W/r]
|
||||
transpose_perm = {2, 4, 0, 1, 3};
|
||||
// Second Reshape [r, r, C, H/r, W/r] -> [C*r*r, H/r, W/r]
|
||||
second_shuffle_shape = nvinfer1::DimsCHW(
|
||||
num_channels * block_size * block_size, h / block_size, w / block_size);
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::IShuffleLayer* first_shuffle =
|
||||
params->converter->network()->addShuffle(*inputs.at(0).tensor());
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(first_shuffle, node_def.name());
|
||||
if (data_format == "NHWC") {
|
||||
first_shuffle->setFirstTranspose({2, 0, 1});
|
||||
}
|
||||
first_shuffle->setReshapeDimensions(first_shuffle_shape);
|
||||
first_shuffle->setSecondTranspose(transpose_perm);
|
||||
|
||||
nvinfer1::IShuffleLayer* second_shuffle =
|
||||
params->converter->network()->addShuffle(*first_shuffle->getOutput(0));
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(second_shuffle, node_def.name());
|
||||
second_shuffle->setReshapeDimensions(second_shuffle_shape);
|
||||
if (data_format == "NHWC") {
|
||||
second_shuffle->setSecondTranspose({1, 2, 0});
|
||||
}
|
||||
|
||||
params->converter->MarkQuantizationRangesAsInferrable(
|
||||
inputs.at(0).tensor(), first_shuffle->getOutput(0));
|
||||
params->converter->MarkQuantizationRangesAsInferrable(
|
||||
first_shuffle->getOutput(0), second_shuffle->getOutput(0));
|
||||
params->outputs->push_back(TRT_TensorOrWeights(second_shuffle->getOutput(0)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if IS_TRT_VERSION_GE(5, 1, 0, 0)
|
||||
Status ConvertCombinedNMS(OpConverterParams* params) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -4416,6 +4492,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["Const"] = ConvertConst;
|
||||
(*registration)["Conv2D"] = ConvertConv2D;
|
||||
(*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
|
||||
(*registration)["DepthToSpace"] = ConvertDepthSpaceShuffle;
|
||||
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
|
||||
(*registration)["ExpandDims"] = ConvertExpandDims;
|
||||
(*registration)["GatherV2"] = ConvertGather;
|
||||
@ -4430,6 +4507,7 @@ static void RegisterValidatableOpConverters(
|
||||
(*registration)["Slice"] = ConvertSlice;
|
||||
(*registration)["Snapshot"] = ConvertIdentity; // Snapshot should be removed
|
||||
(*registration)["Softmax"] = ConvertSoftmax;
|
||||
(*registration)["SpaceToDepth"] = ConvertDepthSpaceShuffle;
|
||||
(*registration)["Split"] = ConvertSplit;
|
||||
(*registration)["Square"] = ConvertSquare;
|
||||
(*registration)["Squeeze"] = ConvertSqueeze;
|
||||
|
@ -212,6 +212,19 @@ std::vector<CType> InitTestVector(int size, CType start_value = CType(0)) {
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename InCType, typename OutCType>
|
||||
struct StaticCaster {
|
||||
OutCType operator()(InCType in) const { return static_cast<OutCType>(in); }
|
||||
};
|
||||
|
||||
template <typename InCType, typename OutCType>
|
||||
std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) {
|
||||
std::vector<OutCType> res(vals.size());
|
||||
std::transform(vals.begin(), vals.end(), res.begin(),
|
||||
StaticCaster<InCType, OutCType>());
|
||||
return res;
|
||||
}
|
||||
|
||||
// Fake ITensor implementation for testing purposes.
|
||||
class FakeITensor : public nvinfer1::ITensor {
|
||||
public:
|
||||
@ -721,19 +734,25 @@ TEST_F(ConverterTest, TransposeTensor) {
|
||||
ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
|
||||
}
|
||||
|
||||
void TestPrepareTensorForShape_Tensor(
|
||||
const std::vector<int>& tensor_dims, const std::vector<int>& reshape_dims,
|
||||
const std::vector<int>& expected_tensor_dims, Converter* converter,
|
||||
void TestPrepareTensorForShape(
|
||||
const std::vector<int>& input_dims, const std::vector<int>& reshape_dims,
|
||||
const std::vector<int>& expected_tensor_dims, bool input_is_tensor,
|
||||
Converter* converter, TrtWeightStore* weight_store,
|
||||
error::Code expected_code = error::OK,
|
||||
const char* expected_error_msg_substr = nullptr) {
|
||||
nvinfer1::ITensor* input_tensor = converter->network()->addInput(
|
||||
"", nvinfer1::DataType::kFLOAT, GetTestDims(tensor_dims));
|
||||
TRT_TensorOrWeights input;
|
||||
if (input_is_tensor) {
|
||||
input = TRT_TensorOrWeights(converter->network()->addInput(
|
||||
"", nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
|
||||
} else {
|
||||
input = TRT_TensorOrWeights(weight_store->GetTempWeights(
|
||||
nvinfer1::DataType::kFLOAT, GetTestDims(input_dims)));
|
||||
}
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
|
||||
for (bool validation_only : {false, true}) {
|
||||
const Status status = converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(input_tensor), GetTestDims(reshape_dims),
|
||||
validation_only, &output_tensor);
|
||||
input, GetTestDims(reshape_dims), validation_only, &output_tensor);
|
||||
if (expected_code == error::OK) {
|
||||
TF_EXPECT_OK(status);
|
||||
if (validation_only) {
|
||||
@ -748,49 +767,45 @@ void TestPrepareTensorForShape_Tensor(
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
|
||||
// Shape size doesn't match.
|
||||
TEST_F(ConverterTest, PrepareTensorForShape) {
|
||||
for (bool input_is_tensor : {true, false}) {
|
||||
// Shape size doesn't match.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({2, 3, 5}, {2, 3, 6}, {}, input_is_tensor,
|
||||
converter_.get(), weight_store_,
|
||||
error::INVALID_ARGUMENT, "Incompatible shapes");
|
||||
|
||||
// Regular shape.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({2, 3, 5}, {10, 3}, {10, 3}, input_is_tensor,
|
||||
converter_.get(), weight_store_);
|
||||
|
||||
// Reshape to zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape({1, 1}, {}, {}, input_is_tensor, converter_.get(),
|
||||
weight_store_);
|
||||
}
|
||||
|
||||
// Tensor input with zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {2, 3, 6}, {}, converter_.get(),
|
||||
error::INVALID_ARGUMENT,
|
||||
"Incompatible shapes");
|
||||
TestPrepareTensorForShape({}, {1, 1}, {1, 1}, /*input_is_tensor=*/true,
|
||||
converter_.get(), weight_store_);
|
||||
|
||||
// TODO(aaroey): we should check the case where uninferred dimensions are
|
||||
// not an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
|
||||
|
||||
// Infer shape, ok.
|
||||
// Infer tensor shape, ok.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
converter_.get());
|
||||
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
/*input_is_tensor=*/true, converter_.get(),
|
||||
weight_store_);
|
||||
|
||||
// Regular shape.
|
||||
// Infer weight shape, should fail.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({2, 3, 5}, {10, 3}, {10, 3},
|
||||
converter_.get());
|
||||
|
||||
// Input with zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({}, {1, 1}, {1, 1}, converter_.get());
|
||||
|
||||
// Reshape to zero rank.
|
||||
Reset();
|
||||
TestPrepareTensorForShape_Tensor({1, 1}, {}, {}, converter_.get());
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
|
||||
TRT_ShapedWeights weights = weight_store_->GetTempWeights(
|
||||
nvinfer1::DataType::kFLOAT, GetTestDims({2, 3, 5}));
|
||||
nvinfer1::ITensor* output_tensor = nullptr;
|
||||
for (bool validation_only : {false, true}) {
|
||||
TF_EXPECT_OK(converter_->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(weights), GetTestDims({10, 3}), validation_only,
|
||||
&output_tensor));
|
||||
if (validation_only) {
|
||||
EXPECT_EQ(nullptr, output_tensor);
|
||||
} else {
|
||||
ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
|
||||
}
|
||||
}
|
||||
TestPrepareTensorForShape({2, 3, 5}, {-1, 2}, {15, 2},
|
||||
/*input_is_tensor=*/false, converter_.get(),
|
||||
weight_store_, error::INVALID_ARGUMENT,
|
||||
"Shape is not fully defined");
|
||||
}
|
||||
|
||||
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
|
||||
@ -4910,6 +4925,279 @@ TEST_F(OpConverterTest, ConvertArgMinMax) {
|
||||
// TestConvertArgMinMax<ops::ArgMax, DT_INT32>(this);
|
||||
}
|
||||
|
||||
// Get the NodeDef for DepthToSpace or SpaceToSpace.
|
||||
template <typename OpType>
|
||||
NodeDef GetDepthSpaceShuffleNodeDef(DataType dtype, int block_size,
|
||||
string data_format) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto input = ops::Placeholder(s.WithOpName("input"), dtype);
|
||||
auto attrs = OpType::DataFormat(data_format);
|
||||
auto shuffle = OpType(s.WithOpName("my_shuffle"), input, block_size, attrs);
|
||||
return shuffle.operation.node()->def();
|
||||
}
|
||||
|
||||
template <typename CType>
|
||||
struct DepthSpaceShuffleTestParams {
|
||||
std::vector<int> input_dims;
|
||||
std::vector<CType> input_value;
|
||||
int block_size;
|
||||
string data_format;
|
||||
std::vector<int> expected_output_dims;
|
||||
std::vector<CType> expected_output;
|
||||
};
|
||||
|
||||
template <typename OpType, DataType dtype, typename CType>
|
||||
void TestConvertDepthSpaceShuffle(
|
||||
OpConverterTest* test,
|
||||
const std::vector<DepthSpaceShuffleTestParams<CType>>& params) {
|
||||
for (int i = 0; i < params.size(); ++i) {
|
||||
test->Reset();
|
||||
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<OpType>(
|
||||
dtype, params[i].block_size, params[i].data_format);
|
||||
test->AddTestTensor("input", params[i].input_dims, 1,
|
||||
TfDataTypeToTrt(dtype));
|
||||
test->RunValidationAndConversion(node_def);
|
||||
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_shuffle", &output));
|
||||
EXPECT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray(params[i].expected_output_dims,
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
DataVec input_data{{"input", test::AsTensor<CType>(params[i].input_value)}};
|
||||
DataVec output_data{{"my_shuffle", ConstructTensor<CType>(
|
||||
params[i].expected_output.size())}};
|
||||
test->BuildAndRun(
|
||||
input_data, &output_data,
|
||||
dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32);
|
||||
EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
|
||||
ElementsAreArray(params[i].expected_output));
|
||||
}
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertDepthToSpace(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(16);
|
||||
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
|
||||
{
|
||||
/*input_shape=*/{4, 2, 2},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{1, 4, 4},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{2, 2, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{4, 4, 1},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{16, 1, 1},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/4,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{1, 4, 4},
|
||||
/*expected_output=*/InitTestVector<CType>(16),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{2, 2, 8},
|
||||
/*input_value=*/InitTestVector<CType>(32),
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{4, 4, 2},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
|
||||
9, 10, 11, 4, 5,
|
||||
6, 7, 12, 13, 14,
|
||||
15, 16, 17, 18, 19,
|
||||
24, 25, 26, 27, 20,
|
||||
21, 22, 23, 28, 29,
|
||||
30, 31}),
|
||||
},
|
||||
};
|
||||
|
||||
TestConvertDepthSpaceShuffle<ops::DepthToSpace, dtype, CType>(test, params);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertDepthToSpace) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_shuffle", "DepthToSpace", {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"DepthToSpace got 0 inputs but expected 1, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"input\" for DepthToSpace must be a "
|
||||
"tensor, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input rank != 4
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestTensor("input", {16, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"The input to DepthToSpace must be rank 4, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Channels not divisible by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Number of channels must be divisible by "
|
||||
"block_size*block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Unsupported format, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::DepthToSpace>(
|
||||
DT_FLOAT, 2, "NCHW_VECT_C");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data format NCHW_VECT_C is not supported, at my_shuffle");
|
||||
}
|
||||
|
||||
TestConvertDepthToSpace<DT_FLOAT>(this);
|
||||
TestConvertDepthToSpace<DT_HALF>(this);
|
||||
TestConvertDepthToSpace<DT_INT32>(this);
|
||||
}
|
||||
|
||||
template <DataType dtype>
|
||||
void TestConvertSpaceToDepth(OpConverterTest* test) {
|
||||
typedef typename EnumToDataType<dtype>::Type CType;
|
||||
const std::vector<CType> common_input = InitTestVector<CType>(16);
|
||||
std::vector<DepthSpaceShuffleTestParams<CType>> params = {
|
||||
{
|
||||
/*input_shape=*/{1, 4, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{4, 2, 2},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{4, 4, 1},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{2, 2, 4},
|
||||
/*expected_output=*/
|
||||
CastTestVector<int, CType>(
|
||||
{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15}),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{1, 4, 4},
|
||||
/*input_value=*/common_input,
|
||||
/*block_size=*/4,
|
||||
/*data_format=*/"NCHW",
|
||||
/*expected_output_dims=*/{16, 1, 1},
|
||||
/*expected_output=*/InitTestVector<CType>(16),
|
||||
},
|
||||
{
|
||||
/*input_shape=*/{4, 4, 2},
|
||||
/*input_value=*/InitTestVector<CType>(32),
|
||||
/*block_size=*/2,
|
||||
/*data_format=*/"NHWC",
|
||||
/*expected_output_dims=*/{2, 2, 8},
|
||||
/*expected_output=*/CastTestVector<int, CType>({0, 1, 2, 3, 8,
|
||||
9, 10, 11, 4, 5,
|
||||
6, 7, 12, 13, 14,
|
||||
15, 16, 17, 18, 19,
|
||||
24, 25, 26, 27, 20,
|
||||
21, 22, 23, 28, 29,
|
||||
30, 31}),
|
||||
},
|
||||
};
|
||||
|
||||
TestConvertDepthSpaceShuffle<ops::SpaceToDepth, dtype, CType>(test, params);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertSpaceToDepth) {
|
||||
{
|
||||
// Input list is empty, should fail.
|
||||
NodeDef node_def = MakeNodeDef("my_shuffle", "SpaceToDepth", {});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"SpaceToDepth got 0 inputs but expected 1, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input is a weight, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestWeights<float>("input", {4, 1, 1}, {1, 2, 3, 4});
|
||||
RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
|
||||
"The input \"input\" for SpaceToDepth must be a "
|
||||
"tensor, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Input rank != 4
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 2, "NCHW");
|
||||
AddTestTensor("input", {16, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"The input to SpaceToDepth must be rank 4, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Width not divisble by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 9, 32});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Width and height must be divisible by "
|
||||
"block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Height not divisble by block_size, should fail.
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(DT_FLOAT, 3, "NCHW");
|
||||
AddTestTensor("input", {16, 32, 9});
|
||||
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
|
||||
"Width and height must be divisible by "
|
||||
"block_size, at my_shuffle");
|
||||
}
|
||||
{
|
||||
// Unsupported format, should fail.
|
||||
Reset();
|
||||
NodeDef node_def = GetDepthSpaceShuffleNodeDef<ops::SpaceToDepth>(
|
||||
DT_FLOAT, 2, "NCHW_VECT_C");
|
||||
AddTestTensor("input", {16, 32, 32});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Data format NCHW_VECT_C is not supported, at my_shuffle");
|
||||
}
|
||||
|
||||
TestConvertSpaceToDepth<DT_FLOAT>(this);
|
||||
TestConvertSpaceToDepth<DT_HALF>(this);
|
||||
TestConvertSpaceToDepth<DT_INT32>(this);
|
||||
}
|
||||
|
||||
} // namespace convert
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
@ -290,17 +291,17 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
|
||||
VLOG(1) << "Executing TRT calibration: " << name();
|
||||
helper->Ref();
|
||||
core::ScopedUnref sc(helper);
|
||||
auto res_mgr = ctx->resource_manager();
|
||||
TRTCalibrationResource* calib_res = nullptr;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
res_mgr->LookupOrCreate(
|
||||
"TF_TRT_Calibration", name(),
|
||||
ctx->resource_manager()->LookupOrCreate(
|
||||
"TF-TRT-Calibration", name(),
|
||||
reinterpret_cast<SerializableResourceBase**>(&calib_res),
|
||||
{[ctx, this](SerializableResourceBase** cr) -> Status {
|
||||
return this->AllocateCalibrationResources(ctx, cr);
|
||||
}}));
|
||||
core::ScopedUnref calib_sc(calib_res);
|
||||
int num_inputs = ctx->num_inputs();
|
||||
// TODO(laigd): need to check that input shape matches.
|
||||
// Pass input data to calibrator
|
||||
std::unordered_map<string, void*> input_data;
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
@ -425,8 +426,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
const_cast<float*>(input_tensor.flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(ERROR) << "FP16 inputs are not supported yet!";
|
||||
return kRetry;
|
||||
buffers[binding_index] =
|
||||
const_cast<Eigen::half*>(input_tensor.flat<Eigen::half>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(ERROR) << "INT8 inputs are not supported yet!";
|
||||
return kRetry;
|
||||
@ -480,8 +482,9 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
|
||||
const_cast<float*>(output_tensor->flat<float>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
LOG(WARNING) << "half size is not supported yet!";
|
||||
return kRetry;
|
||||
buffers[binding_index] =
|
||||
const_cast<Eigen::half*>(output_tensor->flat<Eigen::half>().data());
|
||||
break;
|
||||
case nvinfer1::DataType::kINT8:
|
||||
LOG(WARNING) << "int8 is not supported yet!";
|
||||
return kRetry;
|
||||
@ -522,10 +525,22 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
// TODO(tmorris): using first input to get batch size - is this reliable?
|
||||
const int batch_size = input_shapes[0].dim_size(0);
|
||||
|
||||
// Get engine cache
|
||||
// Canonicalize the op name by removing the scopes if any. This is mainly
|
||||
// because in TFv2, the function graph can be instantiated in various ways and
|
||||
// it'll insert scope names to the name of the TRTEngineOps, which will result
|
||||
// in many different engine caches if we use the instantiated op name
|
||||
// directly, but we still want all of them share the same cache (if they were
|
||||
// representing the same subgraph).
|
||||
absl::string_view resource_name = name();
|
||||
size_t last_slash = resource_name.find_last_of('/');
|
||||
if (last_slash != absl::string_view::npos) {
|
||||
resource_name.remove_prefix(last_slash + 1);
|
||||
}
|
||||
|
||||
// Get engine cache.
|
||||
TRTEngineCacheResource* cache_res = nullptr;
|
||||
auto status = ctx->resource_manager()->LookupOrCreate(
|
||||
"TRTEngineCache", name(), &cache_res,
|
||||
"TF-TRT-Engine-Cache", string(resource_name), &cache_res,
|
||||
{[this, ctx](TRTEngineCacheResource** cr) -> Status {
|
||||
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
|
||||
return Status::OK();
|
||||
@ -632,12 +647,13 @@ EngineContext* TRTEngineOp::GetEngine(
|
||||
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
|
||||
return &empty_context;
|
||||
}
|
||||
VLOG(1) << "Conversion is done";
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
engine->createExecutionContext());
|
||||
cache.emplace(engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(std::move(engine),
|
||||
std::move(exec_context)));
|
||||
VLOG(1) << "Added new engine to cache of " << name()
|
||||
<< ". Cache size: " << cache.size();
|
||||
}
|
||||
return cache.at(engine_input_shapes).get();
|
||||
}
|
||||
|
106
tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
Normal file
106
tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc
Normal file
@ -0,0 +1,106 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
template <typename T>
|
||||
class TRTEngineOpTest : public OpsTestBase {};
|
||||
|
||||
using TypeList = ::testing::Types<float, Eigen::half>;
|
||||
TYPED_TEST_SUITE(TRTEngineOpTest, TypeList);
|
||||
|
||||
TYPED_TEST(TRTEngineOpTest, Basic) {
|
||||
DataType dtype = DataTypeToEnum<TypeParam>::v();
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
|
||||
// Create simple TF graph.
|
||||
Scope s = Scope::NewRootScope();
|
||||
auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype,
|
||||
ops::Placeholder::Shape({1, 2}));
|
||||
auto add = ops::Add(s.WithOpName("add"), feed, feed);
|
||||
ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add);
|
||||
|
||||
// Serialize the graph. TRTEngineOp will convert it using dynamic mode.
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(s.ToGraphDef(&graph_def));
|
||||
TensorShapeProto shape;
|
||||
TensorShape({1, 2}).AsProto(&shape);
|
||||
|
||||
// Create the op.
|
||||
OpsTestBase::SetDevice(DEVICE_GPU, std::move(device));
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "TRTEngineOp")
|
||||
.Input(FakeInput(1, dtype))
|
||||
.Attr("input_shapes", {shape})
|
||||
.Attr("output_shapes", {shape})
|
||||
.Attr("static_engine", false)
|
||||
.Attr("segment_funcdef_name", "") // no native fallback
|
||||
.Attr("serialized_segment", graph_def.SerializeAsString())
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Attr("workspace_size_bytes", 1 << 20)
|
||||
.Attr("precision_mode", "FP32")
|
||||
.Attr("use_calibration", false)
|
||||
.Attr("OutT", {dtype})
|
||||
.Finalize(OpsTestBase::node_def()));
|
||||
TF_ASSERT_OK(OpsTestBase::InitOp());
|
||||
|
||||
// Execute the op.
|
||||
OpsTestBase::AddInputFromArray<TypeParam>(TensorShape({1, 2}),
|
||||
{TypeParam(0.0f), TypeParam(1.0f)});
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
||||
Tensor* output = OpsTestBase::context_->mutable_output(0);
|
||||
const auto& tensor_map = output->flat<TypeParam>();
|
||||
std::vector<TypeParam> output_data(tensor_map.size());
|
||||
ASSERT_EQ(0, cudaDeviceSynchronize());
|
||||
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
|
||||
sizeof(TypeParam) * tensor_map.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
|
||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,223 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::nvinfer1::IRuntime;
|
||||
|
||||
class CreateTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit CreateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
VLOG(1) << "Creating TRT engine cache resource in container " << container_
|
||||
<< " for op " << resource_name_ << " on device "
|
||||
<< ctx->device()->name();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->Create(
|
||||
container_, resource_name_,
|
||||
new TRTEngineCacheResource(ctx, max_cached_engines_)));
|
||||
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->scalar<ResourceHandle>()() =
|
||||
MakeResourceHandle<TRTEngineCacheResource>(ctx, container_,
|
||||
resource_name_);
|
||||
}
|
||||
|
||||
private:
|
||||
string container_;
|
||||
string resource_name_;
|
||||
|
||||
// Maximum number of cached engines
|
||||
int max_cached_engines_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCache")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
CreateTRTEngineCache);
|
||||
|
||||
class PopulateTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
ResourceHandle handle = HandleFromInput(ctx, 0);
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, handle, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
auto allocator = resource->allocator_.get();
|
||||
OP_REQUIRES(ctx, allocator != nullptr,
|
||||
errors::Internal("Not able to initialize TRT engine cache when "
|
||||
"GPU allocator is empty."));
|
||||
OP_REQUIRES(ctx, resource->cache_.size() == 0,
|
||||
errors::Internal("Expect engine cache to be empty, but got ",
|
||||
resource->cache_.size(), " entries."));
|
||||
|
||||
// Get the file name.
|
||||
const string& filename = ctx->input(1).scalar<string>()();
|
||||
OP_REQUIRES(ctx, !filename.empty(),
|
||||
errors::InvalidArgument("filename cannot be empty."));
|
||||
|
||||
// Parse the serialized engines and add them to the cache.
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
OP_REQUIRES_OK(ctx, ctx->env()->NewRandomAccessFile(filename, &file));
|
||||
auto reader = absl::make_unique<io::RecordReader>(file.get());
|
||||
|
||||
uint64 offset = 0;
|
||||
int num_loaded_engine = 0;
|
||||
do {
|
||||
string record;
|
||||
Status status = reader->ReadRecord(&offset, &record);
|
||||
if (errors::IsOutOfRange(status)) break;
|
||||
|
||||
TRTEngineInstance engine_instance;
|
||||
engine_instance.ParseFromString(record);
|
||||
std::vector<TensorShape> engine_input_shapes;
|
||||
for (const TensorShapeProto& shape : engine_instance.input_shapes()) {
|
||||
engine_input_shapes.emplace_back(shape);
|
||||
}
|
||||
|
||||
TrtUniquePtrType<IRuntime> infer(
|
||||
nvinfer1::createInferRuntime(TRTEngineCacheResource::GetLogger()));
|
||||
infer->setGpuAllocator(allocator);
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
|
||||
infer->deserializeCudaEngine(
|
||||
engine_instance.serialized_engine().c_str(),
|
||||
engine_instance.serialized_engine().size(),
|
||||
PluginFactoryTensorRT::GetInstance()));
|
||||
auto raw_engine = engine.get();
|
||||
resource->cache_.emplace(
|
||||
engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
std::move(engine), TrtUniquePtrType<nvinfer1::IExecutionContext>(
|
||||
raw_engine->createExecutionContext())));
|
||||
++num_loaded_engine;
|
||||
} while (1);
|
||||
VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container "
|
||||
<< handle.container() << " for op " << handle.name()
|
||||
<< " on device " << ctx->device()->name() << " from file "
|
||||
<< filename;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("engine_cache_handle"),
|
||||
PopulateTRTEngineCache);
|
||||
|
||||
class DumpTRTEngineCache : public OpKernel {
|
||||
public:
|
||||
explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump",
|
||||
&delete_cache_after_dump_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const string& container = ctx->input(0).scalar<string>()();
|
||||
const string& resource_name = ctx->input(1).scalar<string>()();
|
||||
const string& filename = ctx->input(2).scalar<string>()();
|
||||
OP_REQUIRES(ctx, !filename.empty(),
|
||||
errors::InvalidArgument("filename cannot be empty."));
|
||||
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup(
|
||||
container, resource_name, &resource));
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Serialize the engines and write them to file.
|
||||
std::unique_ptr<WritableFile> file;
|
||||
OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file));
|
||||
auto writer = absl::make_unique<io::RecordWriter>(file.get());
|
||||
|
||||
for (const auto& pair : resource->cache_) {
|
||||
TRTEngineInstance engine_instance;
|
||||
// Add input shapes.
|
||||
const std::vector<TensorShape>& engine_input_shapes = pair.first;
|
||||
for (const TensorShape& shape : engine_input_shapes) {
|
||||
shape.AsProto(engine_instance.add_input_shapes());
|
||||
}
|
||||
// Add the serialized engine.
|
||||
const std::unique_ptr<EngineContext>& engine = pair.second;
|
||||
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(
|
||||
engine->cuda_engine->serialize());
|
||||
engine_instance.set_serialized_engine(engine_data->data(),
|
||||
engine_data->size());
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
writer->WriteRecord(engine_instance.SerializeAsString()));
|
||||
}
|
||||
VLOG(1) << "Serialized " << resource->cache_.size()
|
||||
<< " TRT engines in container " << container << " for op "
|
||||
<< resource_name << " on device " << ctx->device()->name()
|
||||
<< " to file " << filename;
|
||||
|
||||
if (delete_cache_after_dump_) {
|
||||
VLOG(1) << "Destroying TRT engine cache resource in container "
|
||||
<< container << " for op " << resource_name << " on device "
|
||||
<< ctx->device()->name();
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->resource_manager()->Delete<TRTEngineCacheResource>(
|
||||
container, resource_name));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool delete_cache_after_dump_ = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU),
|
||||
DumpTRTEngineCache);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,205 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_instance.pb.h" // NOLINT
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
class TRTEngineResourceOpsTest : public OpsTestBase {
|
||||
protected:
|
||||
void Reset() {
|
||||
inputs_.clear();
|
||||
gtl::STLDeleteElements(&tensors_);
|
||||
gtl::STLDeleteElements(&managed_outputs_);
|
||||
}
|
||||
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> CreateTRTEngine() {
|
||||
Logger logger;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(logger));
|
||||
TrtUniquePtrType<nvinfer1::INetworkDefinition> network(
|
||||
builder->createNetwork());
|
||||
|
||||
// Add the input.
|
||||
nvinfer1::Dims dims;
|
||||
dims.nbDims = 1;
|
||||
dims.d[0] = 1;
|
||||
nvinfer1::ITensor* input =
|
||||
network->addInput("input", nvinfer1::DataType::kFLOAT, dims);
|
||||
EXPECT_NE(nullptr, input);
|
||||
|
||||
// Add a unary layer.
|
||||
nvinfer1::IUnaryLayer* layer =
|
||||
network->addUnary(*input, nvinfer1::UnaryOperation::kEXP);
|
||||
EXPECT_NE(nullptr, layer);
|
||||
|
||||
// Mark the output.
|
||||
nvinfer1::ITensor* output = layer->getOutput(0);
|
||||
output->setName("output");
|
||||
network->markOutput(*output);
|
||||
|
||||
// Build the engine
|
||||
builder->setMaxBatchSize(1);
|
||||
builder->setMaxWorkspaceSize(1 << 10);
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine(
|
||||
builder->buildCudaEngine(*network));
|
||||
EXPECT_NE(nullptr, engine);
|
||||
return engine;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
// Create the GPU device.
|
||||
std::unique_ptr<Device> device(
|
||||
DeviceFactory::NewDevice("GPU", {}, "/job:worker/replica:0/task:0"));
|
||||
ResourceMgr* rm = device->resource_manager();
|
||||
SetDevice(DEVICE_GPU, std::move(device));
|
||||
|
||||
// Create the resource.
|
||||
const string container = "mycontainer";
|
||||
const string resource_name = "myresource";
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
|
||||
.Attr("container", container)
|
||||
.Attr("resource_name", resource_name)
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
ResourceHandle handle =
|
||||
context_->mutable_output(0)->scalar<ResourceHandle>()();
|
||||
|
||||
TRTEngineCacheResource* resource = nullptr;
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
|
||||
// Create a serialized TRT engine file.
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine();
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> context(
|
||||
engine->createExecutionContext());
|
||||
resource->cache_.emplace(
|
||||
std::vector<TensorShape>{TensorShape({1, 1})},
|
||||
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
|
||||
resource->Unref();
|
||||
|
||||
// Serialize the engine using DumpTRTEngineCache op.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache")
|
||||
.Attr("delete_cache_after_dump", true)
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<string>(TensorShape({}), {container});
|
||||
AddInputFromArray<string>(TensorShape({}), {resource_name});
|
||||
const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file");
|
||||
AddInputFromArray<string>(TensorShape({}), {filename});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Make sure the cache is deleted.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
|
||||
.Attr("ignore_lookup_error", false)
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
|
||||
|
||||
// Verify the serialized engine file.
|
||||
Env* env = Env::Default();
|
||||
std::unique_ptr<RandomAccessFile> file;
|
||||
TF_ASSERT_OK(env->NewRandomAccessFile(filename, &file));
|
||||
auto reader = absl::make_unique<io::RecordReader>(file.get());
|
||||
uint64 offset = 0;
|
||||
string record;
|
||||
TF_ASSERT_OK(reader->ReadRecord(&offset, &record));
|
||||
TRTEngineInstance engine_instance;
|
||||
engine_instance.ParseFromString(record);
|
||||
EXPECT_EQ(1, engine_instance.input_shapes_size());
|
||||
EXPECT_EQ(2, engine_instance.input_shapes(0).dim_size());
|
||||
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(0).size());
|
||||
EXPECT_EQ(1, engine_instance.input_shapes(0).dim(1).size());
|
||||
EXPECT_TRUE(errors::IsOutOfRange(reader->ReadRecord(&offset, &record)));
|
||||
|
||||
// Recreate the cache resource.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCache")
|
||||
.Attr("container", container)
|
||||
.Attr("resource_name", resource_name)
|
||||
.Attr("max_cached_engines_count", 1)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
handle = context_->mutable_output(0)->scalar<ResourceHandle>()();
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
EXPECT_EQ(0, resource->cache_.size());
|
||||
resource->Unref();
|
||||
|
||||
// Deserialize the engine using PopulateTRTEngineCache op.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache")
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Input(FakeInput(DT_STRING))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
AddInputFromArray<string>(TensorShape({}), {filename});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok());
|
||||
EXPECT_EQ(1, resource->cache_.size());
|
||||
resource->Unref();
|
||||
|
||||
// Destroy the engine cache again.
|
||||
Reset();
|
||||
TF_ASSERT_OK(NodeDefBuilder("op", "DestroyResourceOp")
|
||||
.Attr("ignore_lookup_error", false)
|
||||
.Input(FakeInput(DT_RESOURCE))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<ResourceHandle>(TensorShape({}), {handle});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
EXPECT_TRUE(errors::IsNotFound(RunOpKernel()));
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("CreateTRTEngineCache")
|
||||
.Attr("container: string")
|
||||
.Attr("resource_name: string")
|
||||
.Attr("max_cached_engines_count: int = 1")
|
||||
.Output("engine_cache_handle: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("PopulateTRTEngineCache")
|
||||
.Input("engine_cache_handle: resource")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("DumpTRTEngineCache")
|
||||
.Attr("delete_cache_after_dump: bool = false")
|
||||
.Input("container: string")
|
||||
.Input("resource_name: string")
|
||||
.Input("filename: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -459,7 +459,7 @@ Status SegmentGraph(const Graph* tf_graph,
|
||||
}
|
||||
LOG(INFO) << msg << "(For more information see "
|
||||
<< "https://docs.nvidia.com/deeplearning"
|
||||
<< "/dgx/integrate-tf-trt/index.html#support-ops).";
|
||||
<< "/dgx/tf-trt-user-guide/index.html#supported-ops).";
|
||||
|
||||
// The segmentation algorithm below visits nodes in reverse topological order
|
||||
// and attempts to merge nodes along output edges. That means that subgraphs
|
||||
|
@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tensorrt;
|
||||
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
// Containing information for a serialized TensorRT engine.
|
||||
message TRTEngineInstance {
|
||||
// The input shapes of the TRT engine.
|
||||
repeated TensorShapeProto input_shapes = 1;
|
||||
|
||||
// The serialized TRT engine.
|
||||
//
|
||||
// TODO(laigd): consider using a more efficient in-memory representation
|
||||
// instead of string which is the default here.
|
||||
bytes serialized_engine = 2;
|
||||
|
||||
// TODO(laigd): consider adding calibration stats, precision_modes, etc.
|
||||
}
|
79
tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
Normal file
79
tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_TENSORRT
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
Logger& TRTEngineCacheResource::GetLogger() {
|
||||
static Logger* logger = new Logger();
|
||||
return *logger;
|
||||
}
|
||||
|
||||
TRTEngineCacheResource::TRTEngineCacheResource(OpKernelContext* ctx,
|
||||
size_t capacity)
|
||||
: cache_(capacity) {
|
||||
auto device = ctx->device();
|
||||
auto alloc = device->GetAllocator(AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(ERROR) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
allocator_ = nullptr;
|
||||
} else {
|
||||
allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
}
|
||||
}
|
||||
|
||||
TRTEngineCacheResource::~TRTEngineCacheResource() {
|
||||
VLOG(1) << "Destroying TRTEngineCacheResource...";
|
||||
}
|
||||
|
||||
string TRTEngineCacheResource::DebugString() const {
|
||||
std::stringstream oss;
|
||||
using std::dec;
|
||||
using std::endl;
|
||||
using std::hex;
|
||||
oss << "TRTEngineCacheResource: ";
|
||||
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
|
||||
oss << "LRUCache = " << hex << &cache_ << dec << endl;
|
||||
oss << "Containing " << cache_.size() << " entries: " << endl;
|
||||
for (const auto& item : cache_) {
|
||||
mutex_lock lock(item.second->mu);
|
||||
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
|
||||
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
|
||||
<< "IExecutionContext: " << item.second->execution_context.get() << dec
|
||||
<< endl;
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_TENSORRT
|
||||
#endif // GOOGLE_CUDA
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -141,36 +142,18 @@ struct EngineContext {
|
||||
|
||||
class TRTEngineCacheResource : public ResourceBase {
|
||||
public:
|
||||
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity)
|
||||
: cache_(capacity) {
|
||||
auto device = ctx->device();
|
||||
auto alloc = device->GetAllocator(AllocatorAttributes());
|
||||
if (!alloc) {
|
||||
LOG(ERROR) << "Can't find device allocator for gpu device "
|
||||
<< device->name();
|
||||
allocator_ = nullptr;
|
||||
} else {
|
||||
allocator_.reset(new TRTDeviceAllocator(alloc));
|
||||
}
|
||||
}
|
||||
// According to the TensorRT API, the logger is considered a singleton by the
|
||||
// TensorRT library, and multiple instances of IRuntime and/or IBuilder must
|
||||
// all use the same logger. So here we make it a singleton.
|
||||
//
|
||||
// TODO(laigd): use this logger in all places where conversion happens.
|
||||
static Logger& GetLogger();
|
||||
|
||||
string DebugString() const override {
|
||||
std::stringstream oss;
|
||||
using std::dec;
|
||||
using std::endl;
|
||||
using std::hex;
|
||||
oss << "TRTEngineCacheResource: ";
|
||||
oss << "TRTBaseAllocator = " << hex << allocator_.get() << dec << ", ";
|
||||
oss << "LRUCache = " << hex << &cache_ << dec << endl;
|
||||
oss << "Containing " << cache_.size() << " entries: " << endl;
|
||||
for (const auto& item : cache_) {
|
||||
oss << TensorShapeUtils::ShapeListString(item.first) << ": " << hex
|
||||
<< "ICudaEngine: " << item.second.get()->cuda_engine.get() << ", "
|
||||
<< "IExecutionContext: " << item.second.get()->execution_context.get()
|
||||
<< dec << endl;
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity);
|
||||
|
||||
~TRTEngineCacheResource() override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
// Keep device allocator for TRT.
|
||||
std::unique_ptr<TRTBaseAllocator> allocator_;
|
||||
|
@ -210,6 +210,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
@ -376,6 +377,7 @@ tf_cc_test(
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:resource_variable_ops",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -390,6 +392,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -100,9 +100,12 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
arg.name = resource->name();
|
||||
break;
|
||||
}
|
||||
case XlaExpression::Kind::kTensorList:
|
||||
return errors::Unimplemented(
|
||||
"TensorList as function argument is not yet implemented.");
|
||||
case XlaExpression::Kind::kTensorList: {
|
||||
arg.kind = XlaCompiler::Argument::kTensorList;
|
||||
const xla::XlaOp& tensor_list = expressions[i]->handle();
|
||||
arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
|
||||
break;
|
||||
}
|
||||
case XlaExpression::Kind::kInvalid:
|
||||
return errors::InvalidArgument("Invalid function argument");
|
||||
}
|
||||
@ -302,8 +305,13 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
if (result.outputs[i].is_constant) {
|
||||
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
|
||||
} else {
|
||||
xla_op_context.SetOutput(
|
||||
i, xla::GetTupleElement(output_handle, computation_output));
|
||||
if (result.outputs[i].is_tensor_list) {
|
||||
xla_op_context.SetTensorListOutput(
|
||||
i, xla::GetTupleElement(output_handle, computation_output));
|
||||
} else {
|
||||
xla_op_context.SetOutput(
|
||||
i, xla::GetTupleElement(output_handle, computation_output));
|
||||
}
|
||||
++computation_output;
|
||||
}
|
||||
}
|
||||
|
@ -58,6 +58,7 @@ class AddNOp : public XlaOpKernel {
|
||||
xla::XlaOp push_index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &push_index));
|
||||
OP_REQUIRES_OK(ctx, BuildTensorList(sum, push_index, &sum));
|
||||
ctx->SetTensorListOutput(0, sum);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -65,9 +66,8 @@ class AddNOp : public XlaOpKernel {
|
||||
for (int i = 1; i < ctx->num_inputs(); ++i) {
|
||||
sum = xla::Add(sum, ctx->Input(i));
|
||||
}
|
||||
ctx->SetOutput(0, sum);
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, sum);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -44,6 +44,7 @@ class BatchMatMulOp : public XlaOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("BatchMatMul"), BatchMatMulOp);
|
||||
REGISTER_XLA_OP(Name("BatchMatMulV2"), BatchMatMulOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA implementations of Categorical op.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -35,7 +36,9 @@ namespace {
|
||||
|
||||
class CategoricalOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
explicit CategoricalOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx),
|
||||
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
// Get the logits
|
||||
@ -100,8 +103,15 @@ class CategoricalOp : public XlaOpKernel {
|
||||
xla::PrimitiveType xla_output_type;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
DataTypeToPrimitiveType(output_type(0), &xla_output_type));
|
||||
xla::XlaOp argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
xla::XlaOp argmax;
|
||||
if (is_gpu_) {
|
||||
argmax = xla::ArgMaxTwoPass(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
} else {
|
||||
argmax = xla::ArgMax(softmax_entries, xla_output_type,
|
||||
/*axis=*/class_dimension);
|
||||
}
|
||||
|
||||
if (num_samples == 1) {
|
||||
argmax = xla::Reshape(argmax, {batch_size, 1});
|
||||
}
|
||||
@ -123,6 +133,7 @@ class CategoricalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_gpu_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
|
||||
};
|
||||
|
||||
@ -140,8 +151,6 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
|
||||
XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp seed = ctx->Input(2);
|
||||
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
if (uniform_shape.element_type() == xla::BF16) {
|
||||
@ -150,8 +159,8 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
|
||||
// * log(-log(0)) is ∞.
|
||||
// * log(-log(1)) is -∞.
|
||||
auto uniforms = xla::StatelessRngUniform(
|
||||
{seed0, seed1}, uniform_shape,
|
||||
xla::XlaOp uniforms = StatelessRngUniform(
|
||||
seed, uniform_shape,
|
||||
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
|
||||
xla::One(builder, uniform_shape.element_type()));
|
||||
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// XLA-specific Ops for 2D convolution.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
@ -293,10 +294,9 @@ xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
|
||||
return attrs;
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
|
||||
xla::XlaOp conv_input,
|
||||
xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs) {
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
StringPiece /*type_string*/, xla::XlaOp conv_input, xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
auto* builder = conv_input.builder();
|
||||
@ -377,12 +377,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
|
||||
return xla::ConvGeneralDilated(
|
||||
conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dims,
|
||||
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count);
|
||||
/*feature_group_count=*/attrs.depthwise ? in_depth : feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
int num_dims = attrs.num_spatial_dims + 2;
|
||||
@ -456,13 +458,14 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
/*feature_group_count=*/
|
||||
attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
|
||||
filter_shape.dimensions(attrs.num_spatial_dims + 1)
|
||||
: feature_group_count);
|
||||
: feature_group_count,
|
||||
/*batch_group_count=*/1, precision_config);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
StringPiece type_string, xla::XlaOp activations,
|
||||
const xla::Shape& filter_shape, xla::XlaOp gradients,
|
||||
const ConvOpAttrs& attrs) {
|
||||
const ConvOpAttrs& attrs, const xla::PrecisionConfig* precision_config) {
|
||||
TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
|
||||
|
||||
auto* builder = activations.builder();
|
||||
@ -612,7 +615,8 @@ xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
activations, gradients, window_strides, padding, /*lhs_dilation=*/ones,
|
||||
rhs_dilation, dnums,
|
||||
/*feature_group_count=*/feature_group_count,
|
||||
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1);
|
||||
/*batch_group_count=*/use_batch_group_count ? dims.in_depth : 1,
|
||||
precision_config);
|
||||
|
||||
if (!use_batch_group_count && attrs.depthwise) {
|
||||
filter_backprop = ContractFilterForDepthwiseBackprop(
|
||||
|
@ -53,17 +53,19 @@ struct ConvOpAttrs {
|
||||
|
||||
// Creates a new XLA forward or backward convolution with the given inputs and
|
||||
// attributes.
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
|
||||
xla::XlaOp conv_input,
|
||||
xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(
|
||||
StringPiece type_string, xla::XlaOp conv_input, xla::XlaOp filter,
|
||||
const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
|
||||
StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
|
||||
xla::XlaOp out_backprop, const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
|
||||
StringPiece type_string, xla::XlaOp activations,
|
||||
const xla::Shape& filter_shape, xla::XlaOp gradients,
|
||||
const ConvOpAttrs& attrs);
|
||||
const ConvOpAttrs& attrs,
|
||||
const xla::PrecisionConfig* precision_config = nullptr);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -82,33 +82,71 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
|
||||
return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
|
||||
}
|
||||
|
||||
// Builds a custom_call to a method named 'fake_quant_with_min_max_vars'.
|
||||
// The method will be provided the input, the min/max range from the original
|
||||
// TensorFlow op, and the num_bits and narrow_range attributes.
|
||||
xla::StatusOr<xla::XlaOp> BuildFakeQuantCustomCall(
|
||||
xla::XlaBuilder* b, xla::XlaOp input, xla::XlaOp input_min,
|
||||
xla::XlaOp input_max, int num_bits, bool narrow_range) {
|
||||
xla::XlaOp num_bits_arg =
|
||||
XlaHelpers::IntegerLiteral(b, DataType::DT_INT32, num_bits);
|
||||
xla::XlaOp narrow_range_arg = narrow_range
|
||||
? XlaHelpers::One(b, DataType::DT_BOOL)
|
||||
: XlaHelpers::Zero(b, DataType::DT_BOOL);
|
||||
|
||||
std::vector<xla::XlaOp> args = {input, input_min, input_max, num_bits_arg,
|
||||
narrow_range_arg};
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
for (const xla::XlaOp& arg : args) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape arg_shape, b->GetShape(arg));
|
||||
*arg_shape.mutable_layout() =
|
||||
xla::LayoutUtil::MakeDescendingLayout(arg_shape.rank());
|
||||
arg_shapes.push_back(std::move(arg_shape));
|
||||
}
|
||||
|
||||
// Input and output shapes match exactly.
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape output_shape, b->GetShape(input));
|
||||
|
||||
return xla::CustomCallWithLayout(b, "fake_quant_with_min_max_vars", args,
|
||||
output_shape, arg_shapes);
|
||||
}
|
||||
|
||||
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
|
||||
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
quant_min_ = narrow_range ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
num_bits_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
|
||||
quant_min_ = narrow_range_ ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
|
||||
float input_min, input_max;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
|
||||
CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max_));
|
||||
CpuNudge(input_min_, input_max_, quant_min_, quant_max_, &nudged_input_min_,
|
||||
&nudged_input_max_, &input_scale_);
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
const DataType data_type = ctx->input_type(0);
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
if (ctx->compiler()->options().allow_cpu_custom_calls &&
|
||||
ctx->compiler()->options().custom_fake_quant_op_calls) {
|
||||
xla::XlaOp custom_call_output =
|
||||
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
|
||||
b, input,
|
||||
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_min_),
|
||||
XlaHelpers::FloatLiteral(b, DataType::DT_FLOAT, input_max_),
|
||||
num_bits_, narrow_range_));
|
||||
ctx->SetOutput(0, custom_call_output);
|
||||
return;
|
||||
}
|
||||
|
||||
xla::XlaOp nudged_input_min =
|
||||
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
|
||||
xla::XlaOp nudged_input_max =
|
||||
@ -121,6 +159,10 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
int num_bits_;
|
||||
bool narrow_range_;
|
||||
float input_min_;
|
||||
float input_max_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
float nudged_input_min_;
|
||||
@ -184,25 +226,32 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
|
||||
OP_REQUIRES(ctx, num_bits_ >= 2 && num_bits_ <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
quant_min_ = narrow_range ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
num_bits_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
|
||||
quant_min_ = narrow_range_ ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits_) - 1;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
const DataType data_type = ctx->input_type(0);
|
||||
xla::XlaOp input_min = ctx->Input(1);
|
||||
xla::XlaOp input_max = ctx->Input(2);
|
||||
|
||||
xla::XlaBuilder* b = ctx->builder();
|
||||
if (ctx->compiler()->options().allow_cpu_custom_calls &&
|
||||
ctx->compiler()->options().custom_fake_quant_op_calls) {
|
||||
xla::XlaOp custom_call_output =
|
||||
b->ReportErrorOrReturn(BuildFakeQuantCustomCall(
|
||||
b, input, input_min, input_max, num_bits_, narrow_range_));
|
||||
ctx->SetOutput(0, custom_call_output);
|
||||
return;
|
||||
}
|
||||
|
||||
xla::XlaOp nudged_input_min, nudged_input_max, input_scale;
|
||||
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
|
||||
&nudged_input_min, &nudged_input_max, &input_scale);
|
||||
@ -213,6 +262,8 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
int num_bits_;
|
||||
bool narrow_range_;
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
};
|
||||
|
@ -31,7 +31,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min)
|
||||
: XlaOpKernel(ctx), is_min_(is_min) {}
|
||||
: XlaOpKernel(ctx),
|
||||
is_min_(is_min),
|
||||
is_gpu_(ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
|
||||
|
||||
void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
@ -64,10 +66,19 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
xla::XlaOp output;
|
||||
// One pass ArgMin/ArgMax is slow on GPUs.
|
||||
if (is_min_) {
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMinTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMin(input, index_xla_type, axis);
|
||||
}
|
||||
} else {
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
if (is_gpu_) {
|
||||
output = xla::ArgMaxTwoPass(input, index_xla_type, axis);
|
||||
} else {
|
||||
output = xla::ArgMax(input, index_xla_type, axis);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, output);
|
||||
|
@ -30,6 +30,7 @@ class XlaArgMinMaxOp : public XlaOpKernel {
|
||||
|
||||
private:
|
||||
const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)?
|
||||
const bool is_gpu_;
|
||||
};
|
||||
|
||||
class XlaArgMaxOp : public XlaArgMinMaxOp {
|
||||
|
@ -22,6 +22,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). The raw random bits are generated by the given
|
||||
// `bit_generator` and converted to the requested data type and range. This
|
||||
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
|
||||
// of type F32, S32 and S64.
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval);
|
||||
|
||||
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||
|
@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
|
||||
TensorShape shape;
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
bool shape_has_zero_dim = false;
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = shape_input[d];
|
||||
if (size == -1) {
|
||||
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
|
||||
unknown_index, " and ", d));
|
||||
unknown_index = d;
|
||||
shape.AddDim(1);
|
||||
} else if (size == 0) {
|
||||
// We don't include zero-sized dimension in product, so that we can
|
||||
// still calculate number of elements for non-zero-sized dimensions and
|
||||
// therefore infer their shapes.
|
||||
shape.AddDim(size);
|
||||
shape_has_zero_dim = true;
|
||||
} else {
|
||||
OP_REQUIRES(ctx, size >= 0,
|
||||
errors::InvalidArgument(
|
||||
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
|
||||
}
|
||||
}
|
||||
if (unknown_index != -1) {
|
||||
OP_REQUIRES(
|
||||
ctx, product > 0,
|
||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
||||
"for an empty tensor unless all specified "
|
||||
"input sizes are non-zero"));
|
||||
const int64 missing = input_shape.num_elements() / product;
|
||||
OP_REQUIRES(
|
||||
ctx, product * missing == input_shape.num_elements(),
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input_shape.num_elements(),
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
int64 input_num_elements = 1;
|
||||
bool input_has_zero_dim = false;
|
||||
for (int dim = 0; dim < input_shape.dims(); dim++) {
|
||||
// For zero dimension, we don't count it into `input_num_elements`
|
||||
// unless `sizes` has no zero dimension, so we are still able to
|
||||
// infer shapes for other dimensions.
|
||||
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
|
||||
input_num_elements *= input_shape.dim_size(dim);
|
||||
} else {
|
||||
input_has_zero_dim = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 missing = input_num_elements / product;
|
||||
if (!input_has_zero_dim) {
|
||||
OP_REQUIRES(
|
||||
ctx, product * missing == input_num_elements,
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input_num_elements,
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
}
|
||||
shape.set_dim(unknown_index, missing);
|
||||
}
|
||||
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
// XLA-specific Shape Ops.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -223,14 +225,33 @@ class ZerosLikeOp : public XlaOpKernel {
|
||||
explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
if (IsTensorListInput(ctx, 0)) {
|
||||
// Input is a TensorList.
|
||||
// TODO(b/124707753): support nested TensorList.
|
||||
xla::XlaOp tensor_list = ctx->Input(0);
|
||||
TensorShape shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(tensor_list, &shape));
|
||||
xla::PrimitiveType type;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPrimitiveType(tensor_list, &type));
|
||||
xla::XlaOp buffer;
|
||||
OP_REQUIRES_OK(ctx, CreateZerosList(ctx, shape, type, &buffer));
|
||||
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
||||
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
|
||||
xla::XlaOp push_index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(tensor_list, &push_index));
|
||||
|
||||
xla::XlaOp output_list;
|
||||
OP_REQUIRES_OK(ctx, BuildTensorList(buffer, push_index, &output_list));
|
||||
ctx->SetTensorListOutput(0, output_list);
|
||||
} else {
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
|
||||
ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
|
||||
REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp);
|
||||
|
||||
class OnesLikeOp : public XlaOpKernel {
|
||||
public:
|
||||
|
@ -35,127 +35,50 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
||||
xla::XlaOp counter, const int64 size) {
|
||||
auto builder = counter.builder();
|
||||
auto input_u64 = Iota(builder, xla::U64, size);
|
||||
input_u64 = input_u64 + counter;
|
||||
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
||||
}
|
||||
|
||||
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
||||
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
||||
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
||||
// stateful ops, hence the following 2 new functions.
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
auto builder = key.builder();
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||
xla::XlaOp counter,
|
||||
const xla::Shape& shape,
|
||||
xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
auto builder = key.builder();
|
||||
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
|
||||
const xla::Shape& shape, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
||||
counter);
|
||||
}
|
||||
case xla::U32: // fall through
|
||||
case xla::S32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
||||
counter);
|
||||
}
|
||||
case xla::U64: // fall through
|
||||
case xla::S64: {
|
||||
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
||||
counter);
|
||||
}
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
case xla::U32:
|
||||
case xla::S32:
|
||||
case xla::U64:
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
default:
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
return {key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename A2>
|
||||
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
||||
return std::make_pair(f(p.first), p.second);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
|
||||
xla::XlaOp initial_state,
|
||||
const xla::Shape& shape) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
return StatefulRngUniformU32(key, counter, shape);
|
||||
case xla::S32: {
|
||||
// Needs explicit function type because of type-inference failure.
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S32);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
||||
}
|
||||
case xla::U64:
|
||||
return StatefulRngUniformU64(key, counter, shape);
|
||||
case xla::S64: {
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S64);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
||||
}
|
||||
return output;
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
output.value = BitcastConvertType(output.value, type);
|
||||
return output;
|
||||
default:
|
||||
auto builder = key.builder();
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
return {
|
||||
key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatefulRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
0);
|
||||
}
|
||||
|
||||
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
||||
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
|
||||
|
||||
// A helper function containing the common part of several kernels below.
|
||||
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
int alg_input_idx, int shape_input_idx,
|
||||
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
||||
TensorShape)> const&
|
||||
sample_with_threefry) {
|
||||
Status CompileImpl(
|
||||
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
|
||||
int shape_input_idx,
|
||||
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
|
||||
sampler) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||
|
||||
static constexpr int COUNTER_SIZE = 1;
|
||||
auto counter = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
||||
static constexpr int kStateSize = 1;
|
||||
auto state = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
|
||||
auto key = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
||||
{}),
|
||||
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
|
||||
xla::U64);
|
||||
|
||||
auto status_or_value = sample_with_threefry(counter, key, shape);
|
||||
auto status_or_value = sampler(state, key, shape);
|
||||
if (!status_or_value.ok()) {
|
||||
return status_or_value.status();
|
||||
}
|
||||
auto output_counter = status_or_value.ConsumeValueOrDie();
|
||||
auto output = output_counter.first;
|
||||
counter = output_counter.second;
|
||||
ctx->SetOutput(0, output);
|
||||
auto builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {counter, key});
|
||||
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
|
||||
state = value_state.state;
|
||||
ctx->SetOutput(0, value_state.value);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {state, key});
|
||||
xla::PrimitiveType state_element_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry = [builder, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::RngOutput uniform_state = StatefulRngUniform(
|
||||
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_state.value;
|
||||
state = uniform_state.state;
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
return {{uniform, counter}};
|
||||
return {{uniform, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
auto normal =
|
||||
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
return {{normal, counter}};
|
||||
xla::RngOutput value_state = xla::NormalF32Distribution(
|
||||
key, state, xla::ThreeFryBitGenerator, xla_shape);
|
||||
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
|
||||
return {{normal, value_state.state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::RngOutput uniform_result = StatefulRngUniform(
|
||||
key, state, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_result.value;
|
||||
state = uniform_result.state;
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
return {{truncated_normal, counter}};
|
||||
return {{truncated_normal, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
|
||||
xla::XlaOp minval = ctx->Input(3);
|
||||
xla::XlaOp maxval = ctx->Input(4);
|
||||
auto sample_with_threefry = [minval, maxval, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
||||
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto sample_with_threefry = [this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
||||
return StatefulRngUniformFullInt(key, state, xla_shape);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
|
@ -36,8 +36,8 @@ namespace tensorflow {
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
if (dtype == DT_BFLOAT16) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
|
||||
xla::BF16);
|
||||
} else {
|
||||
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
}
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval) {
|
||||
xla::XlaBuilder* builder = seeds.builder();
|
||||
|
||||
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
|
||||
// values with uniform distribution in the range [minval, maxval) for the given
|
||||
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
|
||||
// S64 are implemented.
|
||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
|
||||
xla::XlaOp seed, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval,
|
||||
maxval, shape)
|
||||
.value;
|
||||
case xla::S32: // fall through
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval, maxval,
|
||||
shape)
|
||||
.value;
|
||||
break;
|
||||
default:
|
||||
return builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type)));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
ctx->SetOutput(0, uniform);
|
||||
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
xla::XlaOp uniform =
|
||||
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
|
||||
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
|
||||
|
||||
ctx->SetOutput(0, uniform);
|
||||
}
|
||||
|
||||
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(1);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp normal =
|
||||
xla::NormalF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, xla_shape)
|
||||
.value;
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
ctx->SetOutput(0, normal);
|
||||
}
|
||||
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
|
@ -47,16 +47,19 @@ class TensorListLengthOp : public XlaOpKernel {
|
||||
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp index;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListPushIndex(ctx->Input(0), &index));
|
||||
ctx->SetOutput(0, index);
|
||||
TensorShape buffer_shape;
|
||||
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &buffer_shape));
|
||||
Tensor length_tensor(DT_INT32, {});
|
||||
length_tensor.scalar<int32>()() =
|
||||
static_cast<int32>(buffer_shape.dim_size(0));
|
||||
ctx->SetConstantOutput(0, length_tensor);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("TensorListLength"), TensorListLengthOp);
|
||||
REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
|
||||
|
||||
// Creates an empty list with size (leading_dim, *element_shape) if
|
||||
// element_shape is known at compile time. Otherwise creates one with size
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
@ -35,6 +36,17 @@ Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
|
||||
xla::PrimitiveType* type) {
|
||||
TF_RET_CHECK(op.builder());
|
||||
TF_ASSIGN_OR_RETURN(const xla::Shape& list_tuple_shape,
|
||||
op.builder()->GetShape(op));
|
||||
xla::Shape buffer_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(list_tuple_shape, 0);
|
||||
*type = buffer_shape.element_type();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) {
|
||||
TF_RET_CHECK(op.builder());
|
||||
*buffer = xla::GetTupleElement(op, 0);
|
||||
@ -97,4 +109,12 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||
return BuildTensorList(new_buffer, push_index, output_list);
|
||||
}
|
||||
|
||||
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
|
||||
xla::PrimitiveType type, xla::XlaOp* list) {
|
||||
auto zero =
|
||||
xla::ConstantLiteral(ctx->builder(), xla::LiteralUtil::Zero(type));
|
||||
*list = xla::Broadcast(zero, buffer_shape.dim_sizes());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,10 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index);
|
||||
Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
|
||||
xla::XlaOp* output_list);
|
||||
|
||||
// Returns XLA PrimitiveType for the TensorList.
|
||||
Status GetTensorListPrimitiveType(const xla::XlaOp& op,
|
||||
xla::PrimitiveType* type);
|
||||
|
||||
// Returns the buffer for the TensorList.
|
||||
Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
|
||||
|
||||
@ -62,6 +66,10 @@ Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
|
||||
const TensorShape& buffer_shape,
|
||||
xla::XlaOp* output_list);
|
||||
|
||||
// Returns a TensorList filled with zero.
|
||||
Status CreateZerosList(XlaOpKernelContext* ctx, const TensorShape& buffer_shape,
|
||||
xla::PrimitiveType type, xla::XlaOp* list);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_
|
||||
|
@ -529,7 +529,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
int resource_index = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
if (ctx->input_type(i) != DT_RESOURCE) {
|
||||
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
|
||||
if (IsTensorListInput(ctx, i)) {
|
||||
ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
|
||||
} else {
|
||||
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
|
||||
}
|
||||
++resource_index;
|
||||
} else {
|
||||
break;
|
||||
|
@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
||||
PruneForReverseReachability(graph, retval_nodes);
|
||||
PruneForReverseReachability(graph, std::move(retval_nodes));
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
|
||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||
@ -295,6 +295,8 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
compiler_options.flib_def = &graph->flib_def();
|
||||
compiler_options.graph_def_version = graph->versions().producer();
|
||||
compiler_options.allow_cpu_custom_calls = true;
|
||||
compiler_options.custom_fake_quant_op_calls =
|
||||
config.conversion_options().custom_fake_quant_op_calls();
|
||||
XlaCompiler compiler(compiler_options);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
|
@ -53,6 +53,14 @@ message Variable {
|
||||
bool readonly = 5;
|
||||
}
|
||||
|
||||
// Options used during the conversion and compilation process.
|
||||
message ConversionOptions {
|
||||
// When true tf.fake_quant_* ops will be emitted as custom calls to a
|
||||
// 'fake_quant_with_min_max_vars' function accepting the input, min, max,
|
||||
// num_bits, and narrow_range values as runtime arguments.
|
||||
bool custom_fake_quant_op_calls = 1;
|
||||
}
|
||||
|
||||
// Config represents configuration information for tf2xla conversion.
|
||||
message Config {
|
||||
// Each feed is a positional input argument for the generated computation.
|
||||
@ -63,4 +71,6 @@ message Config {
|
||||
repeated Fetch fetch = 2;
|
||||
// Each variable is a named input and output of the generated computation.
|
||||
repeated Variable variable = 3;
|
||||
// Optional conversion options.
|
||||
ConversionOptions conversion_options = 4;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
@ -200,8 +201,13 @@ Status BuildComputation(
|
||||
output.shape = output.constant_value.shape();
|
||||
break;
|
||||
|
||||
case XlaExpression::Kind::kTensorList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case XlaExpression::Kind::kTensorList: {
|
||||
output.is_tensor_list = true;
|
||||
xla::XlaOp value = retval.handle();
|
||||
elems.push_back(value);
|
||||
break;
|
||||
}
|
||||
|
||||
case XlaExpression::Kind::kXlaOp: {
|
||||
output.is_constant = false;
|
||||
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
||||
@ -403,6 +409,8 @@ string XlaCompiler::Argument::HumanString() const {
|
||||
}
|
||||
case kParameter:
|
||||
return absl::StrCat("kind=parameter", common);
|
||||
case kTensorList:
|
||||
return absl::StrCat("kind=tensorlist", common);
|
||||
case kToken:
|
||||
return absl::StrCat("token", common);
|
||||
}
|
||||
@ -641,6 +649,11 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaCompiler::Argument::kTensorList: {
|
||||
TF_RET_CHECK(absl::holds_alternative<xla::Shape>(arg.shape));
|
||||
*xla_shape = absl::get<xla::Shape>(arg.shape);
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
|
||||
@ -744,6 +757,7 @@ Status XlaCompiler::BuildArguments(
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
case XlaCompiler::Argument::kTensorList:
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
input_to_args->push_back(i);
|
||||
break;
|
||||
@ -902,6 +916,10 @@ Status XlaCompiler::BuildArguments(
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kTensorList: {
|
||||
arg_expression = XlaExpression::TensorList(arg_handles[i]);
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kToken: {
|
||||
arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
|
||||
break;
|
||||
|
@ -116,6 +116,9 @@ class XlaCompiler {
|
||||
|
||||
// Argument is an XLA token.
|
||||
kToken,
|
||||
|
||||
// Argument is a TensorList.
|
||||
kTensorList,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
@ -226,6 +229,9 @@ class XlaCompiler {
|
||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||
// the index of the input that contains the resource.
|
||||
int input_index;
|
||||
|
||||
// Whether this output is a TensorList.
|
||||
bool is_tensor_list = false;
|
||||
};
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
@ -305,6 +311,12 @@ class XlaCompiler {
|
||||
// for CPU.
|
||||
bool allow_cpu_custom_calls = false;
|
||||
|
||||
// If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_*
|
||||
// ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars'
|
||||
// function accepting the input, min, max, num_bits, and narrow_range values
|
||||
// as runtime arguments.
|
||||
bool custom_fake_quant_op_calls = false;
|
||||
|
||||
// If set, the XLA representation of variables represented to XLA as the
|
||||
// shape given by this shape function. Variables are reshaped to this shape
|
||||
// on write, and reshaped to their original shape on read.
|
||||
|
@ -14,10 +14,15 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/list_ops.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
@ -35,7 +40,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -1498,5 +1505,76 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
|
||||
FunctionDefLibrary fdef_lib;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
// Build cond fn for While.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
auto result = ops::Const<bool>(scope, {true}, {});
|
||||
ops::_Retval(scope.WithOpName("ret"), result, 0);
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
FunctionDef fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
|
||||
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
|
||||
}
|
||||
// Build body fn for While.
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
ops::_Retval(scope.WithOpName("ret"), arg, 0);
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
FunctionDef fdef;
|
||||
TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
|
||||
TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
|
||||
}
|
||||
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto element_shape = ops::Const<int32>(scope, {1}, {1});
|
||||
auto max_elements = ops::Const<int32>(scope, {10}, {});
|
||||
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
|
||||
std::initializer_list<Output> out = {arg, arg};
|
||||
auto add_n = ops::AddN(scope, out);
|
||||
NameAttrList cond_fn, body_fn;
|
||||
cond_fn.set_name("cond");
|
||||
body_fn.set_name("body");
|
||||
auto while_op =
|
||||
ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
|
||||
auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
|
||||
auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kTensorList;
|
||||
xla::Shape tensor_list_element_shape;
|
||||
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
|
||||
&tensor_list_element_shape));
|
||||
xla::Shape index_shape;
|
||||
TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
|
||||
std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
|
||||
xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
|
||||
args[0].shape = arg_shape;
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.flib_def = &flib_def;
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
ASSERT_EQ(result.outputs.size(), 2);
|
||||
const XlaCompiler::OutputDescription& output0 = result.outputs[0];
|
||||
ASSERT_TRUE(output0.is_tensor_list);
|
||||
const XlaCompiler::OutputDescription& output1 = result.outputs[1];
|
||||
ASSERT_TRUE(output1.is_tensor_list);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -153,7 +153,6 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
cpu_global_jit
|
||||
? XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally
|
||||
: XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested;
|
||||
registration.compile_all_resource_ops = false;
|
||||
}
|
||||
if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) {
|
||||
DeviceRegistration& registration =
|
||||
@ -161,7 +160,6 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||
registration.autoclustering_policy =
|
||||
XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally;
|
||||
registration.compile_all_resource_ops = false;
|
||||
}
|
||||
return nullptr;
|
||||
}();
|
||||
|
@ -88,8 +88,33 @@ class XlaOpRegistry {
|
||||
// When should we autocluster operators assigned to this device?
|
||||
AutoclusteringPolicy autoclustering_policy;
|
||||
|
||||
// Enable compilation of operators that use DT_RESOURCE types?
|
||||
bool compile_all_resource_ops = false;
|
||||
// If we should ignore the resource variable memory model when clustering
|
||||
// resource variable reads and writes placed on this device.
|
||||
bool cluster_resource_variable_ops_unsafely = false;
|
||||
|
||||
// If we should auto-cluster Stack operations placed on this device.
|
||||
bool cluster_stack_ops = false;
|
||||
|
||||
// If we should auto-cluster TensorArray operations placed on this device.
|
||||
bool cluster_tensor_array_ops = false;
|
||||
|
||||
// If we should auto-cluster stateful RNG operations placed on this device.
|
||||
// Stateful RNG semantics are not properly supported by XLA so it is not
|
||||
// necessarily correct to auto-cluster stateful RNG ops in general.
|
||||
bool cluster_stateful_rng_ops = false;
|
||||
|
||||
// If we should auto-cluster ControlTrigger operations placed on this
|
||||
// device. ControlTrigger operations are not necessarily safe to cluster
|
||||
// since they affect deadness (a dead ControlTrigger produces a live
|
||||
// output).
|
||||
bool cluster_control_trigger = false;
|
||||
|
||||
// If we should cluster Assert and CheckNumerics by eliding them (XLA does
|
||||
// not natively support Assert or CheckNumerics).
|
||||
bool elide_assert_and_checknumerics = false;
|
||||
|
||||
// If we should cluster operations returning DT_VARIANT.
|
||||
bool cluster_variant_ops = false;
|
||||
};
|
||||
|
||||
// Registers an XLA backend. `compilation_device_name` is the name of the
|
||||
|
@ -287,6 +287,7 @@ tf_cc_test(
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/hash:hash_testing",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -125,8 +125,60 @@ XlaOp Any(XlaOp predicates) {
|
||||
|
||||
namespace {
|
||||
|
||||
XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
|
||||
PrimitiveType value_type,
|
||||
PrimitiveType index_type, bool is_min) {
|
||||
auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
|
||||
XlaBuilder* b = sub_builder.get();
|
||||
XlaOp lhs_value =
|
||||
Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
|
||||
XlaOp lhs_index =
|
||||
Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
|
||||
XlaOp rhs_value =
|
||||
Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
|
||||
XlaOp rhs_index =
|
||||
Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
|
||||
|
||||
auto cmp = is_min ? Lt(lhs_value, rhs_value) : Gt(lhs_value, rhs_value);
|
||||
XlaOp max = Select(cmp, lhs_value, rhs_value);
|
||||
XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
|
||||
Tuple(b, {max, arg_max});
|
||||
return b->Build().ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
XlaOp value_init_value;
|
||||
if (is_min) {
|
||||
value_init_value = MaxValue(builder, input_shape.element_type());
|
||||
} else {
|
||||
value_init_value = MinValue(builder, input_shape.element_type());
|
||||
}
|
||||
int64 dimension_size = input_shape.dimensions(axis);
|
||||
auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
|
||||
XlaOp index_init_value = Zero(builder, index_type);
|
||||
auto iota_shape = input_shape;
|
||||
iota_shape.set_element_type(index_type);
|
||||
XlaOp iota = Iota(builder, iota_shape, axis);
|
||||
|
||||
XlaComputation reducer = CreateMinMaxComputation(
|
||||
builder, input_shape.element_type(), index_type, is_min);
|
||||
XlaOp max_argmax = Reduce(builder, {input, iota},
|
||||
{value_init_value, index_init_value}, reducer,
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
XlaOp argmax = GetTupleElement(max_argmax, 1);
|
||||
if (index_type != output_type) {
|
||||
argmax = ConvertElementType(argmax, output_type);
|
||||
}
|
||||
return argmax;
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
|
||||
bool is_min) {
|
||||
XlaBuilder* builder = input.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
|
||||
XlaOp init_value;
|
||||
@ -172,7 +224,6 @@ XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
|
||||
/*dimensions_to_reduce=*/{axis});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
@ -183,4 +234,11 @@ XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMax(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false);
|
||||
}
|
||||
|
||||
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis) {
|
||||
return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true);
|
||||
}
|
||||
} // namespace xla
|
||||
|
@ -60,10 +60,12 @@ XlaOp Any(XlaOp predicates);
|
||||
// Returns the argmax of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis);
|
||||
XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
// Returns the argmin of `input` along `axis`. `output_type` is the type to
|
||||
// use for the output.
|
||||
XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis);
|
||||
XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
// The internal state of the Three Fry implementation.
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
XlaBuilder* builder = input[0].builder();
|
||||
key[0] = BitcastConvertType(key[0], U32);
|
||||
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// Returns the inputs with unique counter values for ThreeFry2x32.
|
||||
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
|
||||
ThreeFry2x32State inputs;
|
||||
inputs[0] = Iota(builder, U32, size);
|
||||
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
ThreeFry2x32State inputs = GetInputs(half_size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
// Converts a uint64 to two uint32s.
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||
auto builder = u64.builder();
|
||||
auto const32 = ConstantR0WithType(builder, U64, 32);
|
||||
auto fst = ConvertElementType(u64, U32);
|
||||
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
XlaBuilder* builder = u64.builder();
|
||||
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
|
||||
XlaOp fst = ConvertElementType(u64, U32);
|
||||
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
return {fst, snd};
|
||||
}
|
||||
|
||||
// Converts two uint32s to a uint64.
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||
auto builder = u32s[0].builder();
|
||||
XlaBuilder* builder = u32s[0].builder();
|
||||
return ConvertElementType(u32s[0], U64) |
|
||||
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||
ConstantR0WithType(builder, U64, 32));
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
// Given the initial state and the request number of random numbers to be
|
||||
// generated, returns the input for the random number generator and a new state.
|
||||
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
|
||||
XlaOp initial_state, const int64 size) {
|
||||
XlaBuilder* builder = initial_state.builder();
|
||||
XlaOp input_u64 = Iota(builder, U64, size);
|
||||
input_u64 = input_u64 + initial_state;
|
||||
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Generates random 32bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
XlaBuilder* builder = key.builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
XlaOp result = ConcatInDim(builder, outputs, 0);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
// Generates random 64bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
XlaOp result = Uint32sToUint64(outputs);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
|
||||
// forces the random bits into the mantissa.
|
||||
constexpr int kFloatBits = 32;
|
||||
@ -161,50 +177,119 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
bits = ShiftRightLogical(
|
||||
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
|
||||
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
|
||||
auto floats = BitcastConvertType(bits, F32);
|
||||
XlaOp values = BitcastConvertType(bits, F32);
|
||||
|
||||
// We have a floating point number in the range [1.0, 2.0).
|
||||
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
||||
floats = floats - ConstantR0<float>(builder, 1.0f);
|
||||
values = values - ConstantR0<float>(builder, 1.0f);
|
||||
// Multiply and add to shift to the range [minval, maxval).
|
||||
return floats * (maxval - minval) + minval;
|
||||
return values * (maxval - minval) + minval;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
||||
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type,
|
||||
PrimitiveType unsigned_type) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
auto dist = Rem(bits, range);
|
||||
auto dist_div_2 =
|
||||
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
XlaOp dist = Rem(bits, range);
|
||||
XlaOp dist_div_2 =
|
||||
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
|
||||
|
||||
return minval + BitcastConvertType(dist_div_2, type) +
|
||||
BitcastConvertType(dist - dist_div_2, type);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = seeds[0].builder();
|
||||
// Implements the Box-Muller transform, which converts random floats in the
|
||||
// range of [0, 1] from uniform distribution to normal distribution with mean 0
|
||||
// and variance 1. For more detail on the Box-Muller transform, see
|
||||
// http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
|
||||
std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
|
||||
// Do not send a really small number to log().
|
||||
XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
|
||||
|
||||
XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
|
||||
XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
|
||||
return {Sin(v1) * u2, Cos(v1) * u2};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case F32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformF32(bits, minval, maxval);
|
||||
}
|
||||
case S32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
|
||||
}
|
||||
case S64: {
|
||||
auto bits = StatelessRngUniformU64(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
|
||||
}
|
||||
case F32:
|
||||
case U32:
|
||||
case S32:
|
||||
return ThreeFryRngBit32(key, initial_state, shape);
|
||||
case U64:
|
||||
case S64:
|
||||
return ThreeFryRngBit64(key, initial_state, shape);
|
||||
default:
|
||||
return builder->ReportError(Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform."));
|
||||
return {key.builder()->ReportError(Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by ThreeFryBitGenerator; got %s",
|
||||
primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
|
||||
}
|
||||
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
PrimitiveType type = shape.element_type();
|
||||
PrimitiveType unsigned_type;
|
||||
if (type == U32 || type == S32) {
|
||||
unsigned_type = U32;
|
||||
} else {
|
||||
DCHECK(type == U64 || type == S64);
|
||||
unsigned_type = U64;
|
||||
}
|
||||
return {
|
||||
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
|
||||
new_state};
|
||||
}
|
||||
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
XlaBuilder* builder = key.builder();
|
||||
const int64 num_elems = ShapeUtil::ElementsIn(shape);
|
||||
const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
|
||||
RngOutput bits_state = UniformF32Distribution(
|
||||
key, initial_state, bit_generator, ConstantR0<float>(builder, 0.0),
|
||||
ConstantR0<float>(builder, 1.0),
|
||||
ShapeUtil::MakeShape(F32, {num_pairs * 2}));
|
||||
|
||||
// Separate the bits into two groups to perform the Box-Muller transform.
|
||||
XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
|
||||
XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
|
||||
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
|
||||
|
||||
// Put the numbers in the two groups back to form the requested shape.
|
||||
XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
|
||||
if (num_elems != num_pairs * 2) {
|
||||
normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
|
||||
/*strides=*/{1});
|
||||
}
|
||||
normal = Reshape(normal, shape.dimensions());
|
||||
|
||||
return {normal, bits_state.state};
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,37 +23,52 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Records the bits and state generated by a random number generator.
|
||||
struct RngOutput {
|
||||
XlaOp value;
|
||||
XlaOp state;
|
||||
};
|
||||
|
||||
// A BitGenerator returns random bits and updated random bit generator state.
|
||||
//
|
||||
// key: is a value input to a random number generator that can affect the
|
||||
// sequence of number it will generate. A random number generator constructs
|
||||
// its seed using the key and the initial state. The tf2xla bridge passes the
|
||||
// seed operand of a tensorflow random operation as a key to the random bit
|
||||
// generator, for example.
|
||||
// initial_state: initial_state is the initial state of the current random
|
||||
// number generation. It could be 0 for a stateless random operation, and
|
||||
// the returned state from a previous execution for a stateful random
|
||||
// operation.
|
||||
// shape: the shape of the random bits.
|
||||
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape)>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape);
|
||||
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of uniform distribution in the given range.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
// This function is for shape with float element type.
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
||||
// float32 in the range [minval, maxval).
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
||||
// Similar to UniformF32Distribution but for shape with integer element types.
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts an integer random number 'bits' of type 'type' to a random number
|
||||
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
||||
// unsigned version of 'type' (could be the same) with the same bit width.
|
||||
// The algorithm is the same one that TF uses right now, but it's
|
||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||
// generated from.
|
||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type);
|
||||
|
||||
// The following 2 functions, for converting between one uint64 and two uint32s,
|
||||
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
||||
// second".
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of normal distribution.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const xla::Shape& shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -292,6 +292,11 @@ static void AllocateFlags() {
|
||||
flag_values->xla_gpu_crash_on_verification_failures(),
|
||||
"Crashes the program on extra verification failures, e.g. cuDNN "
|
||||
"cross checking failures"),
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_disable_autotune",
|
||||
bool_setter_for(&DebugOptions::set_xla_gpu_disable_autotune),
|
||||
flag_values->xla_gpu_disable_autotune(),
|
||||
"Disable GEMM and Convolution auto-tuning."),
|
||||
tensorflow::Flag(
|
||||
"xla_force_host_platform_device_count",
|
||||
int32_setter_for(
|
||||
|
@ -120,9 +120,14 @@ class Sharding(object):
|
||||
tile_assignment_dimensions=tile_assignment_dims,
|
||||
tile_assignment_devices=range(num_devices)))
|
||||
|
||||
def apply_to_tensor(self, tensor):
|
||||
"""Applies this Sharding attribute to `tensor`."""
|
||||
if len(tensor.op.outputs) > 1:
|
||||
def apply_to_tensor(self, tensor, assign_tuple_sharding=False):
|
||||
"""Applies this Sharding attribute to `tensor`.
|
||||
|
||||
Args:
|
||||
tensor: A tf.Tensor to split.
|
||||
assign_tuple_sharding: If the sharding type should be a tuple.
|
||||
"""
|
||||
if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
|
||||
proto = self._get_or_create_tuple_proto(tensor.op)
|
||||
# We can't mutate an element of old_proto.tuple_shardings, so create
|
||||
# a new proto.
|
||||
@ -166,21 +171,30 @@ class Sharding(object):
|
||||
# tensor = xla_sharding.replicate(tensor)
|
||||
|
||||
|
||||
def replicate(tensor):
|
||||
Sharding.replicate().apply_to_tensor(tensor)
|
||||
def replicate(tensor, assign_tuple_sharding=False):
|
||||
Sharding.replicate().apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def assign_device(tensor, device):
|
||||
Sharding.assign_device(device).apply_to_tensor(tensor)
|
||||
def assign_device(tensor, device, assign_tuple_sharding=False):
|
||||
Sharding.assign_device(device).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
||||
|
||||
def tile(tensor, tile_assignment):
|
||||
Sharding.tile(tile_assignment).apply_to_tensor(tensor)
|
||||
def tile(tensor, tile_assignment, assign_tuple_sharding=False):
|
||||
Sharding.tile(tile_assignment).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def split(tensor, split_dimension, num_devices):
|
||||
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
|
||||
def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False):
|
||||
Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
|
||||
tensor,
|
||||
assign_tuple_sharding=assign_tuple_sharding)
|
||||
return tensor
|
||||
|
@ -303,6 +303,37 @@ For example, if `operand` is a scalar `f32` with value `2.0f`, and
|
||||
`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
|
||||
`f32[2, 3]` and all the values in the result will be `2.0f`.
|
||||
|
||||
## BroadcastInDim
|
||||
|
||||
See also
|
||||
[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
||||
Expands the size and rank of an array by duplicating the data in the array.
|
||||
|
||||
<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b>
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| ---------------------- | ------------------- | ----------------------------- |
|
||||
| `operand` | `XlaOp` | The array to duplicate |
|
||||
| `out_dim_size` | `ArraySlice<int64>` | The sizes of the dimensions |
|
||||
: : : of the target shape :
|
||||
| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target |
|
||||
: : : shape each dimension of the :
|
||||
: : : operand shape corresponds to :
|
||||
|
||||
Similar to Broadcast, but allows adding dimensions anywhere and expanding
|
||||
existing dimensions with size 1.
|
||||
|
||||
The `operand` is broadcast to the shape described by `out_dim_size`.
|
||||
`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the
|
||||
target shape, i.e. the i'th dimension of the operand is mapped to the
|
||||
broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of
|
||||
`operand` must have size 1 or be the same size as the dimension in in the output
|
||||
shape they are mapped to. The remaining dimensions are filled with dimensions of
|
||||
size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate
|
||||
dimensions to reach the output shape. The semantics are described in detail on
|
||||
the [broadcasting page](broadcasting.md).
|
||||
|
||||
## Call
|
||||
|
||||
See also
|
||||
@ -1258,6 +1289,59 @@ Arguments | Type | Semantics
|
||||
The function is applied to each element in the `operand` array, resulting in an
|
||||
array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
|
||||
|
||||
## Fft
|
||||
|
||||
The XLA FFT operation implements the forward and inverse Fourier Transforms for
|
||||
real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
|
||||
supported, except on TPU, where only a single axis is supported (please file a
|
||||
github issue if you require higher order).
|
||||
|
||||
See also
|
||||
[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| ------------ | ------------------- | ------------------------ |
|
||||
| `operand` | `XlaOp` | The array we are Fourier |
|
||||
: : : transforming. :
|
||||
| `fft_type` | `FftType` | See the table below. |
|
||||
| `fft_length` | `ArraySlice<int64>` | The time-domain lengths |
|
||||
: : : of the axes being :
|
||||
: : : transformed. This is :
|
||||
: : : needed in particular for :
|
||||
: : : IRFFT to right-size the :
|
||||
: : : innermost axis, since :
|
||||
: : : `RFFT(fft_length=[16])` :
|
||||
: : : has the same output :
|
||||
: : : shape as :
|
||||
: : : `RFFT(fft_length=[17])`. :
|
||||
|
||||
| `FftType` | Semantics |
|
||||
| --------- | ---------------------------------------------------------------- |
|
||||
| `FFT` | Forward complex-to-complex FFT. Shape is unchanged. |
|
||||
| `IFFT` | Inverse complex-to-complex FFT. Shape is unchanged. |
|
||||
| `RFFT` | Forward real-to-complex FFT. Shape of the innermost axis is |
|
||||
: : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a :
|
||||
: : non-zero value, omitting the reversed conjugate part of the :
|
||||
: : transformed signal beyond the Nyquist frequency. :
|
||||
| `IRFFT` | Inverse real-to-complex FFT (i.e. takes complex, returns real). |
|
||||
: : Shape of the innermost axis is expanded to `fft_length[-1]` if :
|
||||
: : `fft_length[-1]` is a non-zero value, inferring the part of the :
|
||||
: : transformed signal beyond the Nyquist frequency from the reverse :
|
||||
: : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries. :
|
||||
|
||||
#### Multidimensional FFT
|
||||
|
||||
When more than 1 `fft_length` is provided, this is equivalent to applying a
|
||||
cascade of FFT operations to each of the innermost axes. Note that for the
|
||||
real->complex and complex->real cases, the innermost axis transform is
|
||||
(effectively) performed first (RFFT; last for IRFFT), which is why the innermost
|
||||
axis is the one which changes size. Other axis transforms will then be
|
||||
complex->complex.
|
||||
|
||||
#### Implementation details
|
||||
|
||||
CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT.
|
||||
|
||||
## Gather
|
||||
|
||||
The XLA gather operation stitches together several slices (each slice at a
|
||||
|
@ -69,6 +69,11 @@ class Tile {
|
||||
// combined with the next minor dimension before tiling is applied.
|
||||
static constexpr int64 kCombineDimension = std::numeric_limits<int64>::min();
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Tile& t) {
|
||||
return H::combine(std::move(h), t.dimensions_);
|
||||
}
|
||||
|
||||
private:
|
||||
// The bounds of the tile.
|
||||
std::vector<int64> dimensions_;
|
||||
@ -212,6 +217,13 @@ class Layout {
|
||||
element_size_in_bits_ = 0;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Layout& l) {
|
||||
return H::combine(std::move(h), l.format_, l.minor_to_major_,
|
||||
l.max_sparse_elements_, l.tiles_,
|
||||
l.element_size_in_bits_);
|
||||
}
|
||||
|
||||
private:
|
||||
// The format of this layout.
|
||||
Format format_ = INVALID_FORMAT;
|
||||
|
@ -109,6 +109,9 @@ tf_pybind_extension(
|
||||
":xrt",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@pybind11",
|
||||
@ -132,11 +135,13 @@ tf_pybind_extension(
|
||||
"//tensorflow/compiler/xla/client/lib:svd",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:name_uniquer",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
# Do NOT remove this dependency. The XLA Python extension must not
|
||||
# depend on any part of TensorFlow at runtime, **including**
|
||||
# libtensorflow_framework.so. The XLA module is deployed self-contained
|
||||
|
@ -20,6 +20,11 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
@ -33,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace xla {
|
||||
namespace xla_python {
|
||||
@ -55,7 +61,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<LocalClient*> GetLocalClient(const std::string& platform_name) {
|
||||
StatusOr<std::unique_ptr<PyLocalClient>> PyLocalClient::Get(
|
||||
const std::string& platform_name) {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
PlatformUtil::GetPlatform(platform_name));
|
||||
if (platform->VisibleDeviceCount() <= 0) {
|
||||
@ -64,47 +71,158 @@ StatusOr<LocalClient*> GetLocalClient(const std::string& platform_name) {
|
||||
}
|
||||
LocalClientOptions options;
|
||||
options.set_platform(platform);
|
||||
return ClientLibrary::GetOrCreateLocalClient(options);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
return absl::make_unique<PyLocalClient>(client);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
||||
const py::object& argument, LocalClient* client, int device_ordinal) {
|
||||
VLOG(1) << "Creating shaped buffer from literal on device ordinal: "
|
||||
<< device_ordinal;
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
PyLocalClient::PyLocalClient(LocalClient* client)
|
||||
: client_(client),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()),
|
||||
execute_pool_(tensorflow::Env::Default(), "py_xla_execute",
|
||||
client->device_count()) {}
|
||||
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal,
|
||||
int device_ordinal) {
|
||||
py::gil_scoped_release gil_release;
|
||||
DeviceMemoryAllocator* allocator = client->backend().memory_allocator();
|
||||
TransferManager* transfer_manager = client->backend().transfer_manager();
|
||||
return client_->TransferToInfeedLocal(literal, device_ordinal);
|
||||
}
|
||||
|
||||
StatusOr<pybind11::object> PyLocalClient::TransferFromOutfeed(
|
||||
const Shape& shape, int device_ordinal) {
|
||||
Literal literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
literal, client_->TransferFromOutfeedLocal(shape, device_ordinal));
|
||||
}
|
||||
return LiteralToPython(absl::make_unique<Literal>(std::move(literal)));
|
||||
}
|
||||
|
||||
static StatusOr<LocalShapedBuffer> TransferHostToDeviceAsync(
|
||||
const PythonBufferTree& tree, int device_ordinal, PyLocalClient* client,
|
||||
se::Stream* stream) {
|
||||
DeviceMemoryAllocator* allocator =
|
||||
client->client()->backend().memory_allocator();
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, transfer_manager->ChooseCompactLayoutForShape(tree.shape));
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
tree.shape, allocator, device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(auto stream,
|
||||
client->mutable_backend()->BorrowStream(device_ordinal));
|
||||
shape, allocator, device_ordinal));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->WriteTupleIndexTables(stream.get(), buffer));
|
||||
transfer_manager->WriteTupleIndexTablesAsync(stream, buffer));
|
||||
|
||||
auto it = tree.leaves.begin();
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(tree.shape)) {
|
||||
ShapeUtil::GetLeafShapes(shape)) {
|
||||
TF_RET_CHECK(it != tree.leaves.end());
|
||||
ShapedBuffer leaf(
|
||||
indexed_shape.shape,
|
||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||
client->platform(), device_ordinal);
|
||||
client->client()->platform(), device_ordinal);
|
||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
stream.get(), *it, leaf));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->TransferLiteralToDeviceAsync(stream, *it, leaf));
|
||||
++it;
|
||||
}
|
||||
stream->BlockHostUntilDone();
|
||||
return LocalShapedBuffer(std::move(buffer), client);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<LocalShapedBuffer> LocalShapedBuffer::FromPython(
|
||||
const py::object& argument, PyLocalClient* client, int device_ordinal) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPython");
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument));
|
||||
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
VLOG(1) << "LocalShapedBuffer::FromPython: shape: " << tree.shape.ToString()
|
||||
<< " device ordinal: " << device_ordinal;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StreamPool::Ptr stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
LocalShapedBuffer buffer,
|
||||
TransferHostToDeviceAsync(tree, device_ordinal, client, stream.get()));
|
||||
stream->BlockHostUntilDone();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/*static */ StatusOr<std::vector<LocalShapedBuffer>>
|
||||
LocalShapedBuffer::FromPythonValues(
|
||||
const std::vector<std::pair<py::object, int>>& arguments,
|
||||
PyLocalClient* client) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::FromPythonValues");
|
||||
int num_arguments = static_cast<int>(arguments.size());
|
||||
std::vector<LocalShapedBuffer> outputs(num_arguments);
|
||||
if (num_arguments == 0) {
|
||||
return outputs;
|
||||
}
|
||||
|
||||
struct H2DTransfer {
|
||||
PythonBufferTree tree;
|
||||
StreamPool::Ptr stream;
|
||||
StatusOr<LocalShapedBuffer> buffer;
|
||||
};
|
||||
|
||||
std::vector<H2DTransfer> transfers(num_arguments);
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(transfers[i].tree,
|
||||
GetPythonBufferTree(arguments[i].first));
|
||||
}
|
||||
// We are done manipulating Python objects; release the GIL.
|
||||
py::gil_scoped_release gil_release;
|
||||
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
int device_ordinal = arguments[i].second;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
transfers[i].stream,
|
||||
client->client()->mutable_backend()->BorrowStream(device_ordinal));
|
||||
}
|
||||
|
||||
auto transfer_h2d = [&](int i) -> StatusOr<LocalShapedBuffer> {
|
||||
int device_ordinal = arguments[i].second;
|
||||
return TransferHostToDeviceAsync(transfers[i].tree, device_ordinal, client,
|
||||
transfers[i].stream.get());
|
||||
};
|
||||
|
||||
// We perform the transfers on a thread pool in case XLA needs to do any
|
||||
// host-side preprocessing of the input data.
|
||||
if (num_arguments == 1) {
|
||||
transfers[0].buffer = transfer_h2d(0);
|
||||
} else {
|
||||
absl::BlockingCounter counter(num_arguments - 1);
|
||||
for (int i = 1; i < num_arguments; ++i) {
|
||||
client->h2d_transfer_pool()->Schedule([&, i]() {
|
||||
transfers[i].buffer = transfer_h2d(i);
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
// Perform the first transfer on the main thread.
|
||||
transfers[0].buffer = transfer_h2d(0);
|
||||
counter.Wait();
|
||||
}
|
||||
|
||||
// First, wait for all transfers to complete. We wait for all to complete
|
||||
// since currently we maintain the invariant that the device's view of the
|
||||
// state matches the host's view of the state. Returning early would mean that
|
||||
// we might deallocate device-side memory before a transfer completes, which
|
||||
// violates that invariant.
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
transfers[i].stream->BlockHostUntilDone();
|
||||
}
|
||||
for (int i = 0; i < num_arguments; ++i) {
|
||||
TF_ASSIGN_OR_RETURN(outputs[i], std::move(transfers[i].buffer));
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer,
|
||||
LocalClient* client)
|
||||
PyLocalClient* client)
|
||||
: shaped_buffer_(std::move(shaped_buffer)), client_(client) {}
|
||||
|
||||
const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
|
||||
@ -122,16 +240,18 @@ const Shape& LocalShapedBuffer::shape() const {
|
||||
}
|
||||
|
||||
StatusOr<py::object> LocalShapedBuffer::ToPython() const {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::ToPython");
|
||||
auto literal = absl::make_unique<Literal>();
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(*literal,
|
||||
client_->ShapedBufferToLiteral(*shaped_buffer()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*literal, client_->client()->ShapedBufferToLiteral(*shaped_buffer()));
|
||||
}
|
||||
return LiteralToPython(std::move(literal));
|
||||
}
|
||||
|
||||
StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
||||
tensorflow::profiler::TraceMe traceme("LocalShapedBuffer::DestructureTuple");
|
||||
const Shape tuple_shape = shape();
|
||||
|
||||
if (!tuple_shape.IsTuple()) {
|
||||
@ -173,14 +293,14 @@ StatusOr<std::vector<LocalShapedBuffer>> LocalShapedBuffer::DestructureTuple() {
|
||||
return results;
|
||||
}
|
||||
|
||||
LocalExecutableWrapper::LocalExecutableWrapper(
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::unique_ptr<LocalExecutable> executable,
|
||||
DeviceAssignment device_assignment, LocalClient* client)
|
||||
DeviceAssignment device_assignment, PyLocalClient* client)
|
||||
: executable_(std::move(executable)),
|
||||
device_assignment_(std::move(device_assignment)),
|
||||
client_(client) {}
|
||||
|
||||
std::vector<int> LocalExecutableWrapper::DeviceOrdinals() const {
|
||||
std::vector<int> PyLocalExecutable::DeviceOrdinals() const {
|
||||
int num_replicas = device_assignment_.replica_count();
|
||||
std::vector<int> device_ordinals;
|
||||
device_ordinals.reserve(num_replicas);
|
||||
@ -190,8 +310,9 @@ std::vector<int> LocalExecutableWrapper::DeviceOrdinals() const {
|
||||
return device_ordinals;
|
||||
}
|
||||
|
||||
StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
|
||||
StatusOr<LocalShapedBuffer> PyLocalExecutable::Execute(
|
||||
absl::Span<LocalShapedBuffer* const> argument_handles) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute");
|
||||
if (num_replicas() != 1) {
|
||||
return InvalidArgument(
|
||||
"Attempted to execute computation with %d replicas using Execute()",
|
||||
@ -210,26 +331,23 @@ StatusOr<LocalShapedBuffer> LocalExecutableWrapper::Execute(
|
||||
|
||||
ExecutableRunOptions options;
|
||||
options.set_device_ordinal(device_ordinal);
|
||||
options.set_allocator(client_->backend().memory_allocator());
|
||||
options.set_allocator(client_->client()->backend().memory_allocator());
|
||||
options.set_intra_op_thread_pool(
|
||||
client_->backend().eigen_intra_op_thread_pool_device());
|
||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment_);
|
||||
|
||||
result_buffer_status = executable_->Run(argument_buffers, options);
|
||||
|
||||
if (!result_buffer_status.ok()) {
|
||||
return InternalError(
|
||||
"Failed running replica 0 (other replicas may have failed as well): "
|
||||
"%s.",
|
||||
result_buffer_status.status().ToString());
|
||||
return result_buffer_status.status();
|
||||
}
|
||||
return LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie(),
|
||||
client_);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<LocalShapedBuffer>>
|
||||
LocalExecutableWrapper::ExecutePerReplica(
|
||||
StatusOr<std::vector<LocalShapedBuffer>> PyLocalExecutable::ExecutePerReplica(
|
||||
absl::Span<const std::vector<LocalShapedBuffer*>> argument_handles) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica");
|
||||
const int num_devices = client_->device_count();
|
||||
|
||||
if (argument_handles.size() != num_replicas()) {
|
||||
@ -245,8 +363,8 @@ LocalExecutableWrapper::ExecutePerReplica(
|
||||
|
||||
VLOG(1) << "Executing with " << num_replicas() << " replicas.";
|
||||
|
||||
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
|
||||
auto execute = [this, &argument_handles, &results](int replica) {
|
||||
auto execute =
|
||||
[this, &argument_handles](int replica) -> StatusOr<ScopedShapedBuffer> {
|
||||
const int device_ordinal = device_assignment_(replica, 0);
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
@ -259,39 +377,83 @@ LocalExecutableWrapper::ExecutePerReplica(
|
||||
|
||||
ExecutableRunOptions options;
|
||||
options.set_device_ordinal(device_ordinal);
|
||||
options.set_allocator(client_->backend().memory_allocator());
|
||||
options.set_allocator(client_->client()->backend().memory_allocator());
|
||||
options.set_intra_op_thread_pool(
|
||||
client_->backend().eigen_intra_op_thread_pool_device());
|
||||
client_->client()->backend().eigen_intra_op_thread_pool_device());
|
||||
options.set_device_assignment(&device_assignment_);
|
||||
StatusOr<ScopedShapedBuffer> result_buffer_status =
|
||||
executable_->Run(argument_buffers, options);
|
||||
|
||||
results[replica] = std::move(result_buffer_status);
|
||||
VLOG(1) << "Replica " << replica
|
||||
<< " completed; ok=" << result_buffer_status.ok();
|
||||
if (!result_buffer_status.ok()) {
|
||||
LOG(ERROR) << "Execution of replica " << replica
|
||||
<< " failed: " << result_buffer_status.status();
|
||||
}
|
||||
return result_buffer_status;
|
||||
};
|
||||
|
||||
VLOG(1) << "Executing replicated computation; num_replicas="
|
||||
<< num_replicas();
|
||||
std::vector<StatusOr<ScopedShapedBuffer>> results(num_replicas());
|
||||
if (num_replicas() == 1) {
|
||||
// Fast-path if there is only one replica — run the computation on the
|
||||
// current thread.
|
||||
execute(0);
|
||||
results[0] = execute(0);
|
||||
} else {
|
||||
// TODO(phawkins): don't recreate the threadpool for each execution.
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
|
||||
num_replicas() - 1);
|
||||
absl::Mutex mu;
|
||||
int running GUARDED_BY(mu) = num_replicas();
|
||||
int failed GUARDED_BY(mu) = 0;
|
||||
|
||||
for (int replica = 0; replica < num_replicas() - 1; ++replica) {
|
||||
pool.Schedule([&execute, replica] { execute(replica); });
|
||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||
client_->execute_pool()->Schedule(
|
||||
[&execute, &mu, &running, &failed, &results, replica] {
|
||||
results[replica] = execute(replica);
|
||||
|
||||
absl::MutexLock lock(&mu);
|
||||
--running;
|
||||
if (!results[replica].ok()) {
|
||||
++failed;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
auto done_running_or_failed = [&]() {
|
||||
mu.AssertHeld();
|
||||
return running == 0 || failed > 0;
|
||||
};
|
||||
absl::MutexLock lock(&mu);
|
||||
mu.Await(absl::Condition(&done_running_or_failed));
|
||||
if (failed > 0) {
|
||||
auto done_running = [&]() {
|
||||
mu.AssertHeld();
|
||||
return running == 0;
|
||||
};
|
||||
// If execution does not terminate within a reasonable amount of time, we
|
||||
// may be stuck at a cross-replica barrier on-device. Terminate the
|
||||
// process since that's the only way we can escape this situation at the
|
||||
// moment (b/130629719).
|
||||
if (!mu.AwaitWithTimeout(absl::Condition(&done_running),
|
||||
absl::Seconds(10))) {
|
||||
LOG(FATAL)
|
||||
<< "Replicated computation launch failed, but not all replicas "
|
||||
"terminated. Aborting process to work around deadlock. See the "
|
||||
"error log for details of the failure.";
|
||||
}
|
||||
}
|
||||
execute(num_replicas() - 1);
|
||||
}
|
||||
VLOG(1) << "Replicated execution complete.";
|
||||
|
||||
std::vector<LocalShapedBuffer> wrapped_results(num_replicas());
|
||||
for (int replica = 0; replica < num_replicas(); ++replica) {
|
||||
auto& statusor = results[replica];
|
||||
if (!statusor.ok()) {
|
||||
return InternalError(
|
||||
"Failed running replica %d (other replicas may have failed as well): "
|
||||
"%s.",
|
||||
replica, statusor.status().ToString());
|
||||
return AppendStatus(
|
||||
statusor.status(),
|
||||
absl::StrFormat(
|
||||
"while running replica %d of a replicated computation (other "
|
||||
"replicas may have failed as well).",
|
||||
replica));
|
||||
}
|
||||
wrapped_results[replica] =
|
||||
LocalShapedBuffer(std::move(statusor).ValueOrDie(), client_);
|
||||
@ -334,29 +496,62 @@ StatusOr<std::string> GetComputationHloDotGraph(
|
||||
RenderedGraphFormat::kDot);
|
||||
}
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<LocalExecutableWrapper>>
|
||||
LocalExecutableWrapper::Compile(const XlaComputation& computation,
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
LocalClient* client) {
|
||||
std::vector<const Shape*> argument_shape_pointers;
|
||||
argument_shape_pointers.reserve(argument_shapes.size());
|
||||
for (auto& argument_shape : argument_shapes) {
|
||||
argument_shape_pointers.push_back(&argument_shape);
|
||||
/*static*/ StatusOr<std::unique_ptr<PyLocalExecutable>>
|
||||
PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
std::vector<Shape> argument_layouts,
|
||||
const ExecutableBuildOptions* build_options,
|
||||
PyLocalClient* client) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
|
||||
std::vector<const Shape*> argument_layout_pointers;
|
||||
argument_layout_pointers.reserve(argument_layouts.size());
|
||||
|
||||
// Assign a default layout to any array subshapes that are missing layouts.
|
||||
auto assign_layouts = [client](Shape* shape) {
|
||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||
shape, [&](Shape* subshape, const ShapeIndex&) {
|
||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
TF_ASSIGN_OR_RETURN(*subshape,
|
||||
client->client()
|
||||
->backend()
|
||||
.transfer_manager()
|
||||
->ChooseCompactLayoutForShape(*subshape));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
|
||||
for (Shape& layout : argument_layouts) {
|
||||
argument_layout_pointers.push_back(&layout);
|
||||
assign_layouts(&layout);
|
||||
}
|
||||
|
||||
ExecutableBuildOptions options;
|
||||
if (build_options != nullptr) {
|
||||
options = *build_options;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto local_executable,
|
||||
client->Compile(computation, argument_shape_pointers, options));
|
||||
TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
|
||||
client->backend().computation_placer()->AssignDevices(
|
||||
options.num_replicas(), /*computation_count=*/1));
|
||||
|
||||
return absl::make_unique<LocalExecutableWrapper>(
|
||||
Shape result_layout;
|
||||
if (options.result_layout()) {
|
||||
result_layout = *options.result_layout();
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
computation.GetProgramShape());
|
||||
result_layout = program_shape.result();
|
||||
LayoutUtil::ClearLayout(&result_layout);
|
||||
}
|
||||
assign_layouts(&result_layout);
|
||||
options.set_result_layout(result_layout);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<LocalExecutable> local_executable,
|
||||
client->client()->Compile(
|
||||
computation, argument_layout_pointers, options));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
client->client()->backend().computation_placer()->AssignDevices(
|
||||
options.num_replicas(), /*computation_count=*/1));
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::move(local_executable), std::move(device_assignment), client);
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user