Ambiq squashed commits
This commit is contained in:
parent
d391ba441b
commit
9caf68cf7b
24
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
Normal file
24
.github/ISSUE_TEMPLATE/40-tflite-op-request.md
vendored
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
name: TensorFlow Lite Op Request
|
||||||
|
about: Use this template for reporting ops you are using or missing.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
**System information**
|
||||||
|
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||||
|
- TensorFlow installed from (source or binary):
|
||||||
|
- TensorFlow version (or github SHA if from source):
|
||||||
|
|
||||||
|
|
||||||
|
**Provide the text output from tflite_convert**
|
||||||
|
|
||||||
|
```
|
||||||
|
# Copy and paste here
|
||||||
|
```
|
||||||
|
|
||||||
|
Also, please include a link to a GraphDef or the model if possible.
|
||||||
|
|
||||||
|
**Any other info / logs**
|
||||||
|
|
||||||
|
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
@ -14,7 +14,8 @@ data flow graphs. The graph nodes represent mathematical operations, while
|
|||||||
the graph edges represent the multidimensional data arrays (tensors) that flow
|
the graph edges represent the multidimensional data arrays (tensors) that flow
|
||||||
between them. This flexible architecture enables you to deploy computation to one
|
between them. This flexible architecture enables you to deploy computation to one
|
||||||
or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
|
or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
|
||||||
code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit.
|
code. TensorFlow also includes [TensorBoard](https://github.com/tensorflow/tensorboard),
|
||||||
|
a data visualization toolkit.
|
||||||
|
|
||||||
TensorFlow was originally developed by researchers and engineers
|
TensorFlow was originally developed by researchers and engineers
|
||||||
working on the Google Brain team within Google's Machine Intelligence Research
|
working on the Google Brain team within Google's Machine Intelligence Research
|
||||||
@ -111,7 +112,7 @@ The TensorFlow project strives to abide by generally accepted best practices in
|
|||||||
Build Type | Status | Artifacts
|
Build Type | Status | Artifacts
|
||||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
|
||||||
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
|
**IBM s390x** | [](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
|
||||||
**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA
|
**IBM ppc64le CPU** | [](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
|
||||||
**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
**IBM ppc64le GPU** Nightly | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
|
||||||
**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
|
**IBM ppc64le GPU** Stable Release | [](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
|
||||||
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
**Linux CPU with Intel® MKL-DNN** Nightly | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
|
||||||
@ -127,6 +128,7 @@ Build Type
|
|||||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||||
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||||
|
* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||||
|
|
||||||
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
|
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
|
||||||
|
|
||||||
|
22
WORKSPACE
22
WORKSPACE
@ -1,5 +1,7 @@
|
|||||||
workspace(name = "org_tensorflow")
|
workspace(name = "org_tensorflow")
|
||||||
|
|
||||||
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "io_bazel_rules_closure",
|
name = "io_bazel_rules_closure",
|
||||||
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
|
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
|
||||||
@ -57,9 +59,9 @@ android_workspace()
|
|||||||
# Please add all new TensorFlow dependencies in workspace.bzl.
|
# Please add all new TensorFlow dependencies in workspace.bzl.
|
||||||
tf_workspace()
|
tf_workspace()
|
||||||
|
|
||||||
new_http_archive(
|
http_archive(
|
||||||
name = "inception_v1",
|
name = "inception_v1",
|
||||||
build_file = "models.BUILD",
|
build_file = "//:models.BUILD",
|
||||||
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
|
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
|
||||||
urls = [
|
urls = [
|
||||||
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
|
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
|
||||||
@ -67,9 +69,9 @@ new_http_archive(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
http_archive(
|
||||||
name = "mobile_ssd",
|
name = "mobile_ssd",
|
||||||
build_file = "models.BUILD",
|
build_file = "//:models.BUILD",
|
||||||
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
|
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
|
||||||
urls = [
|
urls = [
|
||||||
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
|
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
|
||||||
@ -77,9 +79,9 @@ new_http_archive(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
http_archive(
|
||||||
name = "mobile_multibox",
|
name = "mobile_multibox",
|
||||||
build_file = "models.BUILD",
|
build_file = "//:models.BUILD",
|
||||||
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
|
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
|
||||||
urls = [
|
urls = [
|
||||||
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
|
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
|
||||||
@ -87,9 +89,9 @@ new_http_archive(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
http_archive(
|
||||||
name = "stylize",
|
name = "stylize",
|
||||||
build_file = "models.BUILD",
|
build_file = "//:models.BUILD",
|
||||||
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
|
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
|
||||||
urls = [
|
urls = [
|
||||||
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
|
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
|
||||||
@ -97,9 +99,9 @@ new_http_archive(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
new_http_archive(
|
http_archive(
|
||||||
name = "speech_commands",
|
name = "speech_commands",
|
||||||
build_file = "models.BUILD",
|
build_file = "//:models.BUILD",
|
||||||
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
|
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
|
||||||
urls = [
|
urls = [
|
||||||
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",
|
||||||
|
22
configure.py
22
configure.py
@ -238,6 +238,13 @@ def setup_python(environ_cp):
|
|||||||
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
|
||||||
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
||||||
|
|
||||||
|
# If choosen python_lib_path is from a path specified in the PYTHONPATH
|
||||||
|
# variable, need to tell bazel to include PYTHONPATH
|
||||||
|
if environ_cp.get('PYTHONPATH'):
|
||||||
|
python_paths = environ_cp.get('PYTHONPATH').split(':')
|
||||||
|
if python_lib_path in python_paths:
|
||||||
|
write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
|
||||||
|
|
||||||
# Write tools/python_bin_path.sh
|
# Write tools/python_bin_path.sh
|
||||||
with open(
|
with open(
|
||||||
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
|
||||||
@ -445,11 +452,12 @@ def convert_version_to_int(version):
|
|||||||
return int(version_str)
|
return int(version_str)
|
||||||
|
|
||||||
|
|
||||||
def check_bazel_version(min_version):
|
def check_bazel_version(min_version, max_version):
|
||||||
"""Check installed bazel version is at least min_version.
|
"""Check installed bazel version is between min_version and max_version.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_version: string for minimum bazel version.
|
min_version: string for minimum bazel version.
|
||||||
|
max_version: string for maximum bazel version.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The bazel version detected.
|
The bazel version detected.
|
||||||
@ -467,6 +475,7 @@ def check_bazel_version(min_version):
|
|||||||
|
|
||||||
min_version_int = convert_version_to_int(min_version)
|
min_version_int = convert_version_to_int(min_version)
|
||||||
curr_version_int = convert_version_to_int(curr_version)
|
curr_version_int = convert_version_to_int(curr_version)
|
||||||
|
max_version_int = convert_version_to_int(max_version)
|
||||||
|
|
||||||
# Check if current bazel version can be detected properly.
|
# Check if current bazel version can be detected properly.
|
||||||
if not curr_version_int:
|
if not curr_version_int:
|
||||||
@ -480,6 +489,10 @@ def check_bazel_version(min_version):
|
|||||||
print('Please upgrade your bazel installation to version %s or higher to '
|
print('Please upgrade your bazel installation to version %s or higher to '
|
||||||
'build TensorFlow!' % min_version)
|
'build TensorFlow!' % min_version)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
if curr_version_int > max_version_int:
|
||||||
|
print('Please downgrade your bazel installation to version %s or lower to '
|
||||||
|
'build TensorFlow!' % max_version)
|
||||||
|
sys.exit(0)
|
||||||
return curr_version
|
return curr_version
|
||||||
|
|
||||||
|
|
||||||
@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp):
|
|||||||
cuda_toolkit_paths_full = [
|
cuda_toolkit_paths_full = [
|
||||||
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
|
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
|
||||||
]
|
]
|
||||||
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
|
if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Reset and retry
|
# Reset and retry
|
||||||
@ -1552,7 +1565,7 @@ def main():
|
|||||||
# environment variables.
|
# environment variables.
|
||||||
environ_cp = dict(os.environ)
|
environ_cp = dict(os.environ)
|
||||||
|
|
||||||
check_bazel_version('0.15.0')
|
check_bazel_version('0.15.0', '0.20.0')
|
||||||
|
|
||||||
reset_tf_configure_bazelrc()
|
reset_tf_configure_bazelrc()
|
||||||
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
|
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
|
||||||
@ -1694,6 +1707,7 @@ def main():
|
|||||||
config_info_line('nohdfs', 'Disable HDFS support.')
|
config_info_line('nohdfs', 'Disable HDFS support.')
|
||||||
config_info_line('noignite', 'Disable Apacha Ignite support.')
|
config_info_line('noignite', 'Disable Apacha Ignite support.')
|
||||||
config_info_line('nokafka', 'Disable Apache Kafka support.')
|
config_info_line('nokafka', 'Disable Apache Kafka support.')
|
||||||
|
config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = (
|
|||||||
TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
|
TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# @unused
|
||||||
|
TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = (
|
||||||
|
TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
# Config setting used when building for products
|
# Config setting used when building for products
|
||||||
# which requires restricted licenses to be avoided.
|
# which requires restricted licenses to be avoided.
|
||||||
config_setting(
|
config_setting(
|
||||||
@ -213,31 +218,37 @@ config_setting(
|
|||||||
#
|
#
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "no_aws_support",
|
name = "no_aws_support",
|
||||||
define_values = {"no_aws_support": "false"},
|
define_values = {"no_aws_support": "true"},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "no_gcp_support",
|
name = "no_gcp_support",
|
||||||
define_values = {"no_gcp_support": "false"},
|
define_values = {"no_gcp_support": "true"},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "no_hdfs_support",
|
name = "no_hdfs_support",
|
||||||
define_values = {"no_hdfs_support": "false"},
|
define_values = {"no_hdfs_support": "true"},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "no_ignite_support",
|
name = "no_ignite_support",
|
||||||
define_values = {"no_ignite_support": "false"},
|
define_values = {"no_ignite_support": "true"},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "no_kafka_support",
|
name = "no_kafka_support",
|
||||||
define_values = {"no_kafka_support": "false"},
|
define_values = {"no_kafka_support": "true"},
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
||||||
|
|
||||||
|
config_setting(
|
||||||
|
name = "no_nccl_support",
|
||||||
|
define_values = {"no_nccl_support": "true"},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -350,7 +361,7 @@ package_group(
|
|||||||
"-//third_party/tensorflow/python/estimator",
|
"-//third_party/tensorflow/python/estimator",
|
||||||
"//learning/meta_rank/...",
|
"//learning/meta_rank/...",
|
||||||
"//tensorflow/...",
|
"//tensorflow/...",
|
||||||
"//tensorflow_estimator/...",
|
"//tensorflow_estimator/contrib/...",
|
||||||
"//tensorflow_fold/llgtm/...",
|
"//tensorflow_fold/llgtm/...",
|
||||||
"//tensorflow_text/...",
|
"//tensorflow_text/...",
|
||||||
"//third_party/py/tensor2tensor/...",
|
"//third_party/py/tensor2tensor/...",
|
||||||
@ -554,18 +565,24 @@ genrule(
|
|||||||
}),
|
}),
|
||||||
outs = ["__init__.py"],
|
outs = ["__init__.py"],
|
||||||
cmd = select({
|
cmd = select({
|
||||||
"api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
|
"api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
|
||||||
"//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
|
"//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_api_init_files(
|
gen_api_init_files(
|
||||||
name = "tf_python_api_gen_v1",
|
name = "tf_python_api_gen_v1",
|
||||||
srcs = ["api_template_v1.__init__.py"],
|
srcs = [
|
||||||
|
"api_template_v1.__init__.py",
|
||||||
|
"compat_template_v1.__init__.py",
|
||||||
|
],
|
||||||
api_version = 1,
|
api_version = 1,
|
||||||
|
compat_api_versions = [1],
|
||||||
|
compat_init_templates = ["compat_template_v1.__init__.py"],
|
||||||
output_dir = "_api/v1/",
|
output_dir = "_api/v1/",
|
||||||
output_files = TENSORFLOW_API_INIT_FILES_V1,
|
output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT,
|
||||||
output_package = "tensorflow._api.v1",
|
output_package = "tensorflow._api.v1",
|
||||||
|
root_file_name = "v1.py",
|
||||||
root_init_template = "api_template_v1.__init__.py",
|
root_init_template = "api_template_v1.__init__.py",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -581,6 +598,7 @@ gen_api_init_files(
|
|||||||
output_dir = "_api/v2/",
|
output_dir = "_api/v2/",
|
||||||
output_files = TENSORFLOW_API_INIT_FILES_V2,
|
output_files = TENSORFLOW_API_INIT_FILES_V2,
|
||||||
output_package = "tensorflow._api.v2",
|
output_package = "tensorflow._api.v2",
|
||||||
|
root_file_name = "v2.py",
|
||||||
root_init_template = "api_template.__init__.py",
|
root_init_template = "api_template.__init__.py",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,8 +21,6 @@ from __future__ import print_function as _print_function
|
|||||||
import os as _os
|
import os as _os
|
||||||
|
|
||||||
# pylint: disable=g-bad-import-order
|
# pylint: disable=g-bad-import-order
|
||||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
|
||||||
|
|
||||||
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
from tensorflow.python.tools import component_api_helper as _component_api_helper
|
||||||
_component_api_helper.package_hook(
|
_component_api_helper.package_hook(
|
||||||
parent_package_str=__name__,
|
parent_package_str=__name__,
|
||||||
@ -30,16 +28,16 @@ _component_api_helper.package_hook(
|
|||||||
|
|
||||||
# API IMPORTS PLACEHOLDER
|
# API IMPORTS PLACEHOLDER
|
||||||
|
|
||||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
|
||||||
|
|
||||||
# Make sure directory containing top level submodules is in
|
# Make sure directory containing top level submodules is in
|
||||||
# the __path__ so that "from tensorflow.foo import bar" works.
|
# the __path__ so that "from tensorflow.foo import bar" works.
|
||||||
_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
|
# We're using bitwise, but there's nothing special about that.
|
||||||
|
_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
|
||||||
if _tf_api_dir not in __path__:
|
if _tf_api_dir not in __path__:
|
||||||
__path__.append(_tf_api_dir)
|
__path__.append(_tf_api_dir)
|
||||||
|
|
||||||
# Calls to enable and disable features.
|
# Enable TF2 behaviors
|
||||||
enable_eager_execution() # pylint: disable=undefined-variable
|
from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
|
||||||
|
_compat.enable_v2_behavior()
|
||||||
|
|
||||||
# These symbols appear because we import the python package which
|
# These symbols appear because we import the python package which
|
||||||
# in turn imports from tensorflow.core and tensorflow.python. They
|
# in turn imports from tensorflow.core and tensorflow.python. They
|
||||||
|
@ -60,6 +60,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:op_gen_lib",
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
@ -120,7 +121,8 @@ tf_cuda_library(
|
|||||||
":c_api",
|
":c_api",
|
||||||
":c_api_internal",
|
":c_api_internal",
|
||||||
"//tensorflow/c/eager:c_api",
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
"//tensorflow/c/eager:c_api_internal",
|
||||||
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/contrib/tpu:all_ops",
|
"//tensorflow/contrib/tpu:all_ops",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -173,6 +175,30 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "kernels",
|
||||||
|
srcs = [
|
||||||
|
"kernels.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"kernels.h",
|
||||||
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
":c_api",
|
||||||
|
":c_api_internal",
|
||||||
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
":c_api",
|
||||||
|
":c_api_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Tests
|
# Tests
|
||||||
|
|
||||||
@ -208,7 +234,10 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
tags = ["noasan"],
|
tags = [
|
||||||
|
"no_oss", # http://b/119522529
|
||||||
|
"noasan",
|
||||||
|
],
|
||||||
# We must ensure that the dependencies can be dynamically linked since
|
# We must ensure that the dependencies can be dynamically linked since
|
||||||
# the shared library must be able to use core:framework.
|
# the shared library must be able to use core:framework.
|
||||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
@ -237,7 +266,7 @@ tf_cuda_cc_test(
|
|||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "c_api_experimental_test",
|
name = "c_api_experimental_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["c_api_experimental_test.cc"],
|
srcs = ["c_api_experimental_test.cc"],
|
||||||
data = ["testdata/tf_record"],
|
data = ["testdata/tf_record"],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
@ -248,8 +277,11 @@ tf_cc_test(
|
|||||||
# the shared library must be able to use core:framework.
|
# the shared library must be able to use core:framework.
|
||||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
deps = [
|
deps = [
|
||||||
|
":c_api",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_test_util",
|
":c_test_util",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
|
"//tensorflow/c/eager:c_api_test_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
@ -300,6 +332,30 @@ tf_kernel_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "kernels_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["kernels_test.cc"],
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
|
tags = ["noasan"],
|
||||||
|
# We must ensure that the dependencies can be dynamically linked since
|
||||||
|
# the shared library must be able to use core:framework.
|
||||||
|
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":kernels",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:proto_text",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Python API target
|
# Python API target
|
||||||
|
|
||||||
|
@ -15,13 +15,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/platform/net.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
#include "tensorflow/core/protobuf/config.pb.h"
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
|
|||||||
// These XLA flags are needed to trigger XLA properly from C (more generally
|
// These XLA flags are needed to trigger XLA properly from C (more generally
|
||||||
// non-Python) clients. If this API is called again with `enable` set to
|
// non-Python) clients. If this API is called again with `enable` set to
|
||||||
// false, it is safe to keep these flag values as is.
|
// false, it is safe to keep these flag values as is.
|
||||||
tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
|
tensorflow::MarkForCompilationPassFlags* flags =
|
||||||
tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
|
tensorflow::GetMarkForCompilationPassFlags();
|
||||||
flags->tf_xla_cpu_global_jit = true;
|
flags->tf_xla_cpu_global_jit = true;
|
||||||
flags->tf_xla_min_cluster_size = 1;
|
flags->tf_xla_min_cluster_size = 1;
|
||||||
} else {
|
} else {
|
||||||
@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
|
|||||||
// These XLA flags are needed to trigger XLA properly from C (more generally
|
// These XLA flags are needed to trigger XLA properly from C (more generally
|
||||||
// non-Python) clients. If this API is called again with `enable` set to
|
// non-Python) clients. If this API is called again with `enable` set to
|
||||||
// false, it is safe to keep these flag values as is.
|
// false, it is safe to keep these flag values as is.
|
||||||
tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
|
tensorflow::MarkForCompilationPassFlags* flags =
|
||||||
tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
|
tensorflow::GetMarkForCompilationPassFlags();
|
||||||
flags->tf_xla_cpu_global_jit = true;
|
flags->tf_xla_cpu_global_jit = true;
|
||||||
flags->tf_xla_min_cluster_size = 1;
|
flags->tf_xla_min_cluster_size = 1;
|
||||||
} else {
|
} else {
|
||||||
@ -6525,7 +6530,7 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset/cycle_length"
|
name: "ExperimentalParallelInterleaveDataset/cycle_length"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
@ -6546,7 +6551,7 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset/block_length"
|
name: "ExperimentalParallelInterleaveDataset/block_length"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
@ -6567,7 +6572,7 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset/sloppy"
|
name: "ExperimentalParallelInterleaveDataset/sloppy"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
@ -6588,7 +6593,7 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset/buffer_output_elements"
|
name: "ExperimentalParallelInterleaveDataset/buffer_output_elements"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
@ -6609,7 +6614,7 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset/prefetch_input_elements"
|
name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
@ -6630,14 +6635,14 @@ library {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_def {
|
node_def {
|
||||||
name: "ParallelInterleaveDataset"
|
name: "ExperimentalParallelInterleaveDataset"
|
||||||
op: "ParallelInterleaveDataset"
|
op: "ExperimentalParallelInterleaveDataset"
|
||||||
input: "RepeatDataset:handle:0"
|
input: "RepeatDataset:handle:0"
|
||||||
input: "ParallelInterleaveDataset/cycle_length:output:0"
|
input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0"
|
||||||
input: "ParallelInterleaveDataset/block_length:output:0"
|
input: "ExperimentalParallelInterleaveDataset/block_length:output:0"
|
||||||
input: "ParallelInterleaveDataset/sloppy:output:0"
|
input: "ExperimentalParallelInterleaveDataset/sloppy:output:0"
|
||||||
input: "ParallelInterleaveDataset/buffer_output_elements:output:0"
|
input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0"
|
||||||
input: "ParallelInterleaveDataset/prefetch_input_elements:output:0"
|
input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0"
|
||||||
attr {
|
attr {
|
||||||
key: "Targuments"
|
key: "Targuments"
|
||||||
value {
|
value {
|
||||||
@ -6737,7 +6742,7 @@ library {
|
|||||||
node_def {
|
node_def {
|
||||||
name: "ShuffleDataset_2"
|
name: "ShuffleDataset_2"
|
||||||
op: "ShuffleDataset"
|
op: "ShuffleDataset"
|
||||||
input: "ParallelInterleaveDataset:handle:0"
|
input: "ExperimentalParallelInterleaveDataset:handle:0"
|
||||||
input: "ShuffleDataset_2/buffer_size_1:output:0"
|
input: "ShuffleDataset_2/buffer_size_1:output:0"
|
||||||
input: "ShuffleDataset_2/seed_2:output:0"
|
input: "ShuffleDataset_2/seed_2:output:0"
|
||||||
input: "ShuffleDataset_2/seed2_2:output:0"
|
input: "ShuffleDataset_2/seed2_2:output:0"
|
||||||
@ -8739,14 +8744,65 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
struct TFE_ExecuteOpNotification {
|
||||||
const char* errMsg) {
|
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
|
||||||
|
tensorflow::Notification n;
|
||||||
|
std::unique_ptr<tensorflow::Thread> thread;
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
|
||||||
|
};
|
||||||
|
|
||||||
|
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
||||||
|
TFE_TensorHandle** retvals,
|
||||||
|
int* num_retvals,
|
||||||
|
TF_Status* status) {
|
||||||
|
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
||||||
|
|
||||||
|
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
||||||
|
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
||||||
|
[op, retvals, num_retvals, n]() {
|
||||||
|
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
||||||
|
n->n.Notify();
|
||||||
|
}));
|
||||||
|
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||||
|
TFE_ExecuteOpNotification* notification, TF_Status* status) {
|
||||||
|
if (notification == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Passed in notification is a nullptr.");
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (notification->thread == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"Passed in notification didn't start a thread correctly. Cleaning up "
|
||||||
|
"this notification. Please re-execute the operation to get a new "
|
||||||
|
"notification.");
|
||||||
|
|
||||||
|
delete notification;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
notification->n.WaitForNotification();
|
||||||
|
|
||||||
|
status->status = notification->status->status;
|
||||||
|
|
||||||
|
delete notification;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
||||||
status->status = tensorflow::errors::Internal(errMsg);
|
status->status = tensorflow::errors::Internal(errMsg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This builder is used in the eager API to build a NodeDef.
|
// This builder is used in the eager API to build a NodeDef.
|
||||||
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
||||||
using tensorflow::AttrBuilder::AttrBuilder;
|
using tensorflow::AttrBuilder::AttrBuilder;
|
||||||
|
// The string buffers to make sure that any `attr_name` we pass into
|
||||||
|
// `builder->Set()` will outlive the subsequent
|
||||||
|
// `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
|
||||||
|
std::set<std::string> attr_names;
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
|
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
|
||||||
@ -8757,13 +8813,15 @@ void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
|
|||||||
|
|
||||||
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
||||||
TF_DataType value) {
|
TF_DataType value) {
|
||||||
builder->Set(attr_name, static_cast<tensorflow::DataType>(value));
|
auto iter = builder->attr_names.insert(attr_name).first;
|
||||||
|
builder->Set((*iter).c_str(), static_cast<tensorflow::DataType>(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
||||||
const TF_DataType* values, int num_values) {
|
const TF_DataType* values, int num_values) {
|
||||||
|
auto iter = builder->attr_names.insert(attr_name).first;
|
||||||
builder->Set(
|
builder->Set(
|
||||||
attr_name,
|
(*iter).c_str(),
|
||||||
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
||||||
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
||||||
}
|
}
|
||||||
@ -8800,3 +8858,31 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
|
|||||||
// The returned string is owned by OpRegistry, so liveness is not a concern.
|
// The returned string is owned by OpRegistry, so liveness is not a concern.
|
||||||
return input_arg.number_attr().c_str();
|
return input_arg.number_attr().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TF_OpIsStateful(const char* op_type, TF_Status* status) {
|
||||||
|
const tensorflow::OpRegistrationData* op_reg_data;
|
||||||
|
status->status =
|
||||||
|
tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
|
||||||
|
if (!status->status.ok()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return op_reg_data->op_def.is_stateful();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_InitMain(const char* usage, int* argc, char*** argv) {
|
||||||
|
tensorflow::port::InitMain(usage, argc, argv);
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_PickUnusedPortOrDie() {
|
||||||
|
return tensorflow::internal::PickUnusedPortOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg,
|
||||||
|
void* data, size_t len) {
|
||||||
|
auto dtype = static_cast<tensorflow::DataType>(dtype_arg);
|
||||||
|
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
|
||||||
|
|
||||||
|
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
|
||||||
|
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||||
|
return new TFE_TensorHandle(tensor, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
|||||||
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
|
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
|
||||||
TFE_TensorHandle* handle);
|
TFE_TensorHandle* handle);
|
||||||
|
|
||||||
|
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
|
||||||
|
|
||||||
|
// Allows invoking a kernel asynchronously, and explicitly returns a
|
||||||
|
// notification that can be waited upon. This always executes the kernel in a
|
||||||
|
// new thread.
|
||||||
|
// 1. `retvals` and `num_retvals` can only be consumed after
|
||||||
|
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
|
||||||
|
// if the return is unsuccessful
|
||||||
|
// 2. These new APIs cannot be used together with the TFE context level async
|
||||||
|
// support.
|
||||||
|
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
|
||||||
|
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Waits to complete the op execution, and cleans up the notification.
|
||||||
|
// Errors reported by op execution are set in `status`.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
|
||||||
|
TFE_ExecuteOpNotification* notification, TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
|
||||||
const char* errMsg);
|
const char* errMsg);
|
||||||
|
|
||||||
@ -209,6 +228,24 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
|
|||||||
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
|
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
|
||||||
const char* op_name, int input_index, TF_Status* status);
|
const char* op_name, int input_index, TF_Status* status);
|
||||||
|
|
||||||
|
// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined
|
||||||
|
// if the status is not ok.
|
||||||
|
TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Platform specific initialization routine. Very few platforms actually require
|
||||||
|
// this to be called.
|
||||||
|
TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
|
||||||
|
|
||||||
|
// Platform-specific implementation to return an unused port. (This should used
|
||||||
|
// in tests only.)
|
||||||
|
TF_CAPI_EXPORT int TF_PickUnusedPortOrDie();
|
||||||
|
|
||||||
|
// Fast path method that makes constructing a single scalar tensor require less
|
||||||
|
// overhead and copies.
|
||||||
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
|
||||||
|
TF_DataType dtype, void* scalar, size_t len);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
#include "tensorflow/c/c_test_util.h"
|
#include "tensorflow/c/c_test_util.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -162,5 +164,137 @@ protocol: "grpc"
|
|||||||
TF_DeleteStatus(status);
|
TF_DeleteStatus(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, IsStateful) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
int assign = TF_OpIsStateful("AssignAddVariableOp", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
EXPECT_EQ(assign, 1);
|
||||||
|
int id = TF_OpIsStateful("Identity", status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
EXPECT_EQ(id, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||||
|
|
||||||
|
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
|
||||||
|
auto* r =
|
||||||
|
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
|
||||||
|
|
||||||
|
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
float product[4] = {0};
|
||||||
|
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||||
|
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(7, product[0]);
|
||||||
|
EXPECT_EQ(10, product[1]);
|
||||||
|
EXPECT_EQ(15, product[2]);
|
||||||
|
EXPECT_EQ(22, product[3]);
|
||||||
|
|
||||||
|
TFE_DeleteOp(matmul_op);
|
||||||
|
TFE_DeleteTensorHandle(m);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform a send/recv test. Recv blocks, so they need to be executed
|
||||||
|
// asynchronously.
|
||||||
|
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
|
||||||
|
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||||
|
|
||||||
|
// Build a send op.
|
||||||
|
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(send_op, m, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
string tensor_name = "Tensor";
|
||||||
|
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
|
||||||
|
tensor_name.size());
|
||||||
|
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
|
||||||
|
send_device.size());
|
||||||
|
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
|
||||||
|
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
|
||||||
|
recv_device.size());
|
||||||
|
TFE_OpSetAttrBool(send_op, "client_terminated", true);
|
||||||
|
|
||||||
|
// Build a recv op.
|
||||||
|
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
|
||||||
|
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
|
||||||
|
tensor_name.size());
|
||||||
|
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
|
||||||
|
send_device.size());
|
||||||
|
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
|
||||||
|
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
|
||||||
|
recv_device.size());
|
||||||
|
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
|
||||||
|
|
||||||
|
TFE_TensorHandle* send_retvals;
|
||||||
|
int send_num_retvals = 0;
|
||||||
|
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
|
||||||
|
&send_num_retvals, status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* recv_retvals[1] = {nullptr};
|
||||||
|
int recv_num_retvals = 1;
|
||||||
|
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
|
||||||
|
&recv_num_retvals, status);
|
||||||
|
|
||||||
|
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
float product[4] = {0};
|
||||||
|
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||||
|
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(1, product[0]);
|
||||||
|
EXPECT_EQ(2, product[1]);
|
||||||
|
EXPECT_EQ(3, product[2]);
|
||||||
|
EXPECT_EQ(4, product[3]);
|
||||||
|
|
||||||
|
TFE_DeleteOp(send_op);
|
||||||
|
TFE_DeleteOp(recv_op);
|
||||||
|
TFE_DeleteTensorHandle(m);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(recv_retvals[0]);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -392,26 +392,26 @@ Status ProcessInputs(
|
|||||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||||
input_tensors->reserve(ninputs);
|
input_tensors->reserve(ninputs);
|
||||||
for (int i = 0; i < ninputs; ++i) {
|
for (int i = 0; i < ninputs; ++i) {
|
||||||
const Node& node = inputs[i].oper->node;
|
Node* node = &inputs[i].oper->node;
|
||||||
int idx = inputs[i].index;
|
int idx = inputs[i].index;
|
||||||
|
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
fn_body->graph.IsValidOutputTensor(node, idx),
|
||||||
"Encountered while processing input ", i, " into function '", fn_name,
|
"Encountered while processing input ", i, " into function '", fn_name,
|
||||||
"'");
|
"'");
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
|
||||||
"Encountered while processing input ", i,
|
"Encountered while processing input ", i,
|
||||||
" into function '", fn_name, "'");
|
" into function '", fn_name, "'");
|
||||||
|
|
||||||
input_tensors->emplace_back(&node, idx);
|
input_tensors->emplace_back(node, idx);
|
||||||
|
|
||||||
const auto& iter = input_nodes->find(&node);
|
const auto& iter = input_nodes->find(node);
|
||||||
if (iter == input_nodes->end()) {
|
if (iter == input_nodes->end()) {
|
||||||
input_nodes->insert({&node, {idx}});
|
input_nodes->insert({node, {idx}});
|
||||||
} else {
|
} else {
|
||||||
auto& indices = iter->second;
|
auto& indices = iter->second;
|
||||||
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
|
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
|
||||||
return InvalidArgument("TF_Output ", node.name(), ":", idx,
|
return InvalidArgument("TF_Output ", node->name(), ":", idx,
|
||||||
" appears more than once in the input list");
|
" appears more than once in the input list");
|
||||||
}
|
}
|
||||||
indices.push_back(idx);
|
indices.push_back(idx);
|
||||||
@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
|||||||
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||||
output_tensors->reserve(noutputs);
|
output_tensors->reserve(noutputs);
|
||||||
for (int i = 0; i < noutputs; ++i) {
|
for (int i = 0; i < noutputs; ++i) {
|
||||||
const Node& node = outputs[i].oper->node;
|
Node* node = &outputs[i].oper->node;
|
||||||
int idx = outputs[i].index;
|
int idx = outputs[i].index;
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
fn_body->graph.IsValidOutputTensor(&node, idx),
|
fn_body->graph.IsValidOutputTensor(node, idx),
|
||||||
"Encountered while processing output ", i, " from function '", fn_name,
|
"Encountered while processing output ", i, " from function '", fn_name,
|
||||||
"'");
|
"'");
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
|
||||||
"Encountered while creating function '",
|
"Encountered while creating function '",
|
||||||
fn_name, "'");
|
fn_name, "'");
|
||||||
output_tensors->emplace_back(&node, idx);
|
output_tensors->emplace_back(node, idx);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ tf_cuda_library(
|
|||||||
],
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}) + [
|
}) + [
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||||
@ -143,6 +144,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,9 +21,11 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/core/platform/host_info.h"
|
||||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||||
@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
const std::vector<string>& remote_workers,
|
const std::vector<string>& remote_workers,
|
||||||
tensorflow::WorkerCacheInterface* worker_cache,
|
tensorflow::WorkerCacheInterface* worker_cache,
|
||||||
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
|
||||||
std::vector<tensorflow::Device*> remote_devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
||||||
tensorflow::Status status;
|
tensorflow::Status status;
|
||||||
// TODO(nareshmodi) do this in parallel instead of serially.
|
// TODO(nareshmodi) do this in parallel instead of serially.
|
||||||
for (const string& remote_worker : remote_workers) {
|
for (const string& remote_worker : remote_workers) {
|
||||||
@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
status = s;
|
status = s;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
for (tensorflow::Device* d : *devices) {
|
for (tensorflow::Device* d : *devices) {
|
||||||
remote_devices.push_back(d);
|
remote_devices.emplace_back(d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n.Notify();
|
n.Notify();
|
||||||
@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
n.WaitForNotification();
|
n.WaitForNotification();
|
||||||
}
|
}
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
|
||||||
new tensorflow::DeviceMgr(remote_devices));
|
new tensorflow::DeviceMgr(std::move(remote_devices)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
|
|||||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||||
|
|
||||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||||
std::vector<tensorflow::Device*> devices;
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
||||||
status->status = tensorflow::DeviceFactory::AddDevices(
|
status->status = tensorflow::DeviceFactory::AddDevices(
|
||||||
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
||||||
&devices);
|
&devices);
|
||||||
if (!status->status.ok()) return nullptr;
|
if (!status->status.ok()) return nullptr;
|
||||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
||||||
new tensorflow::DeviceMgr(devices));
|
new tensorflow::DeviceMgr(std::move(devices)));
|
||||||
|
|
||||||
tensorflow::Rendezvous* r =
|
tensorflow::Rendezvous* r =
|
||||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||||
@ -409,6 +411,18 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
: d->name().c_str();
|
: d->name().c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
||||||
|
TF_Status* status) {
|
||||||
|
if (h == nullptr || h->handle == nullptr) {
|
||||||
|
status->status = tensorflow::errors::InvalidArgument(
|
||||||
|
"The passed in handle is a nullptr");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
tensorflow::Device* d = h->handle->device();
|
||||||
|
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
|
: d->name().c_str();
|
||||||
|
}
|
||||||
|
|
||||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||||
TFE_TensorHandle* h, TF_Status* status) {
|
TFE_TensorHandle* h, TF_Status* status) {
|
||||||
if (h == nullptr || h->handle == nullptr) {
|
if (h == nullptr || h->handle == nullptr) {
|
||||||
@ -458,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
const char* name = op_or_function_name; // Shorthand
|
const char* name = op_or_function_name; // Shorthand
|
||||||
const tensorflow::AttrTypeMap* types;
|
const tensorflow::AttrTypeMap* types;
|
||||||
status->status = tensorflow::AttrTypeMapForOp(name, &types);
|
bool is_function = false;
|
||||||
if (status->status.ok()) return new TFE_Op(ctx, name, types);
|
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
|
||||||
if (TF_GetCode(status) == TF_NOT_FOUND) {
|
if (status->status.ok()) {
|
||||||
if (ctx->context.FindFunctionByName(name)) {
|
if (is_function && !ctx->context.FindFunctionByName(name)) {
|
||||||
status->status = tensorflow::Status::OK();
|
status->status = tensorflow::errors::NotFound(
|
||||||
return new TFE_Op(ctx, name, nullptr);
|
"'", name,
|
||||||
|
"' is neither a type of a primitive operation nor a name "
|
||||||
|
"of a function registered in binary running on ",
|
||||||
|
tensorflow::port::Hostname(),
|
||||||
|
". Make sure the operation or function is "
|
||||||
|
"registered in the binary running in this process.");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
return new TFE_Op(ctx, name, is_function, types);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -497,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
|||||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||||
unsigned char* is_list, TF_Status* status) {
|
unsigned char* is_list, TF_Status* status) {
|
||||||
TF_AttrType ret;
|
TF_AttrType ret;
|
||||||
if (op->operation.is_function()) {
|
|
||||||
status->status = tensorflow::errors::Unimplemented(
|
|
||||||
"TODO(apassos): Support for attributes for TensorFlow functions is not "
|
|
||||||
"ready yet.");
|
|
||||||
return TF_ATTR_INT; // The compiler requires that we return something.
|
|
||||||
}
|
|
||||||
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
|
||||||
attr_name, &ret, is_list);
|
attr_name, &ret, is_list);
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
|
|||||||
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
|
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
|
||||||
int dim_index,
|
int dim_index,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Returns the device of the operation that produced `h`.
|
||||||
|
// If `h` was produced by a copy, returns the destination device of
|
||||||
|
// the copy. Note that returned device name is not always the device
|
||||||
|
// holding the tensor handle's memory. If you want the latter, use
|
||||||
|
// TFE_TensorHandleBackingDeviceName.
|
||||||
|
// This function will block till the operation that produces `h` has completed.
|
||||||
|
//
|
||||||
|
// Device on which the kernel of the operation that produced `h` ran.
|
||||||
|
//
|
||||||
|
// If `h` was produced by a copy, returns the destination device of
|
||||||
|
// the copy.
|
||||||
|
//
|
||||||
|
// Note that returned device name is not always the device that owns the memory
|
||||||
|
// that backs the tensor handle. For the latter see
|
||||||
|
// TFE_TensorHandleBackingDeviceName.
|
||||||
|
//
|
||||||
// This function will block till the operation that produces `h` has completed.
|
// This function will block till the operation that produces `h` has completed.
|
||||||
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
|
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
|
||||||
TFE_TensorHandle* h, TF_Status* status);
|
TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
|
// Returns the name of the device in whose memory `h` resides.
|
||||||
|
//
|
||||||
|
// This function will block till the operation that produces `h` has completed.
|
||||||
|
TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
|
||||||
|
TFE_TensorHandle* h, TF_Status* status);
|
||||||
|
|
||||||
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
|
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
|
||||||
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
|
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
|
||||||
// the error and a nullptr is returned.
|
// the error and a nullptr is returned.
|
||||||
|
@ -93,10 +93,9 @@ struct TFE_TensorDebugInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TFE_Op {
|
struct TFE_Op {
|
||||||
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
|
TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
|
||||||
// primitive operation.
|
const tensorflow::AttrTypeMap* t)
|
||||||
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
|
: operation(&ctx->context, op, is_function, t) {}
|
||||||
: operation(&ctx->context, op, t) {}
|
|
||||||
|
|
||||||
tensorflow::EagerOperation operation;
|
tensorflow::EagerOperation operation;
|
||||||
};
|
};
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
|
|||||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
const int num_devices = TF_DeviceListCount(devices);
|
const int num_devices = TF_DeviceListCount(devices);
|
||||||
|
bool has_gpu0 = false;
|
||||||
|
bool has_gpu1 = false;
|
||||||
|
for (int i = 0; i < num_devices; ++i) {
|
||||||
|
const char* dev = TF_DeviceListName(devices, i, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
string device_name(dev);
|
||||||
|
if (device_name.find("GPU:0") != string::npos) {
|
||||||
|
has_gpu0 = true;
|
||||||
|
}
|
||||||
|
if (device_name.find("GPU:1") != string::npos) {
|
||||||
|
has_gpu1 = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const char* kCPUDevice = "CPU:0";
|
const char* kCPUDevice = "CPU:0";
|
||||||
if (num_devices < 3) {
|
if (!has_gpu0 || !has_gpu1) {
|
||||||
TF_DeleteDeviceList(devices);
|
TF_DeleteDeviceList(devices);
|
||||||
TF_DeleteTensor(t);
|
TF_DeleteTensor(t);
|
||||||
TFE_DeleteTensorHandle(hcpu);
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) {
|
|||||||
|
|
||||||
TF_SetStatus(status.get(), TF_OK, "");
|
TF_SetStatus(status.get(), TF_OK, "");
|
||||||
|
|
||||||
|
device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
|
||||||
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
ASSERT_EQ(device_name, nullptr);
|
||||||
|
ASSERT_EQ("The passed in handle is a nullptr",
|
||||||
|
string(TF_Message(status.get())));
|
||||||
|
|
||||||
|
TF_SetStatus(status.get(), TF_OK, "");
|
||||||
|
|
||||||
int num_dims = TFE_TensorHandleNumDims(h, status.get());
|
int num_dims = TFE_TensorHandleNumDims(h, status.get());
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
ASSERT_EQ(num_dims, -1);
|
ASSERT_EQ(num_dims, -1);
|
||||||
@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) {
|
|||||||
string(TF_Message(status.get())));
|
string(TF_Message(status.get())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TensorHandleDevices) {
|
||||||
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||||
|
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
|
||||||
|
const char* backing_device_name =
|
||||||
|
TFE_TensorHandleBackingDeviceName(hcpu, status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
|
||||||
|
<< backing_device_name;
|
||||||
|
|
||||||
|
// Disable the test if no GPU is present.
|
||||||
|
string gpu_device_name;
|
||||||
|
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||||
|
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||||
|
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
|
||||||
|
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
TFE_TensorHandle* retvals[1];
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
|
||||||
|
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||||
|
|
||||||
|
// .device of shape is GPU since the op is executed on GPU
|
||||||
|
device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
|
||||||
|
|
||||||
|
// .backing_device of shape is CPU since the tensor is backed by CPU
|
||||||
|
backing_device_name =
|
||||||
|
TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
|
||||||
|
<< backing_device_name;
|
||||||
|
|
||||||
|
TFE_DeleteOp(shape_op);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
TFE_DeleteTensorHandle(hgpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(hcpu);
|
||||||
|
TFE_ContextAsyncWait(ctx, status.get());
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
void Execute_MatMul_CPU(bool async) {
|
void Execute_MatMul_CPU(bool async) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
|||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, a, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
|
||||||
|
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TestAxisTensorHandle() {
|
TFE_TensorHandle* TestAxisTensorHandle() {
|
||||||
int64_t dims[] = {1};
|
int64_t dims[] = {1};
|
||||||
int data[] = {1};
|
int data[] = {1};
|
||||||
|
@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
|||||||
// Return a matmul op multiplying `a` by `b`.
|
// Return a matmul op multiplying `a` by `b`.
|
||||||
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||||
|
|
||||||
|
// Return a shape op fetching the shape of `a`.
|
||||||
|
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||||
|
|
||||||
// Return an 1-D INT32 tensor containing a single value 1.
|
// Return an 1-D INT32 tensor containing a single value 1.
|
||||||
TFE_TensorHandle* TestAxisTensorHandle();
|
TFE_TensorHandle* TestAxisTensorHandle();
|
||||||
|
|
||||||
|
@ -141,8 +141,9 @@ class GradientTape {
|
|||||||
// null. The result is populated with one tensor per target element.
|
// null. The result is populated with one tensor per target element.
|
||||||
Status ComputeGradient(
|
Status ComputeGradient(
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||||
gtl::ArraySlice<int64> target_tensor_ids,
|
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||||
gtl::ArraySlice<int64> source_tensor_id,
|
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||||
|
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
|
||||||
gtl::ArraySlice<Gradient*> output_gradients,
|
gtl::ArraySlice<Gradient*> output_gradients,
|
||||||
std::vector<Gradient*>* result);
|
std::vector<Gradient*>* result);
|
||||||
|
|
||||||
@ -396,6 +397,7 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
|||||||
Status InitialGradients(
|
Status InitialGradients(
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||||
gtl::ArraySlice<int64> target_tensor_ids,
|
gtl::ArraySlice<int64> target_tensor_ids,
|
||||||
|
gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
|
||||||
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
|
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
|
||||||
const OpTape<BackwardFunction, TapeTensor>& op_tape,
|
const OpTape<BackwardFunction, TapeTensor>& op_tape,
|
||||||
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
|
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
|
||||||
@ -425,8 +427,13 @@ Status InitialGradients(
|
|||||||
"none of operations outputs match expected tensor");
|
"none of operations outputs match expected tensor");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No record of the target tensor found on the tape, so no gradient
|
// This target tensor was not generated by any operation recorded on
|
||||||
// needs to be computed from it. Do nothing.
|
// the tape, so no gradient needs to be computed from it unless this
|
||||||
|
// target is also a source.
|
||||||
|
auto source_tensor = sources_that_are_targets.find(id);
|
||||||
|
if (source_tensor != sources_that_are_targets.end()) {
|
||||||
|
(*result)[id].push_back(vspace.Ones(source_tensor->second));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(*result)[id].push_back(output_gradients[i]);
|
(*result)[id].push_back(output_gradients[i]);
|
||||||
@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
|
|||||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||||
Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||||
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
|
||||||
gtl::ArraySlice<int64> target_tensor_ids,
|
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||||
gtl::ArraySlice<int64> source_tensor_ids,
|
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||||
|
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
|
||||||
gtl::ArraySlice<Gradient*> output_gradients,
|
gtl::ArraySlice<Gradient*> output_gradients,
|
||||||
std::vector<Gradient*>* result) {
|
std::vector<Gradient*>* result) {
|
||||||
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
|
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
|
||||||
@ -478,7 +486,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
|||||||
std::vector<int64> op_stack =
|
std::vector<int64> op_stack =
|
||||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||||
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
|
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
|
||||||
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
Status s = InitialGradients(vspace, target_tensor_ids,
|
||||||
|
sources_that_are_targets, output_gradients,
|
||||||
tensor_tape_, state.op_tape, &gradients);
|
tensor_tape_, state.op_tape, &gradients);
|
||||||
auto cleanup = [this, &state]() {
|
auto cleanup = [this, &state]() {
|
||||||
if (!persistent_) {
|
if (!persistent_) {
|
||||||
|
143
tensorflow/c/kernels.cc
Normal file
143
tensorflow/c/kernels.cc
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/kernels.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
// This file forms the basis of a stable ABI for third-party kernel
|
||||||
|
// implementations. It is crucial that changes to this file are made cautiously
|
||||||
|
// and with a focus on maintaining both source and binary compatibility.
|
||||||
|
|
||||||
|
struct TF_KernelBuilder {
|
||||||
|
::tensorflow::KernelDefBuilder* cc_builder;
|
||||||
|
|
||||||
|
void* (*create_function)(TF_OpKernelConstruction*);
|
||||||
|
void (*compute_function)(void*, TF_OpKernelContext*);
|
||||||
|
void (*delete_function)(void*);
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_KernelBuilder* TF_NewKernelBuilder(
|
||||||
|
const char* op_name, const char* device_name,
|
||||||
|
void* (*create_func)(TF_OpKernelConstruction*),
|
||||||
|
void (*compute_func)(void*, TF_OpKernelContext*),
|
||||||
|
void (*delete_func)(void*)) {
|
||||||
|
TF_KernelBuilder* result = new TF_KernelBuilder;
|
||||||
|
result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
|
||||||
|
result->cc_builder->Device(device_name);
|
||||||
|
result->create_function = create_func;
|
||||||
|
result->compute_function = compute_func;
|
||||||
|
result->delete_function = delete_func;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
|
||||||
|
DCHECK_NE(builder, nullptr);
|
||||||
|
delete builder->cc_builder;
|
||||||
|
delete builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// An OpKernel whose methods delegate to C function pointers.
|
||||||
|
class COpKernel : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit COpKernel(OpKernelConstruction* ctx,
|
||||||
|
void* (*create_func)(TF_OpKernelConstruction*),
|
||||||
|
void (*compute_func)(void*, TF_OpKernelContext*),
|
||||||
|
void (*delete_func)(void*))
|
||||||
|
: OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
|
||||||
|
if (create_func != nullptr) {
|
||||||
|
c_kernel_ =
|
||||||
|
(*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
|
||||||
|
} else {
|
||||||
|
c_kernel_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
(*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
~COpKernel() override {
|
||||||
|
if (delete_func_ != nullptr) {
|
||||||
|
(*delete_func_)(c_kernel_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void (*compute_func_)(void*, TF_OpKernelContext* context);
|
||||||
|
void (*delete_func_)(void*);
|
||||||
|
void* c_kernel_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A KernelFactory that returns COpKernel instances.
|
||||||
|
class KernelBuilderFactory
|
||||||
|
: public ::tensorflow::kernel_factory::OpKernelFactory {
|
||||||
|
public:
|
||||||
|
explicit KernelBuilderFactory(TF_KernelBuilder* builder)
|
||||||
|
: builder_(builder) {}
|
||||||
|
::tensorflow::OpKernel* Create(
|
||||||
|
::tensorflow::OpKernelConstruction* context) override {
|
||||||
|
return new ::tensorflow::COpKernel(context, builder_->create_function,
|
||||||
|
builder_->compute_function,
|
||||||
|
builder_->delete_function);
|
||||||
|
}
|
||||||
|
~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_KernelBuilder* builder_;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
|
||||||
|
TF_Status* status) {
|
||||||
|
using tensorflow::register_kernel::Name;
|
||||||
|
|
||||||
|
tensorflow::kernel_factory::OpKernelRegistrar(
|
||||||
|
builder->cc_builder->Build(), name,
|
||||||
|
absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
|
||||||
|
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_NumInputs(TF_OpKernelContext* ctx) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||||
|
return cc_ctx->num_inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
int TF_NumOutputs(TF_OpKernelContext* ctx) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||||
|
return cc_ctx->num_outputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
||||||
|
TF_Status* status) {
|
||||||
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||||
|
if (i < 0 || i >= cc_ctx->num_inputs()) {
|
||||||
|
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
||||||
|
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
|
||||||
|
if (TF_GetCode(status) == TF_OK) {
|
||||||
|
*tensor = result;
|
||||||
|
}
|
||||||
|
}
|
110
tensorflow/c/kernels.h
Normal file
110
tensorflow/c/kernels.h
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_C_KERNELS_H_
|
||||||
|
#define TENSORFLOW_C_KERNELS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// C API for TensorFlow Kernels.
|
||||||
|
//
|
||||||
|
// This API allows developers to register custom kernel implementations for
|
||||||
|
// TensorFlow.
|
||||||
|
//
|
||||||
|
// See c_api.h header comments for a discussion about API conventions.
|
||||||
|
//
|
||||||
|
// Users wishing to extend TensorFlow with new kernels will call
|
||||||
|
// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
|
||||||
|
// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
|
||||||
|
// kernels when necessary.
|
||||||
|
|
||||||
|
struct TF_KernelBuilder;
|
||||||
|
struct TF_OpKernelConstruction;
|
||||||
|
struct TF_OpKernelContext;
|
||||||
|
|
||||||
|
// Allocates a new kernel builder and returns a pointer to it.
|
||||||
|
//
|
||||||
|
// If non-null, TensorFlow will call create_func when it needs to instantiate
|
||||||
|
// the kernel. The pointer returned by create_func will be passed to
|
||||||
|
// compute_func and delete_func, thereby functioning as a "this" pointer for
|
||||||
|
// referring to kernel instances.
|
||||||
|
//
|
||||||
|
// The TF_OpKernelConstruction pointer passed to create_func is owned by
|
||||||
|
// TensorFlow and will be deleted once create_func returns. It must not be used
|
||||||
|
// after this.
|
||||||
|
//
|
||||||
|
// When TensorFlow needs to perform a computation with this kernel, it will
|
||||||
|
// call compute_func. This function will receive the pointer returned by
|
||||||
|
// create_func (or null if no create_func was provided), along with the inputs
|
||||||
|
// to the computation.
|
||||||
|
//
|
||||||
|
// The TF_OpKernelContext pointer received by compute_func is owned by
|
||||||
|
// TensorFlow and will be deleted once compute_func returns. It must not be used
|
||||||
|
// after this.
|
||||||
|
//
|
||||||
|
// Finally, when TensorFlow no longer needs the kernel, it will call
|
||||||
|
// delete_func if one is provided. This function will receive the pointer
|
||||||
|
// returned in `create_func` or nullptr if no `create_func` was provided.
|
||||||
|
//
|
||||||
|
// The caller should pass the result of this function to
|
||||||
|
// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
|
||||||
|
// some reason, the kernel builder will not be registered, the caller should
|
||||||
|
// delete it with TF_DeleteKernelBuilder.
|
||||||
|
TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
|
||||||
|
const char* op_name, const char* device_name,
|
||||||
|
void* (*create_func)(TF_OpKernelConstruction*),
|
||||||
|
void (*compute_func)(void*, TF_OpKernelContext*),
|
||||||
|
void (*delete_func)(void*));
|
||||||
|
|
||||||
|
// Register the given kernel builder with the TensorFlow runtime. If
|
||||||
|
// registration fails, the given status will be populated.
|
||||||
|
//
|
||||||
|
// This call takes ownership of the `builder` pointer.
|
||||||
|
TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
|
||||||
|
TF_KernelBuilder* builder,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Deletes the given TF_KernelBuilder. This should be called only if the kernel
|
||||||
|
// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// OpKernelContext routines
|
||||||
|
|
||||||
|
// TF_NumInputs returns the number of inputs available in ctx.
|
||||||
|
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
|
// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
|
||||||
|
// kernel.
|
||||||
|
TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
|
||||||
|
|
||||||
|
// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
|
||||||
|
// populated and its ownership is passed to the caller. In any other case,
|
||||||
|
// *tensor is not modified.
|
||||||
|
//
|
||||||
|
// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
|
||||||
|
TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
|
||||||
|
TF_Tensor** tensor, TF_Status* status);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} /* end extern "C" */
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_KERNELS_H_
|
194
tensorflow/c/kernels_test.cc
Normal file
194
tensorflow/c/kernels_test.cc
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/kernels.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/c_api.h"
|
||||||
|
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
struct MyCustomKernel {
|
||||||
|
bool created;
|
||||||
|
bool compute_called;
|
||||||
|
};
|
||||||
|
|
||||||
|
static bool delete_called = false;
|
||||||
|
|
||||||
|
static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
|
||||||
|
struct MyCustomKernel* s = new struct MyCustomKernel;
|
||||||
|
s->created = true;
|
||||||
|
s->compute_called = false;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
|
||||||
|
s->compute_called = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void MyDeleteFunc(void* kernel) {
|
||||||
|
struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
|
||||||
|
EXPECT_TRUE(s->created);
|
||||||
|
EXPECT_TRUE(s->compute_called);
|
||||||
|
delete_called = true;
|
||||||
|
delete s;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
|
||||||
|
const char* op_name,
|
||||||
|
Status* status) {
|
||||||
|
NodeDef def;
|
||||||
|
def.set_op(op_name);
|
||||||
|
def.set_device(device_name);
|
||||||
|
def.add_input("input1");
|
||||||
|
def.add_input("input2");
|
||||||
|
return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
|
||||||
|
status);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests registration of a single C kernel and checks that calls through the
|
||||||
|
// C/C++ boundary are being made.
|
||||||
|
TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||||
|
const char* kernel_name = "SomeKernelName";
|
||||||
|
const char* op_name = "FooOp";
|
||||||
|
const char* device_name = "FakeDeviceName1";
|
||||||
|
|
||||||
|
REGISTER_OP(op_name)
|
||||||
|
.Input("input1: double")
|
||||||
|
.Input("input2: uint8")
|
||||||
|
.Output("output1: uint8");
|
||||||
|
|
||||||
|
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||||
|
op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
|
||||||
|
|
||||||
|
{
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||||
|
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||||
|
KernelList list;
|
||||||
|
list.ParseFromArray(buf->data, buf->length);
|
||||||
|
ASSERT_EQ(1, list.kernel_size());
|
||||||
|
ASSERT_EQ(device_name, list.kernel(0).device_type());
|
||||||
|
TF_DeleteBuffer(buf);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<OpKernel> kernel =
|
||||||
|
GetFakeKernel(device_name, op_name, &status);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
ASSERT_NE(nullptr, kernel.get());
|
||||||
|
kernel->Compute(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ASSERT_TRUE(delete_called);
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyDevice : public DeviceBase {
|
||||||
|
public:
|
||||||
|
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||||
|
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||||
|
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||||
|
return cpu_allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool save_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(TestKernel, TestInputAndOutputCount) {
|
||||||
|
const char* kernel_name = "InputOutputCounterKernel";
|
||||||
|
const char* op_name = "BarOp";
|
||||||
|
const char* device_name = "FakeDeviceName2";
|
||||||
|
|
||||||
|
REGISTER_OP(op_name)
|
||||||
|
.Input("input1: double")
|
||||||
|
.Input("input2: uint8")
|
||||||
|
.Output("output1: uint8");
|
||||||
|
|
||||||
|
static int num_inputs = 0;
|
||||||
|
static int num_outputs = 0;
|
||||||
|
|
||||||
|
// A kernel whose Compute function has a side-effect of updating num_inputs
|
||||||
|
// and num_outputs. Various functions on TF_OpKernelContext are also
|
||||||
|
// exercised.
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
num_inputs = TF_NumInputs(ctx);
|
||||||
|
num_outputs = TF_NumOutputs(ctx);
|
||||||
|
|
||||||
|
TF_Tensor* input = nullptr;
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_GetInput(ctx, 0, &input, s);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
|
||||||
|
EXPECT_EQ(123, *static_cast<tensorflow::uint8*>(TF_TensorData(input)));
|
||||||
|
TF_GetInput(ctx, -1, &input, s);
|
||||||
|
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
|
||||||
|
TF_GetInput(ctx, 3, &input, s);
|
||||||
|
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
if (input != nullptr) {
|
||||||
|
TF_DeleteTensor(input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
|
||||||
|
my_compute_func, nullptr);
|
||||||
|
|
||||||
|
{
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
OpKernelContext::Params p;
|
||||||
|
DummyDevice dummy_device(nullptr, false);
|
||||||
|
p.device = &dummy_device;
|
||||||
|
|
||||||
|
Tensor t(tensorflow::uint8(123));
|
||||||
|
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
// Simulate 2 inputs
|
||||||
|
inputs.emplace_back(&t);
|
||||||
|
inputs.emplace_back();
|
||||||
|
p.inputs = &inputs;
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<OpKernel> kernel =
|
||||||
|
GetFakeKernel(device_name, op_name, &status);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
ASSERT_NE(nullptr, kernel.get());
|
||||||
|
|
||||||
|
p.op_kernel = kernel.get();
|
||||||
|
OpKernelContext ctx(&p);
|
||||||
|
kernel->Compute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(2, num_inputs);
|
||||||
|
ASSERT_EQ(1, num_outputs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
|||||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
|
||||||
|
TF_Status* status) {
|
||||||
|
mutex_lock l(graph->mu);
|
||||||
|
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
|
||||||
|
new_src.index, &dst->node);
|
||||||
|
if (status->status.ok()) {
|
||||||
|
// This modification only updates the destination node for
|
||||||
|
// the purposes of running this graph in a session. Thus, we don't
|
||||||
|
// record the source node as being modified.
|
||||||
|
RecordMutation(graph, *dst, "adding input tensor");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
|
|||||||
|
|
||||||
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
|
||||||
|
|
||||||
|
// Updates 'dst' to consume 'new_src'.
|
||||||
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
|
|||||||
// because I couldn't get SWIG to work otherwise.
|
// because I couldn't get SWIG to work otherwise.
|
||||||
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
|
||||||
size_t proto_len, TF_Status* status);
|
size_t proto_len, TF_Status* status);
|
||||||
|
|
||||||
|
// This method is used to add a new input edge to 'dst', which must be a While
|
||||||
|
// op. The While op's "T" attribute must have already been updated to include
|
||||||
|
// the new edge. This is used to construct tf.while_loop gradients.
|
||||||
|
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
|
||||||
|
TF_Status* status);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_PYTHON_API_H_
|
#endif // TENSORFLOW_C_PYTHON_API_H_
|
||||||
|
@ -489,6 +489,7 @@ tf_gen_op_wrappers_cc(
|
|||||||
"image_ops",
|
"image_ops",
|
||||||
"io_ops",
|
"io_ops",
|
||||||
"linalg_ops",
|
"linalg_ops",
|
||||||
|
"list_ops",
|
||||||
"logging_ops",
|
"logging_ops",
|
||||||
"lookup_ops",
|
"lookup_ops",
|
||||||
"manip_ops",
|
"manip_ops",
|
||||||
|
@ -133,5 +133,6 @@ filegroup(
|
|||||||
"testdata/half_plus_two_pbtxt/**",
|
"testdata/half_plus_two_pbtxt/**",
|
||||||
"testdata/half_plus_two_main_op/**",
|
"testdata/half_plus_two_main_op/**",
|
||||||
"testdata/half_plus_two/**",
|
"testdata/half_plus_two/**",
|
||||||
|
"testdata/half_plus_two_v2/**",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
|||||||
/// SavedModel text format proto filename.
|
/// SavedModel text format proto filename.
|
||||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||||
|
|
||||||
/// SavedModel legacy init op key.
|
/// SavedModel legacy init op collection key. Used in v1 SavedModels.
|
||||||
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||||
|
|
||||||
/// SavedModel main op key.
|
/// SavedModel main op collection key. Used in v1 SavedModels.
|
||||||
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
||||||
|
|
||||||
/// Directory in which to save the SavedModel variables.
|
/// Directory in which to save the SavedModel variables.
|
||||||
@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
|
|||||||
/// SavedModel variables filename.
|
/// SavedModel variables filename.
|
||||||
constexpr char kSavedModelVariablesFilename[] = "variables";
|
constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||||
|
|
||||||
|
/// SavedModel SignatureDef keys for the initialization and train ops. Used in
|
||||||
|
/// V2 SavedModels.
|
||||||
|
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||||
|
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||||
|
@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options,
|
|||||||
return run_status;
|
return run_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
|
// RunInitOp will return OK if the initialization op was run successfully.
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
// An empty init_op_name indicates that there are no init ops to run.
|
||||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||||
collection_def_map.end()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
|
|
||||||
const MetaGraphDef& meta_graph_def,
|
const MetaGraphDef& meta_graph_def,
|
||||||
const std::vector<AssetFileDef>& asset_file_defs,
|
const std::vector<AssetFileDef>& asset_file_defs,
|
||||||
Session* session, const string& main_op_key) {
|
Session* session, const string& init_op_name) {
|
||||||
LOG(INFO) << "Running MainOp with key " << main_op_key
|
if (!init_op_name.empty()) {
|
||||||
<< " on SavedModel bundle.";
|
LOG(INFO) << "Running initialization op on SavedModel bundle.";
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
|
||||||
const auto main_op_it = collection_def_map.find(main_op_key);
|
|
||||||
if (main_op_it != collection_def_map.end()) {
|
|
||||||
if (main_op_it->second.node_list().value_size() != 1) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
|
||||||
}
|
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||||
RunMetadata run_metadata;
|
RunMetadata run_metadata;
|
||||||
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
|
return RunOnce(run_options, inputs, {}, {init_op_name},
|
||||||
return RunOnce(run_options, inputs, {}, {string(main_op_name)},
|
|
||||||
nullptr /* outputs */, &run_metadata, session);
|
nullptr /* outputs */, &run_metadata, session);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A SavedModel may store the name of the initialization op to run in the
|
||||||
|
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||||
|
// exists, then the collection must contain exactly one op.
|
||||||
|
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||||
|
string* init_op_name) {
|
||||||
|
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||||
|
const auto& init_op_sig_it =
|
||||||
|
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||||
|
if (init_op_sig_it != sig_def_map.end()) {
|
||||||
|
*init_op_name = init_op_sig_it->second.outputs()
|
||||||
|
.find(kSavedModelInitOpSignatureKey)
|
||||||
|
->second.name();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||||
|
string init_op_collection_key;
|
||||||
|
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||||
|
collection_def_map.end()) {
|
||||||
|
init_op_collection_key = kSavedModelMainOpKey;
|
||||||
|
} else {
|
||||||
|
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||||
|
if (init_op_it != collection_def_map.end()) {
|
||||||
|
if (init_op_it->second.node_list().value_size() != 1) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||||
|
}
|
||||||
|
*init_op_name = init_op_it->second.node_list().value(0);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||||
const StringPiece restore_op_name,
|
const StringPiece restore_op_name,
|
||||||
const StringPiece variable_filename_const_op_name,
|
const StringPiece variable_filename_const_op_name,
|
||||||
@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
|||||||
|
|
||||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||||
std::vector<AssetFileDef>* asset_file_defs) {
|
std::vector<AssetFileDef>* asset_file_defs) {
|
||||||
|
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||||
|
// collection, so read from metagraph first.
|
||||||
|
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||||
|
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||||
|
asset_file_defs->push_back(asset);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Fall back to read from collection to be backward compatible with v1.
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||||
if (assets_it == collection_def_map.end()) {
|
if (assets_it == collection_def_map.end()) {
|
||||||
@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
|||||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||||
asset_file_defs, bundle->session.get()));
|
asset_file_defs, bundle->session.get()));
|
||||||
if (HasMainOp(bundle->meta_graph_def)) {
|
string init_op_name;
|
||||||
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
|
TF_RETURN_IF_ERROR(
|
||||||
bundle->meta_graph_def, asset_file_defs,
|
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||||
bundle->session.get(), kSavedModelMainOpKey));
|
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||||
} else {
|
asset_file_defs, bundle->session.get(),
|
||||||
TF_RETURN_IF_ERROR(RunMainOp(
|
init_op_name));
|
||||||
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
|
|
||||||
bundle->session.get(), kSavedModelLegacyInitOpKey));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
|
|||||||
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
||||||
constexpr char kTestDataSharded[] =
|
constexpr char kTestDataSharded[] =
|
||||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||||
|
constexpr char kTestDataInitOpV2[] =
|
||||||
|
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||||
|
|
||||||
class LoaderTest : public ::testing::Test {
|
class LoaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
|
|||||||
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
|
||||||
|
SavedModelBundle bundle;
|
||||||
|
SessionOptions session_options;
|
||||||
|
RunOptions run_options;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
|
||||||
|
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||||
|
{kSavedModelTagServe}, &bundle));
|
||||||
|
CheckSavedModelBundle(export_dir, bundle);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
asset-file-contents
|
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate methods for args (inputs).
|
// Generate methods for args (inputs).
|
||||||
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
Status GenArgMethods(const tf2xla::Config& config,
|
||||||
|
const xla::ProgramShapeProto& ps,
|
||||||
const CompileResult& compile_result, string* methods) {
|
const CompileResult& compile_result, string* methods) {
|
||||||
size_t num_args = ps.parameters_size();
|
size_t num_args = ps.parameters_size();
|
||||||
if (config.feed_size() != num_args) {
|
if (config.feed_size() != num_args) {
|
||||||
@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < num_args; ++i) {
|
for (int i = 0; i < num_args; ++i) {
|
||||||
std::vector<std::pair<string, string>> rewrites;
|
std::vector<std::pair<string, string>> rewrites;
|
||||||
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
|
TF_RETURN_IF_ERROR(
|
||||||
|
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
|
||||||
const string code = R"(
|
const string code = R"(
|
||||||
void set_arg{{NAME}}_data(void* data) {
|
void set_arg{{NAME}}_data(const void* data) {
|
||||||
set_arg_data({{I}}, data);
|
set_arg_data({{I}}, data);
|
||||||
}
|
}
|
||||||
{{TYPE}}* arg{{NAME}}_data() {
|
{{TYPE}}* arg{{NAME}}_data() {
|
||||||
@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
|||||||
|
|
||||||
// Generate methods for results (outputs).
|
// Generate methods for results (outputs).
|
||||||
Status GenResultMethods(const tf2xla::Config& config,
|
Status GenResultMethods(const tf2xla::Config& config,
|
||||||
const xla::ProgramShape& ps, string* methods) {
|
const xla::ProgramShapeProto& ps, string* methods) {
|
||||||
if (ps.result().element_type() != xla::TUPLE) {
|
if (ps.result().element_type() != xla::TUPLE) {
|
||||||
// The XlaCompiler we use to build the xla computation always generates a
|
// The XlaCompiler we use to build the xla computation always generates a
|
||||||
// tuple result, and we rely on this to simplify code generation.
|
// tuple result, and we rely on this to simplify code generation.
|
||||||
@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
||||||
std::vector<std::pair<string, string>> rewrites;
|
std::vector<std::pair<string, string>> rewrites;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
||||||
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
|
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
|
||||||
string code = R"(
|
string code = R"(
|
||||||
{{TYPE}}* result{{NAME}}_data() {
|
{{TYPE}}* result{{NAME}}_data() {
|
||||||
return static_cast<{{TYPE}}*>(result_data({{I}}));
|
return static_cast<{{TYPE}}*>(result_data({{I}}));
|
||||||
@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
ExtractEntryParamBufferInfos(buffer_infos);
|
ExtractEntryParamBufferInfos(buffer_infos);
|
||||||
std::vector<BufferInfo> buffer_infos_for_temps =
|
std::vector<BufferInfo> buffer_infos_for_temps =
|
||||||
ExtractTempBufferInfos(buffer_infos);
|
ExtractTempBufferInfos(buffer_infos);
|
||||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
const xla::ProgramShapeProto& ps = compile_result.program_shape;
|
||||||
string methods_arg, methods_result;
|
string methods_arg, methods_result;
|
||||||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||||
@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||||
|
|
||||||
// Shape of the args and results.
|
// Shape of the args and results.
|
||||||
static const xla::ProgramShape* StaticProgramShape() {
|
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||||
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
||||||
return kShape;
|
return kShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
{"{{METHODS_RESULT}}\n", methods_result},
|
{"{{METHODS_RESULT}}\n", methods_result},
|
||||||
{"{{NS_END}}\n", ns_end},
|
{"{{NS_END}}\n", ns_end},
|
||||||
{"{{NS_START}}\n", ns_start},
|
{"{{NS_START}}\n", ns_start},
|
||||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
|
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||||
metadata_result.program_shape_access_shim},
|
metadata_result.program_shape_access_shim},
|
||||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||||
@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
|
|||||||
Status GenerateMetadata(const CodegenOpts& opts,
|
Status GenerateMetadata(const CodegenOpts& opts,
|
||||||
const CompileResult& compile_result,
|
const CompileResult& compile_result,
|
||||||
MetadataResult* metadata_result) {
|
MetadataResult* metadata_result) {
|
||||||
std::unique_ptr<xla::ProgramShape> program_shape;
|
std::unique_ptr<xla::ProgramShapeProto> program_shape;
|
||||||
|
|
||||||
if (opts.gen_program_shape) {
|
if (opts.gen_program_shape) {
|
||||||
program_shape =
|
program_shape =
|
||||||
absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
|
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
|
||||||
|
|
||||||
// The parameter names are currently meaningless, and redundant with the
|
// The parameter names are currently meaningless, and redundant with the
|
||||||
// rest of our metadata, so clear them out to avoid confusion and save
|
// rest of our metadata, so clear them out to avoid confusion and save
|
||||||
@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
|
|||||||
// a shim that evaluates to nullptr, which is what we want.
|
// a shim that evaluates to nullptr, which is what we want.
|
||||||
|
|
||||||
ProtobufToEmbed program_shape_protobuf{
|
ProtobufToEmbed program_shape_protobuf{
|
||||||
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
|
CreateUniqueIdentifier(opts, "ProgramShapeProto"),
|
||||||
program_shape.get()};
|
"xla::ProgramShapeProto", program_shape.get()};
|
||||||
|
|
||||||
ProtobufToEmbed hlo_profile_printer_data_protobuf{
|
ProtobufToEmbed hlo_profile_printer_data_protobuf{
|
||||||
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
|
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
|
||||||
|
@ -57,7 +57,7 @@ struct MetadataResult {
|
|||||||
std::vector<string> header_variable_decls;
|
std::vector<string> header_variable_decls;
|
||||||
|
|
||||||
// program_shape_access_shim is a C++ expression that constructs the
|
// program_shape_access_shim is a C++ expression that constructs the
|
||||||
// xla::ProgramShape instance for the CompileResult passed to
|
// xla::ProgramShapeProto instance for the CompileResult passed to
|
||||||
// GenerateMetadata.
|
// GenerateMetadata.
|
||||||
string program_shape_access_shim;
|
string program_shape_access_shim;
|
||||||
|
|
||||||
|
@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
|
|||||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||||
5, {}));
|
5, {}));
|
||||||
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
|
compile_result.program_shape =
|
||||||
{
|
xla::ShapeUtil::MakeProgramShape(
|
||||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
{
|
||||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||||
},
|
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||||
xla::ShapeUtil::MakeTupleShape(
|
},
|
||||||
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
|
xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
|
||||||
|
.ToProto();
|
||||||
compile_result.entry_point = "entry_point";
|
compile_result.entry_point = "entry_point";
|
||||||
compile_result.pointer_size = 8;
|
compile_result.pointer_size = 8;
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ extern "C" void entry_point(
|
|||||||
void* result, const xla::ExecutableRunOptions* run_options,
|
void* result, const xla::ExecutableRunOptions* run_options,
|
||||||
const void** args, void** temps, tensorflow::int64* profile_counters);
|
const void** args, void** temps, tensorflow::int64* profile_counters);
|
||||||
|
|
||||||
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
|
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
|
||||||
|
|
||||||
|
|
||||||
namespace foo {
|
namespace foo {
|
||||||
@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
// with dim indices specifying which value. No bounds checking is performed
|
// with dim indices specifying which value. No bounds checking is performed
|
||||||
// on dim indices.
|
// on dim indices.
|
||||||
|
|
||||||
void set_arg0_data(void* data) {
|
void set_arg0_data(const void* data) {
|
||||||
set_arg_data(0, data);
|
set_arg_data(0, data);
|
||||||
}
|
}
|
||||||
float* arg0_data() {
|
float* arg0_data() {
|
||||||
@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_arg_myfeed_data(void* data) {
|
void set_arg_myfeed_data(const void* data) {
|
||||||
set_arg_data(0, data);
|
set_arg_data(0, data);
|
||||||
}
|
}
|
||||||
float* arg_myfeed_data() {
|
float* arg_myfeed_data() {
|
||||||
@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
arg_data(0)))[dim0][dim1];
|
arg_data(0)))[dim0][dim1];
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_arg1_data(void* data) {
|
void set_arg1_data(const void* data) {
|
||||||
set_arg_data(1, data);
|
set_arg_data(1, data);
|
||||||
}
|
}
|
||||||
tensorflow::int64* arg1_data() {
|
tensorflow::int64* arg1_data() {
|
||||||
@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Shape of the args and results.
|
// Shape of the args and results.
|
||||||
static const xla::ProgramShape* StaticProgramShape() {
|
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||||
static const xla::ProgramShape* kShape = []() {
|
static const xla::ProgramShapeProto* kShape = []() {
|
||||||
xla::ProgramShape* proto = new xla::ProgramShape;
|
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
||||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
|
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
|
||||||
return proto;
|
return proto;
|
||||||
}();
|
}();
|
||||||
return kShape;
|
return kShape;
|
||||||
|
Binary file not shown.
@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
return errors::Unknown("Couldn't get XLA program shape: ",
|
return errors::Unknown("Couldn't get XLA program shape: ",
|
||||||
pshape_or.status().error_message());
|
pshape_or.status().error_message());
|
||||||
}
|
}
|
||||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
|
||||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
|
||||||
std::vector<const xla::Shape*> arg_layouts;
|
|
||||||
arg_layouts.reserve(pshape->parameters_size());
|
// AotXlaComputationInstance::argument_layouts is a vector of Shape
|
||||||
|
// pointers. Accumulate the Shape objects themselves in a separate vector
|
||||||
|
// while building the vector of pointers.
|
||||||
|
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
|
||||||
|
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
|
||||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
|
||||||
|
arg_layout_ptrs[i] = &arg_layouts[i];
|
||||||
}
|
}
|
||||||
xla::CompileOnlyClient::AotXlaComputationInstance instance;
|
xla::CompileOnlyClient::AotXlaComputationInstance instance;
|
||||||
instance.computation = &computation;
|
instance.computation = &computation;
|
||||||
instance.argument_layouts = std::move(arg_layouts);
|
instance.argument_layouts = std::move(arg_layout_ptrs);
|
||||||
instance.result_layout = &pshape->result();
|
xla::Shape result_shape(pshape->result());
|
||||||
|
instance.result_layout = &result_shape;
|
||||||
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
|
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
|
||||||
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
|
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
|
||||||
if (!aot_or.ok()) {
|
if (!aot_or.ok()) {
|
||||||
|
@ -33,9 +33,9 @@ namespace tfcompile {
|
|||||||
struct CompileResult {
|
struct CompileResult {
|
||||||
// Contains object file and meta-info.
|
// Contains object file and meta-info.
|
||||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||||
xla::ProgramShape program_shape; // Static shape of args and results.
|
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||||
string entry_point; // Name of generated function.
|
string entry_point; // Name of generated function.
|
||||||
int pointer_size = 0; // Size of a pointer in bytes.
|
int pointer_size = 0; // Size of a pointer in bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
// CompileGraph compiles the graph_def into an object file containing a function
|
// CompileGraph compiles the graph_def into an object file containing a function
|
||||||
|
@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) {
|
|||||||
|
|
||||||
// muladd has the program shape defined.
|
// muladd has the program shape defined.
|
||||||
MatMulAndAddComp muladd;
|
MatMulAndAddComp muladd;
|
||||||
const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
|
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
|
||||||
ASSERT_TRUE(muladd_shape != nullptr);
|
ASSERT_TRUE(muladd_shape != nullptr);
|
||||||
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
||||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
|
EXPECT_TRUE(
|
||||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
|
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
|
||||||
|
EXPECT_TRUE(
|
||||||
|
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
|
||||||
|
|
||||||
const xla::Shape& muladd_result = muladd_shape->result();
|
const xla::Shape muladd_result(muladd_shape->result());
|
||||||
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
|
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
|
||||||
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
|
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
|
||||||
const xla::Shape& muladd_result0 =
|
const xla::Shape& muladd_result0 =
|
||||||
|
@ -23,7 +23,6 @@ package(
|
|||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
|
||||||
|
|
||||||
@ -38,7 +37,7 @@ cc_library(
|
|||||||
":xla_cpu_device",
|
":xla_cpu_device",
|
||||||
":xla_cpu_jit",
|
":xla_cpu_jit",
|
||||||
"//tensorflow/compiler/plugin",
|
"//tensorflow/compiler/plugin",
|
||||||
] + if_cuda_is_configured([
|
] + if_cuda([
|
||||||
":xla_gpu_device",
|
":xla_gpu_device",
|
||||||
":xla_gpu_jit",
|
":xla_gpu_jit",
|
||||||
]),
|
]),
|
||||||
@ -51,6 +50,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
|
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
@ -76,10 +76,10 @@ cc_library(
|
|||||||
srcs = ["xla_cpu_device.cc"],
|
srcs = ["xla_cpu_device.cc"],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_device",
|
":xla_device",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
|
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||||
@ -210,6 +210,18 @@ cc_library(
|
|||||||
|
|
||||||
# Internal targets below this point.
|
# Internal targets below this point.
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flags",
|
||||||
|
srcs = ["flags.cc"],
|
||||||
|
hdrs = ["flags.h"],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:parse_flags_from_env",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "common",
|
name = "common",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -256,6 +268,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
@ -268,6 +281,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -487,6 +501,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":encapsulate_util",
|
":encapsulate_util",
|
||||||
|
":flags",
|
||||||
":shape_inference_helpers",
|
":shape_inference_helpers",
|
||||||
":union_find",
|
":union_find",
|
||||||
":xla_cluster_util",
|
":xla_cluster_util",
|
||||||
@ -494,8 +509,6 @@ cc_library(
|
|||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope_internal",
|
"//tensorflow/cc:scope_internal",
|
||||||
"//tensorflow/compiler/jit/graphcycles",
|
"//tensorflow/compiler/jit/graphcycles",
|
||||||
"//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags",
|
|
||||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
|
||||||
"//tensorflow/compiler/jit/ops:xla_ops",
|
"//tensorflow/compiler/jit/ops:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||||
@ -724,7 +737,10 @@ tf_custom_op_py_library(
|
|||||||
visibility = [
|
visibility = [
|
||||||
":friends",
|
":friends",
|
||||||
],
|
],
|
||||||
deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
|
deps = [
|
||||||
|
"//tensorflow/compiler/jit/ops:xla_ops_grad",
|
||||||
|
"//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/control_flow_ops.h"
|
#include "tensorflow/cc/ops/control_flow_ops.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
|
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||||
@ -320,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
|||||||
return IsXlaCompiledKernel(*n);
|
return IsXlaCompiledKernel(*n);
|
||||||
});
|
});
|
||||||
|
|
||||||
bool lazy_compilation_enabled = enable_lazy_compilation_
|
bool lazy_compilation_enabled =
|
||||||
? *enable_lazy_compilation_
|
enable_lazy_compilation_
|
||||||
: legacy_flags::GetBuildXlaOpsPassFlags()
|
? *enable_lazy_compilation_
|
||||||
.tf_xla_enable_lazy_compilation;
|
: GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
|
||||||
|
|
||||||
for (Node* n : xla_compiled_kernels) {
|
for (Node* n : xla_compiled_kernels) {
|
||||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||||
|
@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
|
|||||||
.ok());
|
.ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() override {
|
|
||||||
for (Device* device : devices_) {
|
|
||||||
delete device;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<Device*> devices_;
|
std::vector<std::unique_ptr<Device>> devices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ::tensorflow::testing::FindNodeByName;
|
using ::tensorflow::testing::FindNodeByName;
|
||||||
|
@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
auto* device_count = options.config.mutable_device_count();
|
auto* device_count = options.config.mutable_device_count();
|
||||||
device_count->insert({"CPU", 1});
|
device_count->insert({"CPU", 1});
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_CHECK_OK(DeviceFactory::AddDevices(
|
TF_CHECK_OK(DeviceFactory::AddDevices(
|
||||||
options, "/job:localhost/replica:0/task:0", &devices_));
|
options, "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
FunctionDefLibrary proto;
|
FunctionDefLibrary proto;
|
||||||
for (const auto& fdef : flib) {
|
for (const auto& fdef : flib) {
|
||||||
@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||||
OpRegistry::Global(), proto);
|
OpRegistry::Global(), proto);
|
||||||
OptimizerOptions opts;
|
OptimizerOptions opts;
|
||||||
device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
|
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||||
@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* flr_;
|
FunctionLibraryRuntime* flr_;
|
||||||
std::vector<Device*> devices_;
|
|
||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
|
@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
|
|||||||
continue;
|
continue;
|
||||||
} else if (src_xla_computation && !dst_xla_computation) {
|
} else if (src_xla_computation && !dst_xla_computation) {
|
||||||
if (src_outside_compilation) {
|
if (src_outside_compilation) {
|
||||||
// Case 1d: outside compilation to host computation control edge.
|
// Case 1c: outside compilation to host computation control edge.
|
||||||
edges_to_remove.push_back(e);
|
edges_to_remove.push_back(e);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
|
|||||||
}
|
}
|
||||||
} else if (!src_xla_computation && dst_xla_computation) {
|
} else if (!src_xla_computation && dst_xla_computation) {
|
||||||
if (dst_outside_compilation) {
|
if (dst_outside_compilation) {
|
||||||
// Case 1d: host computation control to outside compilation edge.
|
// Case 1c: host computation control to outside compilation edge.
|
||||||
edges_to_remove.push_back(e);
|
edges_to_remove.push_back(e);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
|
|||||||
} else { // src_xla_computation && dst_xla_computation
|
} else { // src_xla_computation && dst_xla_computation
|
||||||
if (*src_xla_computation != *dst_xla_computation) {
|
if (*src_xla_computation != *dst_xla_computation) {
|
||||||
if (src_outside_compilation && dst_outside_compilation) {
|
if (src_outside_compilation && dst_outside_compilation) {
|
||||||
// Case 1c: outside compilation to outside compilation control edge.
|
// Case 1b: outside compilation to outside compilation control edge.
|
||||||
edges_to_remove.push_back(e);
|
edges_to_remove.push_back(e);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
||||||
} else if (src_outside_compilation && !dst_outside_compilation) {
|
} else if (src_outside_compilation && !dst_outside_compilation) {
|
||||||
// Case 1b: outside compilation to another XLA computaition control
|
// Case 1a: outside compilation to another XLA computaition control
|
||||||
// edge.
|
// edge.
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
|
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
|
||||||
*dst_xla_computation));
|
*dst_xla_computation));
|
||||||
} else if (!src_outside_compilation && dst_outside_compilation) {
|
} else if (!src_outside_compilation && dst_outside_compilation) {
|
||||||
// Case 1b: another XLA computaition to outside compilation control
|
// Case 1a: another XLA computaition to outside compilation control
|
||||||
// edge.
|
// edge.
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
|
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
|
||||||
*src_xla_computation));
|
*src_xla_computation));
|
||||||
}
|
}
|
||||||
} else { // *src_xla_computation == *dst_xla_computation
|
|
||||||
if (src_outside_compilation && dst_outside_compilation) {
|
|
||||||
if (*src_outside_compilation != *dst_outside_compilation) {
|
|
||||||
// Case 1c: outside compilation to outside compilation control edge.
|
|
||||||
edges_to_remove.push_back(e);
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
|
||||||
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
|
|
||||||
}
|
|
||||||
} else if (src_outside_compilation && !dst_outside_compilation) {
|
|
||||||
// Case 1a: outside compilation to its XLA computation control edge.
|
|
||||||
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
|
|
||||||
} else if (!src_outside_compilation && dst_outside_compilation) {
|
|
||||||
// Case 1a: XLA computation to outside compilation in it control edge.
|
|
||||||
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g,
|
|||||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
||||||
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
|
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
|
||||||
}
|
}
|
||||||
} else { // *src_xla_computation == *dst_xla_computation
|
|
||||||
if (src_outside_compilation && dst_outside_compilation &&
|
|
||||||
*src_outside_compilation != *dst_outside_compilation) {
|
|
||||||
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
|
||||||
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
|
|
||||||
// Remove the edge from host to outside compilation. Add a placeholder as
|
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||||
// outside compilation node input.
|
// outside compilation node input.
|
||||||
std::map<string, Node*> placeholders;
|
std::map<std::pair<string, int>, Node*> placeholders;
|
||||||
for (int i = 0; i < edges.size(); i++) {
|
for (int i = 0; i < edges.size(); i++) {
|
||||||
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||||
const Edge* e;
|
const Edge* e;
|
||||||
@ -275,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
// Find or create placeholder node.
|
// Find or create placeholder node.
|
||||||
string new_name =
|
string new_name =
|
||||||
edges[i].is_host_to_outside_compilation
|
edges[i].is_host_to_outside_compilation
|
||||||
? absl::StrCat(src->name(), "_host_to_oc_placeholder")
|
? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
|
||||||
: absl::StrCat(src->name(), "_oc_to_host_placeholder");
|
: absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
|
||||||
auto iter = placeholders.find(new_name);
|
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||||
|
auto iter = placeholders.find(placeholder_index);
|
||||||
Node* placeholder_node;
|
Node* placeholder_node;
|
||||||
if (iter == placeholders.end()) {
|
if (iter == placeholders.end()) {
|
||||||
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||||
@ -310,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
|
|||||||
Status s;
|
Status s;
|
||||||
placeholder_node = g->AddNode(placeholder_def, &s);
|
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
placeholders[new_name] = placeholder_node;
|
placeholders[placeholder_index] = placeholder_node;
|
||||||
} else {
|
} else {
|
||||||
placeholder_node = iter->second;
|
placeholder_node = iter->second;
|
||||||
}
|
}
|
||||||
@ -594,14 +573,244 @@ Status AddControlDependencies(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
|
||||||
|
// `PreprocessEdgesBetweenOutsideCompilations` for details.
|
||||||
|
Status PreprocessControlEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
// Gather edges to remove. We should not remove the edge while iterating.
|
||||||
|
std::vector<const Edge*> edges_to_remove;
|
||||||
|
for (const Edge* e : g->edges()) {
|
||||||
|
if (!e->IsControlEdge()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_outside_compilation =
|
||||||
|
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||||
|
auto dst_outside_compilation =
|
||||||
|
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||||
|
|
||||||
|
if (src_outside_compilation && dst_outside_compilation) {
|
||||||
|
if (*src_outside_compilation != *dst_outside_compilation) {
|
||||||
|
// Case 1a: outside compilation to outside compilation control edge.
|
||||||
|
edges_to_remove.push_back(e);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
|
||||||
|
e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
|
||||||
|
e->src()->name()));
|
||||||
|
}
|
||||||
|
} else if (src_outside_compilation && !dst_outside_compilation) {
|
||||||
|
// Case 1b: outside compilation to its XLA computation control edge.
|
||||||
|
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
|
||||||
|
} else if (!src_outside_compilation && dst_outside_compilation) {
|
||||||
|
// Case 1b: XLA computation to outside compilation in it control edge.
|
||||||
|
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto e : edges_to_remove) {
|
||||||
|
g->RemoveEdge(e);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
|
||||||
|
// `PreprocessEdgesBetweenOutsideCompilations` for details.
|
||||||
|
Status PreprocessDataEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
// Gather edges between outside compilation and host computation. Notice that
|
||||||
|
// we do not store `Edge*` directly because we remove some nodes while adding
|
||||||
|
// Identity nodes, and those Edge pointers might be invalidated.
|
||||||
|
struct EdgeInfo {
|
||||||
|
int dst_input, dst_node_id;
|
||||||
|
};
|
||||||
|
std::vector<EdgeInfo> edges;
|
||||||
|
for (const Edge* e : g->edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto src_outside_compilation =
|
||||||
|
GetStringAttr(*e->src(), outside_compilation_attr_name);
|
||||||
|
auto dst_outside_compilation =
|
||||||
|
GetStringAttr(*e->dst(), outside_compilation_attr_name);
|
||||||
|
|
||||||
|
if (src_outside_compilation && dst_outside_compilation &&
|
||||||
|
*src_outside_compilation != *dst_outside_compilation) {
|
||||||
|
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
|
||||||
|
VLOG(4) << "Oc -> oc edge: " << e->DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the edge from host to outside compilation. Add a placeholder as
|
||||||
|
// outside compilation node input.
|
||||||
|
std::map<std::pair<string, int>, Node*> placeholders;
|
||||||
|
for (int i = 0; i < edges.size(); i++) {
|
||||||
|
Node* dst = g->FindNodeId(edges[i].dst_node_id);
|
||||||
|
const Edge* e;
|
||||||
|
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
|
||||||
|
Node* src = e->src();
|
||||||
|
int src_output = e->src_output(), dst_input = e->dst_input();
|
||||||
|
g->RemoveEdge(e);
|
||||||
|
|
||||||
|
// Find or create placeholder node.
|
||||||
|
string new_name =
|
||||||
|
absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
|
||||||
|
auto placeholder_index = std::make_pair(src->name(), src_output);
|
||||||
|
auto iter = placeholders.find(placeholder_index);
|
||||||
|
Node* placeholder_node;
|
||||||
|
if (iter == placeholders.end()) {
|
||||||
|
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
|
||||||
|
placeholder_builder.Attr("dtype", src->output_type(src_output));
|
||||||
|
string outside_compilation_attr;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
|
||||||
|
outside_compilation_attr_name,
|
||||||
|
&outside_compilation_attr));
|
||||||
|
placeholder_builder.Attr(outside_compilation_attr_name,
|
||||||
|
outside_compilation_attr);
|
||||||
|
placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
|
||||||
|
src->name());
|
||||||
|
placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
|
||||||
|
src_output);
|
||||||
|
NodeDef placeholder_def;
|
||||||
|
TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
|
||||||
|
Status s;
|
||||||
|
placeholder_node = g->AddNode(placeholder_def, &s);
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
placeholders[placeholder_index] = placeholder_node;
|
||||||
|
} else {
|
||||||
|
placeholder_node = iter->second;
|
||||||
|
}
|
||||||
|
g->AddEdge(placeholder_node, 0, dst, dst_input);
|
||||||
|
|
||||||
|
// Replace `e->dst()` because its input node changed.
|
||||||
|
NodeDef new_def = dst->def();
|
||||||
|
*new_def.mutable_input(dst_input) = placeholder_node->name();
|
||||||
|
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
|
||||||
|
|
||||||
|
// Other edge in `edges` might have `e->dst()` as src or dst
|
||||||
|
// node. Before removing `e->dst()`, replace those edges with
|
||||||
|
// corresponding edges for `dst_replace_node`.
|
||||||
|
for (int j = i + 1; j < edges.size(); j++) {
|
||||||
|
if (edges[j].dst_node_id == edges[i].dst_node_id) {
|
||||||
|
edges[j].dst_node_id = dst_replace_node->id();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
|
||||||
|
// `PostprocessEdgesBetweenOutsideCompilations` for details.
|
||||||
|
Status PostprocessDataEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
// Gather all outside compilation to outside compilation nodes.
|
||||||
|
std::vector<Node*> placeholder_nodes;
|
||||||
|
for (Node* n : g->nodes()) {
|
||||||
|
if (n->type_string() == "Placeholder" &&
|
||||||
|
HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
|
||||||
|
placeholder_nodes.push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the placeholder nodes, and reconnect original edge.
|
||||||
|
auto node_name_index = g->BuildNodeNameIndex();
|
||||||
|
for (auto n : placeholder_nodes) {
|
||||||
|
string node_name;
|
||||||
|
int node_src_output;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(
|
||||||
|
n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(
|
||||||
|
n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
|
||||||
|
auto iter = node_name_index.find(node_name);
|
||||||
|
if (iter == node_name_index.end()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Cannot find original node for oc -> host placeholder node ",
|
||||||
|
node_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change all usage node to use the original node instead.
|
||||||
|
Node* original_node = iter->second;
|
||||||
|
std::vector<const Edge*> control_edges;
|
||||||
|
std::vector<OutEdgeInfo> data_edges;
|
||||||
|
for (auto e : n->out_edges()) {
|
||||||
|
if (e->IsControlEdge()) {
|
||||||
|
control_edges.push_back(e);
|
||||||
|
} else {
|
||||||
|
data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const Edge* e : control_edges) {
|
||||||
|
g->AddControlEdge(original_node, e->dst());
|
||||||
|
g->RemoveEdge(e);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < data_edges.size(); i++) {
|
||||||
|
Node* dst = data_edges[i].dst;
|
||||||
|
NodeDef new_def = dst->def();
|
||||||
|
int dst_input = data_edges[i].dst_input;
|
||||||
|
*new_def.mutable_input(dst_input) =
|
||||||
|
absl::StrCat(original_node->name(), ":", node_src_output);
|
||||||
|
TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
|
||||||
|
|
||||||
|
const Edge* edge_to_replace = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
|
||||||
|
g->RemoveEdge(edge_to_replace);
|
||||||
|
g->AddEdge(original_node, node_src_output, replace_node, dst_input);
|
||||||
|
|
||||||
|
// Other edges might have `dst` as dst node. Update those edges with
|
||||||
|
// `replace_node`.
|
||||||
|
for (int j = i + 1; j < data_edges.size(); j++) {
|
||||||
|
if (data_edges[j].dst == dst) {
|
||||||
|
data_edges[j].dst = replace_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other placeholder node might have `dst` as original node. Update
|
||||||
|
// `node_name_index` with `replace_node`.
|
||||||
|
node_name_index[replace_node->name()] = replace_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove placeholder node.
|
||||||
|
g->RemoveNode(n);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
|
||||||
|
// `PostprocessEdgesBetweenOutsideCompilations` for details.
|
||||||
|
Status PostprocessControlEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
auto node_name_index = g->BuildNodeNameIndex();
|
||||||
|
|
||||||
|
// Reconnect outside compilation to outside compilation control edge.
|
||||||
|
for (Node* n : g->nodes()) {
|
||||||
|
std::vector<string> control_deps;
|
||||||
|
Status s =
|
||||||
|
GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
|
||||||
|
&control_deps);
|
||||||
|
if (!s.ok()) {
|
||||||
|
if (s.code() != error::NOT_FOUND) {
|
||||||
|
return s;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
|
||||||
|
for (const string& control_input : control_deps) {
|
||||||
|
auto iter = node_name_index.find(control_input);
|
||||||
|
if (iter == node_name_index.end()) {
|
||||||
|
return errors::Internal("Cannot find original node for ",
|
||||||
|
control_input);
|
||||||
|
}
|
||||||
|
g->AddControlEdge(iter->second, n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
|
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
|
||||||
|
|
||||||
const char kXlaConnectedToXlaComputationAttrName[] =
|
|
||||||
"_xla_connected_to_xla_computation";
|
|
||||||
const char kXlaConnectedFromXlaComputationAttrName[] =
|
|
||||||
"_xla_connected_from_xla_computation";
|
|
||||||
const char kXlaConnectedToOtherXlaComputationAttrName[] =
|
const char kXlaConnectedToOtherXlaComputationAttrName[] =
|
||||||
"_xla_connected_to_other_xla_computation";
|
"_xla_connected_to_other_xla_computation";
|
||||||
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
|
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
|
||||||
@ -616,6 +825,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] =
|
|||||||
"_xla_host_to_oc_node_name";
|
"_xla_host_to_oc_node_name";
|
||||||
const char kHostToOutsideCompilationSrcOutputAttrName[] =
|
const char kHostToOutsideCompilationSrcOutputAttrName[] =
|
||||||
"_xla_host_to_oc_src_output";
|
"_xla_host_to_oc_src_output";
|
||||||
|
const char kXlaConnectedToXlaComputationAttrName[] =
|
||||||
|
"_xla_connected_to_xla_computation";
|
||||||
|
const char kXlaConnectedFromXlaComputationAttrName[] =
|
||||||
|
"_xla_connected_from_xla_computation";
|
||||||
|
const char kOutsideCompilationOriginalNodeAttrName[] =
|
||||||
|
"_xla_oc_to_oc_node_name";
|
||||||
|
const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
|
||||||
|
const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
|
||||||
|
"_xla_control_dependencies_within_xla_cluster";
|
||||||
|
|
||||||
Status PerformStaticShapeInferenceBeforeEncapsulation(
|
Status PerformStaticShapeInferenceBeforeEncapsulation(
|
||||||
Graph* g, const string& xla_computation_attr_name,
|
Graph* g, const string& xla_computation_attr_name,
|
||||||
@ -699,4 +917,39 @@ Status PostprocessForEncapsulation(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status PreprocessEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
// Remove edges from source node to outside compilation nodes, and edges
|
||||||
|
// from outside compilation nodes to sink node.
|
||||||
|
std::vector<const Edge*> edges_to_remove;
|
||||||
|
for (const Edge* e : g->source_node()->out_edges()) {
|
||||||
|
if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
|
||||||
|
edges_to_remove.push_back(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const Edge* e : g->sink_node()->in_edges()) {
|
||||||
|
if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
|
||||||
|
edges_to_remove.push_back(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto e : edges_to_remove) {
|
||||||
|
g->RemoveEdge(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
|
||||||
|
g, outside_compilation_attr_name));
|
||||||
|
TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
|
||||||
|
g, outside_compilation_attr_name));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PostprocessEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name) {
|
||||||
|
TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
|
||||||
|
g, outside_compilation_attr_name));
|
||||||
|
TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
|
||||||
|
g, outside_compilation_attr_name));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
|
|||||||
Graph* g, const string& xla_computation_attr_name,
|
Graph* g, const string& xla_computation_attr_name,
|
||||||
const string& outside_compilation_attr_name);
|
const string& outside_compilation_attr_name);
|
||||||
|
|
||||||
// Attribute indicating that some ops in this node's XLA computation has control
|
|
||||||
// dependency on this node. Attribute value will always be "true".
|
|
||||||
extern const char kXlaConnectedToXlaComputationAttrName[];
|
|
||||||
|
|
||||||
// Attribute indicating that this node has control dependency on some ops in
|
|
||||||
// this node's XLA computation. Attribute value will always be "true".
|
|
||||||
extern const char kXlaConnectedFromXlaComputationAttrName[];
|
|
||||||
|
|
||||||
// Attribute indicating that some ops in other XLA computation has control
|
// Attribute indicating that some ops in other XLA computation has control
|
||||||
// dependency on this node. Attribute value will be a list of string (XLA
|
// dependency on this node. Attribute value will be a list of string (XLA
|
||||||
// computation names).
|
// computation names).
|
||||||
@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
|
|||||||
// int (src_output for original edge).
|
// int (src_output for original edge).
|
||||||
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
|
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
|
||||||
|
|
||||||
|
// Attribute indicating that some ops in this node's XLA computation has control
|
||||||
|
// dependency on this node. Attribute value will always be "true".
|
||||||
|
extern const char kXlaConnectedToXlaComputationAttrName[];
|
||||||
|
|
||||||
|
// Attribute indicating that this node has control dependency on some ops in
|
||||||
|
// this node's XLA computation. Attribute value will always be "true".
|
||||||
|
extern const char kXlaConnectedFromXlaComputationAttrName[];
|
||||||
|
|
||||||
// Attribute indicating that this is an Placeholder node added to act as a
|
// Attribute indicating that this is an Placeholder node added to act as a
|
||||||
// temporary input node for an host node. Attribute value will be string
|
// temporary input node for an host node. Attribute value will be string
|
||||||
// (original input node name).
|
// (original input node name).
|
||||||
@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
|
|||||||
// for original edge).
|
// for original edge).
|
||||||
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
|
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
|
||||||
|
|
||||||
// Preprocesses the graph for encapsulation. It will perform the following
|
// Attribute indicating that this is an Placeholder node added to act as a
|
||||||
// operations in order:
|
// temporary input node for an outside compilation node. Attribute value will be
|
||||||
|
// string (original input node name).
|
||||||
|
extern const char kOutsideCompilationOriginalNodeAttrName[];
|
||||||
|
|
||||||
|
// Attribute indicating that this is an Placeholder node added to act as a
|
||||||
|
// temporary input node for an outside compilation node. Attribute value will be
|
||||||
|
// int (src_output for original edge).
|
||||||
|
extern const char kOutsideCompilationSrcOutputAttrName[];
|
||||||
|
|
||||||
|
// Attribute indicating that this node has control dependencies on some other
|
||||||
|
// nodes within the same XLA cluster. Attribute value will be a list of string
|
||||||
|
// (node names).
|
||||||
|
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
|
||||||
|
|
||||||
|
// Preprocesses edges between different XLA clusters for encapsulation. It will
|
||||||
|
// perform the following operations in order:
|
||||||
//
|
//
|
||||||
// 1a. For control edges between outside compilation and its XLA computation,
|
// 1a. For control edges between outside compilation and another XLA
|
||||||
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
|
|
||||||
// outside compilation node.
|
|
||||||
// 1b. For control edges between outside compilation and another XLA
|
|
||||||
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
|
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
|
||||||
// = XLA computation node name" to the outside compilation node.
|
// = XLA computation node name" to the outside compilation node.
|
||||||
// 1c. For control edges between different outside compilations, remove the edge
|
// 1b. For control edges between different outside compilations (in different
|
||||||
// and add attr "kXlaControlDependenciesAttrName = src node name" to dst
|
// XLA computations), remove the edge and add attr
|
||||||
// node.
|
// "kXlaControlDependenciesAttrName = src node name" to dst node.
|
||||||
// 1d. For control edges between outside compilation and host computation,
|
// 1c. For control edges between outside compilation and host computation,
|
||||||
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
|
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
|
||||||
// name" to dst node.
|
// name" to dst node.
|
||||||
// 2. For data edges between different XLA computations, if either src or dst
|
// 2. For data edges between different XLA computations, if either src or dst
|
||||||
@ -146,26 +158,53 @@ struct XlaClusterInfo {
|
|||||||
const std::map<string, int> host_compute_core;
|
const std::map<string, int> host_compute_core;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Postprocesses the graph for encapsulation. This function reverts what
|
// Postprocesses edges between different XLA clusters for encapsulation. This
|
||||||
// `PreprocessForEncapsulation` did. It will perform the following operations in
|
// function reverts what `PreprocessForEncapsulation` did. It will perform the
|
||||||
// order:
|
// following operations in order:
|
||||||
//
|
//
|
||||||
// 1. Remove Placeholder nodes between outside compilation and host computation
|
// 1. Remove Placeholder nodes between outside compilation and host computation
|
||||||
// (created in `PreprocessForEncapsulation` step 3).
|
// (created in `PreprocessForEncapsulation` step 3).
|
||||||
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
|
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
|
||||||
// 3a. Reconnect control edges between different outside compilations (marked by
|
// 3a. Reconnect control edges between outside compilation and another XLA
|
||||||
// `PreprocessForEncapsulation` step 1c) and control edges between outside
|
// computation (marked by `PreprocessForEncapsulation` step 1a).
|
||||||
// compilation and host computation (marked by `PreprocessForEncapsulation`
|
// 3b. Reconnect control edges between different outside compilations (marked by
|
||||||
// step 1d).
|
// `PreprocessForEncapsulation` step 1b).
|
||||||
// 3b. Reconnect control edges between outside compilation and another XLA
|
// 3c. Reconnect control edges between outside compilation and host computation
|
||||||
// computation (marked by `PreprocessForEncapsulation` step 1b).
|
// (marked by `PreprocessForEncapsulation` step 1c).
|
||||||
// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are
|
|
||||||
// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`.
|
|
||||||
Status PostprocessForEncapsulation(
|
Status PostprocessForEncapsulation(
|
||||||
Graph* g, const string& xla_computation_attr_name,
|
Graph* g, const string& xla_computation_attr_name,
|
||||||
const string& outside_compilation_attr_name,
|
const string& outside_compilation_attr_name,
|
||||||
const std::unordered_map<string, XlaClusterInfo>& clusters);
|
const std::unordered_map<string, XlaClusterInfo>& clusters);
|
||||||
|
|
||||||
|
// Preprocesses edges within the same XLA cluster. It will perform the following
|
||||||
|
// operations in order:
|
||||||
|
//
|
||||||
|
// 0. Remove edges from source node to outside compilation nodes, and edges
|
||||||
|
// from outside compilation nodes to sink node.
|
||||||
|
// 1a. For edges between different outside compilation clusters, remove the edge
|
||||||
|
// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
|
||||||
|
// name" to dst node.
|
||||||
|
// 1b. For control edges between outside compilation and its XLA computation,
|
||||||
|
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
|
||||||
|
// outside compilation node.
|
||||||
|
// 2. For data edges between different outside compilations, remove the edge
|
||||||
|
// and create a Placeholder node as dst node's input.
|
||||||
|
Status PreprocessEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name);
|
||||||
|
|
||||||
|
// Postprocesses edges within the same XLA cluster. This function reverts what
|
||||||
|
// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
|
||||||
|
// following operations in order:
|
||||||
|
//
|
||||||
|
// 1. Remove Placeholder nodes between different outside compilations (created
|
||||||
|
// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
|
||||||
|
// 2a. Reconnect control edges between different outside compilations (marked by
|
||||||
|
// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
|
||||||
|
// Notice that control edges marked by
|
||||||
|
// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
|
||||||
|
// They are handled in `RewriteOutsideCompilationSubgraphFn`.
|
||||||
|
Status PostprocessEdgesBetweenOutsideCompilations(
|
||||||
|
Graph* g, const string& outside_compilation_attr_name);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
|
||||||
|
@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
|||||||
identity4_node->AddAttr("_xla", "1");
|
identity4_node->AddAttr("_xla", "1");
|
||||||
identity4_node->AddAttr("_oc", "0");
|
identity4_node->AddAttr("_oc", "0");
|
||||||
identity5_node->AddAttr("_xla", "1");
|
identity5_node->AddAttr("_xla", "1");
|
||||||
// Case 1a: control edges between outside compilation and its XLA computation.
|
// Case 1a: control edges between outside compilation and another XLA
|
||||||
g.AddControlEdge(add_node, identity0_node);
|
|
||||||
g.AddControlEdge(identity0_node, identity1_node);
|
|
||||||
// Case 1b: control edges between outside compilation and another XLA
|
|
||||||
// computation.
|
// computation.
|
||||||
g.AddControlEdge(identity0_node, identity3_node);
|
g.AddControlEdge(identity0_node, identity3_node);
|
||||||
g.AddControlEdge(identity1_node, identity4_node);
|
g.AddControlEdge(identity1_node, identity4_node);
|
||||||
// Case 1c: control edges between different outside compilations.
|
// Case 1b: control edges between different outside compilations.
|
||||||
g.AddControlEdge(identity0_node, identity4_node);
|
g.AddControlEdge(identity0_node, identity4_node);
|
||||||
// Case 1d: control edges between outside compilation and host computation.
|
// Case 1c: control edges between outside compilation and host computation.
|
||||||
g.AddControlEdge(const0_node, identity0_node);
|
g.AddControlEdge(const0_node, identity0_node);
|
||||||
g.AddControlEdge(identity0_node, identity2_node);
|
g.AddControlEdge(identity0_node, identity2_node);
|
||||||
|
|
||||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||||
|
|
||||||
// Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the
|
// Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
|
||||||
// outside compilation node.
|
|
||||||
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
|
|
||||||
kXlaConnectedFromXlaComputationAttrName));
|
|
||||||
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
|
|
||||||
kXlaConnectedToXlaComputationAttrName));
|
|
||||||
// Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
|
|
||||||
// to the outside compilation node.
|
// to the outside compilation node.
|
||||||
std::vector<string> attr;
|
std::vector<string> attr;
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
||||||
@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
|||||||
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
|
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
|
||||||
EXPECT_EQ(attr.size(), 1);
|
EXPECT_EQ(attr.size(), 1);
|
||||||
EXPECT_EQ(attr[0], "0");
|
EXPECT_EQ(attr[0], "0");
|
||||||
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
|
// Case 1b: add attr "_xla_control_deps = src node name" to dst node.
|
||||||
attr.clear();
|
attr.clear();
|
||||||
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
|
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
|
||||||
kXlaControlDependenciesAttrName, &attr));
|
kXlaControlDependenciesAttrName, &attr));
|
||||||
EXPECT_EQ(attr.size(), 1);
|
EXPECT_EQ(attr.size(), 1);
|
||||||
EXPECT_EQ(attr[0], "identity0");
|
EXPECT_EQ(attr[0], "identity0");
|
||||||
// Case 1d: add attr "_xla_control_deps = src node name" to dst node.
|
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
|
||||||
attr.clear();
|
attr.clear();
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
|
||||||
kXlaControlDependenciesAttrName, &attr));
|
kXlaControlDependenciesAttrName, &attr));
|
||||||
@ -162,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
|
|||||||
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
TEST(PreprocessForEncapsulationTest, DataEdges) {
|
||||||
// Build the graph:
|
// Build the graph:
|
||||||
// "const_0" and "const_1" in host computation
|
// "const_0" and "const_1" in host computation
|
||||||
|
// "identityn0" = ("const_0", "const_1") in host computation 0
|
||||||
// "add0" = "const_0" + "const_1" in XLA computation 0
|
// "add0" = "const_0" + "const_1" in XLA computation 0
|
||||||
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
|
||||||
// "identity0" = "add1" in XLA computation 0
|
// "identity0" = "add1" in XLA computation 0
|
||||||
// "add2" = "add1" + "identity0" in host computation
|
// "add2" = "add1" + "identity0" in host computation
|
||||||
// "add3" = "add1" + "add2" in XLA computation 1
|
// "add3" = "add1" + "add2" in XLA computation 1
|
||||||
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1
|
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
|
||||||
|
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
|
||||||
|
// outside compilation 0
|
||||||
|
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
|
||||||
|
// outside compilation 0
|
||||||
// "identity1" = "add4" in XLA computation 1
|
// "identity1" = "add4" in XLA computation 1
|
||||||
// "identity2" = "identity1" in host computation
|
// "identity2" = "identity1" in host computation
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
|
||||||
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
|
||||||
|
auto identityn0 =
|
||||||
|
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
|
||||||
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
|
||||||
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
|
||||||
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
|
||||||
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
|
||||||
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
|
||||||
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
|
||||||
|
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
|
||||||
|
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
|
||||||
|
{identityn0[0], identityn0[1]});
|
||||||
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
|
||||||
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
@ -189,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
|
||||||
*identity0_node = node_index["identity0"],
|
*identity0_node = node_index["identity0"],
|
||||||
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
|
||||||
|
*add5_node = node_index["add5"],
|
||||||
|
*identityn1_node = node_index["identityn_1"],
|
||||||
*identity1_node = node_index["identity1"];
|
*identity1_node = node_index["identity1"];
|
||||||
add0_node->AddAttr("_xla", "0");
|
add0_node->AddAttr("_xla", "0");
|
||||||
add1_node->AddAttr("_xla", "0");
|
add1_node->AddAttr("_xla", "0");
|
||||||
@ -197,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
add3_node->AddAttr("_xla", "1");
|
add3_node->AddAttr("_xla", "1");
|
||||||
add4_node->AddAttr("_xla", "1");
|
add4_node->AddAttr("_xla", "1");
|
||||||
add4_node->AddAttr("_oc", "0");
|
add4_node->AddAttr("_oc", "0");
|
||||||
|
add5_node->AddAttr("_xla", "1");
|
||||||
|
add5_node->AddAttr("_oc", "0");
|
||||||
|
identityn1_node->AddAttr("_xla", "1");
|
||||||
|
identityn1_node->AddAttr("_oc", "0");
|
||||||
identity1_node->AddAttr("_xla", "1");
|
identity1_node->AddAttr("_xla", "1");
|
||||||
|
|
||||||
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
|
||||||
@ -214,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
EXPECT_NE(bridge_identity0_add4, nullptr);
|
EXPECT_NE(bridge_identity0_add4, nullptr);
|
||||||
// Step 3: add placeholder for edges between host computation and outside
|
// Step 3: add placeholder for edges between host computation and outside
|
||||||
// compilation.
|
// compilation.
|
||||||
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder");
|
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
|
||||||
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"];
|
Node *add1_oc_to_host_placeholder =
|
||||||
|
node_index["add1_oc_to_host_placeholder_0"];
|
||||||
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
|
||||||
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
kOutsideCompilationToHostOriginalNodeAttrName, &str));
|
||||||
EXPECT_EQ(str, "add1");
|
EXPECT_EQ(str, "add1");
|
||||||
@ -226,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
|
|||||||
add4_node = node_index["add4"];
|
add4_node = node_index["add4"];
|
||||||
ASSERT_NE(add4_node, nullptr);
|
ASSERT_NE(add4_node, nullptr);
|
||||||
EXPECT_EQ(add4_node->def().input(0),
|
EXPECT_EQ(add4_node->def().input(0),
|
||||||
"bridge_identity0_add4_host_to_oc_placeholder");
|
"bridge_identity0_add4_host_to_oc_placeholder_0");
|
||||||
Node *identity0_host_to_oc_placeholder =
|
Node *identity0_host_to_oc_placeholder =
|
||||||
node_index["bridge_identity0_add4_host_to_oc_placeholder"];
|
node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||||
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
kHostToOutsideCompilationOriginalNodeAttrName, &str));
|
||||||
EXPECT_EQ(str, "bridge_identity0_add4");
|
EXPECT_EQ(str, "bridge_identity0_add4");
|
||||||
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
|
||||||
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
kHostToOutsideCompilationSrcOutputAttrName, &i));
|
||||||
EXPECT_EQ(i, 0);
|
EXPECT_EQ(i, 0);
|
||||||
|
|
||||||
|
// Check different placeholder nodes are created for different src_output.
|
||||||
|
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
|
||||||
|
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
|
||||||
|
EXPECT_NE(placeholder0, nullptr);
|
||||||
|
EXPECT_NE(placeholder1, nullptr);
|
||||||
|
// Check we only have 2 placeholder nodes created for "identityn_0".
|
||||||
|
int placeholder_count = 0;
|
||||||
|
for (Node *n : g.nodes()) {
|
||||||
|
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
|
||||||
|
string attr;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(
|
||||||
|
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
|
||||||
|
if (attr == "identityn_0") {
|
||||||
|
++placeholder_count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(placeholder_count, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PostprocessForEncapsulationTest, ControlEdges) {
|
TEST(PostprocessForEncapsulationTest, ControlEdges) {
|
||||||
|
@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
|||||||
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
|
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
|
||||||
e->dst()->type_string() != kXlaClusterOutput) {
|
e->dst()->type_string() != kXlaClusterOutput) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Undeclared output of XLA computation. A common cause of this error "
|
"Undeclared output of XLA computation. Some common causes of this "
|
||||||
"is variable initializers that depend on the XLA computation. Edge: ",
|
"error are: 1) variable initializers that depend on the XLA "
|
||||||
|
"computation; 2) gradient computations that depend on the XLA "
|
||||||
|
"computation, which can be mitigated by moving gradient computations "
|
||||||
|
"inside XLA computation. Offending edge: ",
|
||||||
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
|
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
|
||||||
e->dst_input());
|
e->dst_input());
|
||||||
}
|
}
|
||||||
|
@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
|
|||||||
// replace this node with compilation result node.
|
// replace this node with compilation result node.
|
||||||
// 3) all outside compilation graphs.
|
// 3) all outside compilation graphs.
|
||||||
Status ConstructHostGraph(
|
Status ConstructHostGraph(
|
||||||
const string& xla_cluster_name,
|
const string& xla_cluster_name, const string& outside_compilation_attr_name,
|
||||||
const std::vector<string>& outside_compilation_host_graphs,
|
const std::vector<string>& outside_compilation_host_graphs,
|
||||||
FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
|
FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
|
||||||
host_graph->reset(new Graph(fld));
|
host_graph->reset(new Graph(fld));
|
||||||
@ -476,6 +476,10 @@ Status ConstructHostGraph(
|
|||||||
host_graph->get(),
|
host_graph->get(),
|
||||||
std::unordered_set<const Node*>{(*host_graph)->sink_node()});
|
std::unordered_set<const Node*>{(*host_graph)->sink_node()});
|
||||||
|
|
||||||
|
// Postprocess edges between different outside compilations.
|
||||||
|
TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
|
||||||
|
host_graph->get(), outside_compilation_attr_name));
|
||||||
|
|
||||||
if (VLOG_IS_ON(4)) {
|
if (VLOG_IS_ON(4)) {
|
||||||
dump_graph::DumpGraphToFile(
|
dump_graph::DumpGraphToFile(
|
||||||
absl::StrCat("extract_outside_compilation_host_graph_for_",
|
absl::StrCat("extract_outside_compilation_host_graph_for_",
|
||||||
@ -801,6 +805,11 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
},
|
},
|
||||||
&fbody));
|
&fbody));
|
||||||
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
|
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
|
||||||
|
|
||||||
|
// Preprocess edges between different outside compilations. They will be
|
||||||
|
// restored in `ConstructHostGraph()`.
|
||||||
|
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
|
||||||
|
fbody->graph, outside_compilation_attr_name));
|
||||||
if (VLOG_IS_ON(4)) {
|
if (VLOG_IS_ON(4)) {
|
||||||
dump_graph::DumpGraphToFile(
|
dump_graph::DumpGraphToFile(
|
||||||
absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
|
absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
|
||||||
@ -860,8 +869,9 @@ Status ExtractOutsideCompilationForFunction(
|
|||||||
|
|
||||||
// Construct host graph.
|
// Construct host graph.
|
||||||
if (!outside_compilation_host_graphs.empty()) {
|
if (!outside_compilation_host_graphs.empty()) {
|
||||||
TF_RETURN_IF_ERROR(ConstructHostGraph(
|
TF_RETURN_IF_ERROR(
|
||||||
xla_cluster_name, outside_compilation_host_graphs, fld, host_graph));
|
ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
|
||||||
|
outside_compilation_host_graphs, fld, host_graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the outside compilation graphs from function library.
|
// Remove the outside compilation graphs from function library.
|
||||||
|
@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
|
|||||||
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
|
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
|
||||||
EXPECT_EQ(shapes.size(), 1);
|
EXPECT_EQ(shapes.size(), 1);
|
||||||
EXPECT_EQ(shapes[0].dim_size(), 1);
|
EXPECT_EQ(shapes[0].dim_size(), 1);
|
||||||
// Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a
|
// Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
|
||||||
// non-empty value, and "1" should have an empty value.
|
// empty values.
|
||||||
string shape_inference_graph;
|
string shape_inference_graph;
|
||||||
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
|
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
|
||||||
&shape_inference_graph));
|
&shape_inference_graph));
|
||||||
EXPECT_EQ(shape_inference_graph,
|
EXPECT_EQ(shape_inference_graph, "");
|
||||||
"_outside_compilation_shape_inference_cluster_0");
|
|
||||||
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
|
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
|
||||||
&shape_inference_graph));
|
&shape_inference_graph));
|
||||||
EXPECT_EQ(shape_inference_graph, "");
|
EXPECT_EQ(shape_inference_graph, "");
|
||||||
|
|
||||||
// Check `shape_inference_graphs`.
|
// Check `shape_inference_graphs`.
|
||||||
EXPECT_EQ(shape_inference_graphs.size(), 1);
|
EXPECT_EQ(shape_inference_graphs.size(), 0);
|
||||||
EXPECT_EQ(shape_inference_graphs[0],
|
|
||||||
"_outside_compilation_shape_inference_cluster_0");
|
|
||||||
|
|
||||||
// Check `host_graph`: verify we have key placeholder and sequencer.
|
// Check `host_graph`: verify we have key placeholder and sequencer.
|
||||||
Node *key_placeholder = nullptr, *sequencer = nullptr;
|
Node *key_placeholder = nullptr, *sequencer = nullptr;
|
||||||
@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
|
|||||||
send_recv_nodes.push_back(n);
|
send_recv_nodes.push_back(n);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_EQ(num_send_from_host, 2);
|
EXPECT_EQ(num_send_from_host, 1);
|
||||||
EXPECT_EQ(num_recv_at_host, 2);
|
EXPECT_EQ(num_recv_at_host, 1);
|
||||||
for (Node *n : send_recv_nodes) {
|
for (Node *n : send_recv_nodes) {
|
||||||
Node *input_node;
|
Node *input_node;
|
||||||
TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
|
TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
|
||||||
|
152
tensorflow/compiler/jit/flags.cc
Normal file
152
tensorflow/compiler/jit/flags.cc
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
BuildXlaOpsPassFlags* build_ops_flags;
|
||||||
|
DumpGraphFlags* dump_graph_flags;
|
||||||
|
MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||||
|
XlaDeviceFlags* device_flags;
|
||||||
|
XlaOpsCommonFlags* ops_flags;
|
||||||
|
|
||||||
|
std::vector<Flag>* flag_list;
|
||||||
|
std::once_flag flags_init;
|
||||||
|
|
||||||
|
void AppendDumpGraphFlagsInternal(std::vector<Flag>* flag_list) {
|
||||||
|
std::vector<Flag> new_flags = {
|
||||||
|
Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix,
|
||||||
|
"Path prefix to which graphs dumped during debugging should be "
|
||||||
|
"written."),
|
||||||
|
};
|
||||||
|
flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
|
||||||
|
std::vector<Flag> new_flags = {
|
||||||
|
Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
|
||||||
|
"Control compilation of operators into XLA computations on CPU and "
|
||||||
|
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
||||||
|
"things very likely to be improved; 2 = on for everything. "
|
||||||
|
"Experimental."),
|
||||||
|
Flag("tf_xla_min_cluster_size",
|
||||||
|
&mark_for_compilation_flags->tf_xla_min_cluster_size,
|
||||||
|
"Minimum number of operators in an XLA compilation. Ignored for "
|
||||||
|
"operators placed on an XLA device or operators explicitly marked "
|
||||||
|
"for compilation."),
|
||||||
|
Flag("tf_xla_max_cluster_size",
|
||||||
|
&mark_for_compilation_flags->tf_xla_max_cluster_size,
|
||||||
|
"Maximum number of operators in an XLA compilation."),
|
||||||
|
Flag("tf_xla_clustering_debug",
|
||||||
|
&mark_for_compilation_flags->tf_xla_clustering_debug,
|
||||||
|
"Dump graphs during XLA compilation."),
|
||||||
|
Flag("tf_xla_cpu_global_jit",
|
||||||
|
&mark_for_compilation_flags->tf_xla_cpu_global_jit,
|
||||||
|
"Enables global JIT compilation for CPU via SessionOptions."),
|
||||||
|
Flag("tf_xla_clustering_fuel",
|
||||||
|
&mark_for_compilation_flags->tf_xla_clustering_fuel,
|
||||||
|
"Places an artificial limit on the number of ops marked as "
|
||||||
|
"eligible for clustering."),
|
||||||
|
Flag("tf_xla_fusion_only",
|
||||||
|
&mark_for_compilation_flags->tf_xla_fusion_only,
|
||||||
|
"enable fusion of element-wise operations only using XLA when "
|
||||||
|
"global_jit_level is ON*.")};
|
||||||
|
flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllocateAndParseFlags() {
|
||||||
|
build_ops_flags = new BuildXlaOpsPassFlags;
|
||||||
|
build_ops_flags->tf_xla_enable_lazy_compilation = true;
|
||||||
|
|
||||||
|
dump_graph_flags = new DumpGraphFlags;
|
||||||
|
dump_graph_flags->tf_dump_graph_prefix = "/tmp/";
|
||||||
|
|
||||||
|
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||||
|
mark_for_compilation_flags->tf_xla_auto_jit = 0;
|
||||||
|
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
|
||||||
|
mark_for_compilation_flags->tf_xla_max_cluster_size =
|
||||||
|
std::numeric_limits<int32>::max();
|
||||||
|
mark_for_compilation_flags->tf_xla_clustering_debug = false;
|
||||||
|
mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
|
||||||
|
mark_for_compilation_flags->tf_xla_clustering_fuel =
|
||||||
|
std::numeric_limits<int64>::max();
|
||||||
|
mark_for_compilation_flags->tf_xla_fusion_only = false;
|
||||||
|
|
||||||
|
device_flags = new XlaDeviceFlags;
|
||||||
|
device_flags->tf_xla_compile_on_demand = false;
|
||||||
|
|
||||||
|
ops_flags = new XlaOpsCommonFlags;
|
||||||
|
ops_flags->tf_xla_always_defer_compilation = false;
|
||||||
|
|
||||||
|
flag_list = new std::vector<Flag>({
|
||||||
|
Flag("tf_xla_enable_lazy_compilation",
|
||||||
|
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
|
||||||
|
|
||||||
|
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||||
|
"Switch a device into 'on-demand' mode, where instead of "
|
||||||
|
"autoclustering ops are compiled one by one just-in-time."),
|
||||||
|
|
||||||
|
Flag("tf_xla_always_defer_compilation",
|
||||||
|
&ops_flags->tf_xla_always_defer_compilation, ""),
|
||||||
|
});
|
||||||
|
AppendDumpGraphFlagsInternal(flag_list);
|
||||||
|
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||||
|
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return *build_ops_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
DumpGraphFlags* GetDumpGraphFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return dump_graph_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return mark_for_compilation_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDeviceFlags* GetXlaDeviceFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return device_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
return *ops_flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AppendDumpGraphFlags(std::vector<Flag>* flag_list) {
|
||||||
|
std::call_once(flags_init, &AllocateAndParseFlags);
|
||||||
|
AppendDumpGraphFlagsInternal(flag_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
#define TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
||||||
|
|
||||||
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
|
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -24,15 +22,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// Append to *flag_list flag definitions associated with the XLA bridge's
|
// Flags associated with the XLA bridge's mark_for_compilation_pass module.
|
||||||
// mark_for_compilation_pass module.
|
|
||||||
void AppendMarkForCompilationPassFlags(
|
|
||||||
std::vector<tensorflow::Flag>* flag_list);
|
|
||||||
|
|
||||||
// The values of flags associated with the XLA bridge's
|
|
||||||
// mark_for_compilation_pass module.
|
|
||||||
struct MarkForCompilationPassFlags {
|
struct MarkForCompilationPassFlags {
|
||||||
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
|
int32 tf_xla_auto_jit; // Control compilation of operators into XLA
|
||||||
// computations on CPU and GPU devices. 0 = use
|
// computations on CPU and GPU devices. 0 = use
|
||||||
@ -57,12 +48,56 @@ struct MarkForCompilationPassFlags {
|
|||||||
// only using XLA.
|
// only using XLA.
|
||||||
};
|
};
|
||||||
|
|
||||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
// Flags associated with the XLA bridge's xla_device module.
|
||||||
|
struct XlaDeviceFlags {
|
||||||
|
// Switch the CPU device into "on-demand" mode, where instead of
|
||||||
|
// autoclustering ops are compiled one by one just-in-time.
|
||||||
|
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
||||||
|
// feature is battle-tested, we will switch this to be a session option.
|
||||||
|
bool tf_xla_compile_on_demand;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Flags common to the _Xla* ops and their kernels.
|
||||||
|
struct XlaOpsCommonFlags {
|
||||||
|
// If true, _XlaCompile always refuses to compile the cluster, which means the
|
||||||
|
// XLA clusters always run in the TF executor. Defaults to false.
|
||||||
|
bool tf_xla_always_defer_compilation;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Flags for the build_xla_ops pass.
|
||||||
|
struct BuildXlaOpsPassFlags {
|
||||||
|
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
|
||||||
|
// Defaults to true.
|
||||||
|
bool tf_xla_enable_lazy_compilation;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Flags for the XLA bridge's dump_graph module.
|
||||||
|
struct DumpGraphFlags {
|
||||||
|
// Path prefix to which graphs dumped during debugging should be written.
|
||||||
|
string tf_dump_graph_prefix;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return a pointer to the DumpGraphFlags struct;
|
||||||
// repeated calls return the same pointer.
|
// repeated calls return the same pointer.
|
||||||
// This should be called only after Flags::Parse() has returned.
|
// This should be called only after Flags::Parse() has returned.
|
||||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
// Getters for flags structs defined above. The first call to any of these
|
||||||
|
// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
|
||||||
|
// always return the same pointer.
|
||||||
|
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
|
||||||
|
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
|
||||||
|
XlaDeviceFlags* GetXlaDeviceFlags();
|
||||||
|
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||||
|
DumpGraphFlags* GetDumpGraphFlags();
|
||||||
|
|
||||||
|
// Appends the flag definitions associated with
|
||||||
|
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||||
|
//
|
||||||
|
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
|
||||||
|
void AppendMarkForCompilationPassFlags(
|
||||||
|
std::vector<tensorflow::Flag>* flag_list);
|
||||||
|
void AppendDumpGraphFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
|
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/array_ops.h"
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
#include "tensorflow/cc/ops/const_op.h"
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
#include "tensorflow/cc/ops/math_ops.h"
|
#include "tensorflow/cc/ops/math_ops.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
|
||||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||||
@ -208,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope,
|
|||||||
DCHECK_EQ(slice_size.back().type(), DT_INT64);
|
DCHECK_EQ(slice_size.back().type(), DT_INT64);
|
||||||
}
|
}
|
||||||
|
|
||||||
*size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
|
// Trivial ConcatV2 nodes (with exactly one input) are disallowed.
|
||||||
ops::Const(host_scope.WithOpName("concat_axis"), 0));
|
*size =
|
||||||
|
slice_size.size() == 1
|
||||||
|
? slice_size[0]
|
||||||
|
: ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
|
||||||
|
ops::Const(host_scope.WithOpName("concat_axis"), 0));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
|
|||||||
.WithOpName("static_shaped_slice"),
|
.WithOpName("static_shaped_slice"),
|
||||||
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
|
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
|
||||||
.node();
|
.node();
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(main_scope.status());
|
||||||
|
|
||||||
std::vector<string> compile_time_const_inputs;
|
std::vector<string> compile_time_const_inputs;
|
||||||
compile_time_const_inputs.push_back("size");
|
compile_time_const_inputs.push_back("size");
|
||||||
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
|
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
|
||||||
@ -284,49 +291,45 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If `n` is a slice we can rewrite to have a static shape (i.e. have the output
|
// Return true if `n` is a slice we can rewrite to have a static shape
|
||||||
// shape only depend on the "size" input) then returns the a SliceInputs
|
// (i.e. have the output shape only depend on the "size" input).
|
||||||
// representing the inputs to `n`. Otherwise returns nullopt.
|
xla::StatusOr<bool> IsRewritableSlice(Node* n) {
|
||||||
StatusOrOptional<SliceInputs> IsRewritableSlice(Node* n) {
|
|
||||||
if (n->type_string() != "Slice") {
|
if (n->type_string() != "Slice") {
|
||||||
return {absl::nullopt};
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!GetXlaClusterForNode(*n).has_value()) {
|
if (!GetXlaClusterForNode(*n).has_value()) {
|
||||||
// There is no need to change slice ops outside XLA clusters.
|
// There is no need to change slice ops outside XLA clusters.
|
||||||
return {absl::nullopt};
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
||||||
GetSliceInputs(n));
|
GetSliceInputs(n));
|
||||||
if (!slice_inputs.has_value()) {
|
if (!slice_inputs.has_value()) {
|
||||||
return {absl::nullopt};
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If slice_size[i] < -1 for any i then executing the slice will throw an
|
// If slice_size[i] < -1 for any i then executing the slice will throw an
|
||||||
// error, and we don't do anything here.
|
// error, and we don't do anything here.
|
||||||
bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector,
|
return absl::c_all_of(slice_inputs->size_as_vector,
|
||||||
[](int64 size_i) { return size_i >= -1; });
|
[](int64 size_i) { return size_i >= -1; });
|
||||||
if (!slice_is_ok) {
|
|
||||||
return {absl::nullopt};
|
|
||||||
}
|
|
||||||
|
|
||||||
return slice_inputs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
||||||
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite;
|
std::vector<Node*> slices_to_rewrite;
|
||||||
for (Node* n : g->nodes()) {
|
for (Node* n : g->nodes()) {
|
||||||
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
|
||||||
IsRewritableSlice(n));
|
if (is_rewritable) {
|
||||||
if (slice_inputs.has_value()) {
|
slices_to_rewrite.push_back(n);
|
||||||
slices_to_rewrite.push_back({n, std::move(*slice_inputs)});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& pair : slices_to_rewrite) {
|
for (Node* n : slices_to_rewrite) {
|
||||||
TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second,
|
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
|
||||||
*GetXlaClusterForNode(*pair.first)));
|
GetSliceInputs(n));
|
||||||
|
TF_RET_CHECK(slice_inputs.has_value());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!slices_to_rewrite.empty()) {
|
if (!slices_to_rewrite.empty()) {
|
||||||
@ -342,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) {
|
|||||||
|
|
||||||
Status IncreaseDynamismForAutoJitPass::Run(
|
Status IncreaseDynamismForAutoJitPass::Run(
|
||||||
const GraphOptimizationPassOptions& options) {
|
const GraphOptimizationPassOptions& options) {
|
||||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
legacy_flags::GetMarkForCompilationPassFlags();
|
|
||||||
if (flags->tf_xla_clustering_debug) {
|
if (flags->tf_xla_clustering_debug) {
|
||||||
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
|
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
|
||||||
**options.graph, options.flib_def);
|
**options.graph, options.flib_def);
|
||||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::_;
|
||||||
using testing::matchers::AssignedDevice;
|
using testing::matchers::AssignedDevice;
|
||||||
using testing::matchers::Attr;
|
using testing::matchers::Attr;
|
||||||
using testing::matchers::Const;
|
using testing::matchers::Const;
|
||||||
@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) {
|
|||||||
EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
|
EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) {
|
||||||
|
Scope root = Scope::NewRootScope()
|
||||||
|
.ExitOnError()
|
||||||
|
.WithAssignedDevice(kDeviceName)
|
||||||
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
|
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||||
|
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||||
|
Output size = ops::Const(root.WithOpName("size"), {-1});
|
||||||
|
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> result;
|
||||||
|
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||||
|
|
||||||
|
Node* static_shaped_slice = testing::FindNodeByName(
|
||||||
|
result.get(), "slice/static_shaped_slice/static_shaped_slice");
|
||||||
|
EXPECT_NE(static_shaped_slice, nullptr);
|
||||||
|
EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2")))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
|
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
|
||||||
Scope root = Scope::NewRootScope()
|
Scope root = Scope::NewRootScope()
|
||||||
.ExitOnError()
|
.ExitOnError()
|
||||||
@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
|
|||||||
CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
|
CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 ToInt64(int v) { return static_cast<int64>(v); }
|
||||||
|
|
||||||
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
|
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
|
||||||
Scope root = Scope::NewRootScope()
|
Scope root = Scope::NewRootScope()
|
||||||
.ExitOnError()
|
.ExitOnError()
|
||||||
.WithAssignedDevice(kDeviceName)
|
.WithAssignedDevice(kDeviceName)
|
||||||
.WithXlaCluster("cluster_0");
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
auto to_int64 = [](int v) { return static_cast<int64>(v); };
|
|
||||||
|
|
||||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||||
Output size =
|
Output size =
|
||||||
ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)});
|
ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)});
|
||||||
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||||
|
|
||||||
std::unique_ptr<Graph> result;
|
std::unique_ptr<Graph> result;
|
||||||
@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
|
|||||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) {
|
||||||
|
Scope root = Scope::NewRootScope()
|
||||||
|
.ExitOnError()
|
||||||
|
.WithAssignedDevice(kDeviceName)
|
||||||
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
|
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||||
|
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||||
|
Output size = ops::Const<int64>(root.WithOpName("size"), {});
|
||||||
|
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> result;
|
||||||
|
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||||
|
|
||||||
|
Node* static_shaped_slice = testing::FindNodeByName(
|
||||||
|
result.get(), "slice/static_shaped_slice/static_shaped_slice");
|
||||||
|
ASSERT_NE(static_shaped_slice, nullptr);
|
||||||
|
EXPECT_THAT(static_shaped_slice,
|
||||||
|
NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr),
|
||||||
|
Inputs(_, _, Out(NodeWith(Name(size.node()->name()))))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
|
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
|
||||||
Scope root = Scope::NewRootScope()
|
Scope root = Scope::NewRootScope()
|
||||||
.ExitOnError()
|
.ExitOnError()
|
||||||
.WithAssignedDevice(kDeviceName)
|
.WithAssignedDevice(kDeviceName)
|
||||||
.WithXlaCluster("cluster_0");
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
auto to_int64 = [](int v) { return static_cast<int64>(v); };
|
auto ToInt64 = [](int v) { return static_cast<int64>(v); };
|
||||||
|
|
||||||
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||||
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
|
||||||
@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
|
|||||||
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
|
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
|
||||||
|
|
||||||
Output size =
|
Output size =
|
||||||
ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}});
|
ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}});
|
||||||
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
|
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
|
||||||
|
|
||||||
std::unique_ptr<Graph> result;
|
std::unique_ptr<Graph> result;
|
||||||
@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
|
|||||||
Not(Contains(NodeWith(Op("Slice"),
|
Not(Contains(NodeWith(Op("Slice"),
|
||||||
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
Attr(kXlaCompileTimeConstantInputsAttr)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) {
|
||||||
|
Scope root = Scope::NewRootScope()
|
||||||
|
.ExitOnError()
|
||||||
|
.WithAssignedDevice(kDeviceName)
|
||||||
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
|
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
|
||||||
|
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
|
||||||
|
Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500});
|
||||||
|
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a);
|
||||||
|
|
||||||
|
Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200});
|
||||||
|
Output slice_with_slice_input = ops::Slice(
|
||||||
|
root.WithOpName("slice_with_slice_input"), slice, begin, size_b);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> result;
|
||||||
|
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||||
|
|
||||||
|
Node* static_shaped_slice = testing::FindNodeByName(
|
||||||
|
result.get(),
|
||||||
|
"slice_with_slice_input/static_shaped_slice/static_shaped_slice");
|
||||||
|
ASSERT_NE(static_shaped_slice, nullptr);
|
||||||
|
EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
|
||||||
|
<< "Expected DT_FLOAT, was "
|
||||||
|
<< DataType_Name(static_shaped_slice->output_type(0));
|
||||||
|
EXPECT_THAT(
|
||||||
|
static_shaped_slice,
|
||||||
|
NodeWith(
|
||||||
|
Op("Slice"),
|
||||||
|
Inputs(Out(NodeWith(
|
||||||
|
Op("Slice"),
|
||||||
|
Name("slice/static_shaped_slice/static_shaped_slice"))),
|
||||||
|
_, _)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) {
|
||||||
|
Scope root = Scope::NewRootScope()
|
||||||
|
.ExitOnError()
|
||||||
|
.WithAssignedDevice(kDeviceName)
|
||||||
|
.WithXlaCluster("cluster_0");
|
||||||
|
|
||||||
|
Output input_float =
|
||||||
|
ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT);
|
||||||
|
Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64);
|
||||||
|
|
||||||
|
Output begin_begin =
|
||||||
|
ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32);
|
||||||
|
Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1});
|
||||||
|
Output begin =
|
||||||
|
ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size);
|
||||||
|
|
||||||
|
Output size =
|
||||||
|
ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)});
|
||||||
|
Output slice_with_slice_begin = ops::Slice(
|
||||||
|
root.WithOpName("slice_with_slice_begin"), input_float, begin, size);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> result;
|
||||||
|
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
|
||||||
|
|
||||||
|
Node* static_shaped_slice = testing::FindNodeByName(
|
||||||
|
result.get(),
|
||||||
|
"slice_with_slice_begin/static_shaped_slice/static_shaped_slice");
|
||||||
|
ASSERT_NE(static_shaped_slice, nullptr);
|
||||||
|
EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
|
||||||
|
<< "Expected DT_FLOAT, was "
|
||||||
|
<< DataType_Name(static_shaped_slice->output_type(0));
|
||||||
|
EXPECT_THAT(
|
||||||
|
static_shaped_slice,
|
||||||
|
NodeWith(
|
||||||
|
Op("Slice"),
|
||||||
|
Inputs(_,
|
||||||
|
Out(NodeWith(
|
||||||
|
Op("Slice"),
|
||||||
|
Name("begin/static_shaped_slice/static_shaped_slice"))),
|
||||||
|
_)));
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -12,10 +12,10 @@ cc_library(
|
|||||||
hdrs = ["xla_ops.h"],
|
hdrs = ["xla_ops.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/jit:common",
|
"//tensorflow/compiler/jit:common",
|
||||||
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:xla_compilation_cache",
|
"//tensorflow/compiler/jit:xla_compilation_cache",
|
||||||
"//tensorflow/compiler/jit:xla_device",
|
"//tensorflow/compiler/jit:xla_device",
|
||||||
"//tensorflow/compiler/jit:xla_launch_util",
|
"//tensorflow/compiler/jit:xla_launch_util",
|
||||||
"//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags",
|
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
@ -418,7 +418,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
|||||||
cannot_compile_cluster = cannot_compile_cluster_;
|
cannot_compile_cluster = cannot_compile_cluster_;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
|
if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
|
||||||
cannot_compile_cluster) {
|
cannot_compile_cluster) {
|
||||||
executable = nullptr;
|
executable = nullptr;
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,65 +0,0 @@
|
|||||||
# Legacy command line flags for the XLA bridge libraries.
|
|
||||||
|
|
||||||
# Please do not add more flags to this package.
|
|
||||||
|
|
||||||
# The XLA bridge libraries were written in an environment that allowed
|
|
||||||
# command-line flags to be scattered freely throughout the libraries. This
|
|
||||||
# model, while initially convenient, leads to a proliferation in unused command
|
|
||||||
# line flags in tests and binaries, and serious problems in servers, where one
|
|
||||||
# might wish parameters to be different in independent RPC calls to the same
|
|
||||||
# routine.
|
|
||||||
#
|
|
||||||
# Please don't add more flags. If you're a library author, pass options and
|
|
||||||
# parameters explicitly through the library's interface.
|
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
package(default_visibility = ["//tensorflow:internal"])
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "mark_for_compilation_pass_flags",
|
|
||||||
srcs = ["mark_for_compilation_pass_flags.cc"],
|
|
||||||
hdrs = ["mark_for_compilation_pass_flags.h"],
|
|
||||||
deps =
|
|
||||||
[
|
|
||||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "xla_device_flags",
|
|
||||||
srcs = ["xla_device_flags.cc"],
|
|
||||||
hdrs = ["xla_device_flags.h"],
|
|
||||||
deps =
|
|
||||||
[
|
|
||||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "build_xla_ops_pass_flags",
|
|
||||||
srcs = ["build_xla_ops_pass_flags.cc"],
|
|
||||||
hdrs = ["build_xla_ops_pass_flags.h"],
|
|
||||||
deps =
|
|
||||||
[
|
|
||||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "xla_ops_common_flags",
|
|
||||||
srcs = ["xla_ops_common_flags.cc"],
|
|
||||||
hdrs = ["xla_ops_common_flags.h"],
|
|
||||||
deps =
|
|
||||||
[
|
|
||||||
"//tensorflow/compiler/xla:parse_flags_from_env",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,47 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <mutex> // NOLINT
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
BuildXlaOpsPassFlags* flags;
|
|
||||||
std::vector<Flag>* flag_list;
|
|
||||||
std::once_flag flags_init;
|
|
||||||
|
|
||||||
void AllocateAndParseFlags() {
|
|
||||||
flags = new BuildXlaOpsPassFlags;
|
|
||||||
flags->tf_xla_enable_lazy_compilation = true;
|
|
||||||
flag_list = new std::vector<Flag>({
|
|
||||||
Flag("tf_xla_enable_lazy_compilation",
|
|
||||||
&flags->tf_xla_enable_lazy_compilation, ""),
|
|
||||||
});
|
|
||||||
xla::ParseFlagsFromEnv(*flag_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
|
|
||||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
|
||||||
return *flags;
|
|
||||||
}
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
@ -1,37 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// Flags for the build_xla_ops pass.
|
|
||||||
struct BuildXlaOpsPassFlags {
|
|
||||||
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
|
|
||||||
// Defaults to true.
|
|
||||||
bool tf_xla_enable_lazy_compilation;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment
|
|
||||||
// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
|
|
||||||
// only the first time this routine is called.
|
|
||||||
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
|
|
@ -1,98 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
|
|
||||||
|
|
||||||
#include <mutex>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
|
||||||
// via flags_init.
|
|
||||||
static MarkForCompilationPassFlags* flags;
|
|
||||||
static std::vector<Flag>* flag_list;
|
|
||||||
static std::once_flag flags_init;
|
|
||||||
|
|
||||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
|
||||||
static void AllocateFlags() {
|
|
||||||
flags = new MarkForCompilationPassFlags;
|
|
||||||
flags->tf_xla_auto_jit = 0;
|
|
||||||
flags->tf_xla_min_cluster_size = 2;
|
|
||||||
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
|
|
||||||
flags->tf_xla_clustering_debug = false;
|
|
||||||
flags->tf_xla_cpu_global_jit = false;
|
|
||||||
flags->tf_xla_clustering_fuel = std::numeric_limits<int64>::max();
|
|
||||||
flags->tf_xla_fusion_only = false;
|
|
||||||
flag_list = new std::vector<Flag>(
|
|
||||||
{Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
|
|
||||||
"Control compilation of operators into XLA computations on CPU and "
|
|
||||||
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
|
|
||||||
"things very likely to be improved; 2 = on for everything. "
|
|
||||||
"Experimental."),
|
|
||||||
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
|
|
||||||
"Minimum number of operators in an XLA compilation. Ignored for "
|
|
||||||
"operators placed on an XLA device or operators explicitly marked "
|
|
||||||
"for compilation."),
|
|
||||||
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
|
|
||||||
"Maximum number of operators in an XLA compilation."),
|
|
||||||
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
|
|
||||||
"Dump graphs during XLA compilation."),
|
|
||||||
Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
|
|
||||||
"Enables global JIT compilation for CPU via SessionOptions."),
|
|
||||||
Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
|
|
||||||
"Places an artificial limit on the number of ops marked as "
|
|
||||||
"eligible for clustering."),
|
|
||||||
Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
|
|
||||||
"enable fusion of element-wise operations only using XLA when "
|
|
||||||
"global_jit_level is ON*.")});
|
|
||||||
xla::ParseFlagsFromEnv(*flag_list);
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
|
||||||
VLOG(1) << "Parsed MarkForCompilationPassFlags:";
|
|
||||||
VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
|
|
||||||
VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size;
|
|
||||||
VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size;
|
|
||||||
VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug;
|
|
||||||
VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
|
|
||||||
VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel;
|
|
||||||
VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append to *append_to flag definitions associated with the XLA bridge's
|
|
||||||
// mark_for_compilation_pass module.
|
|
||||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* append_to) {
|
|
||||||
std::call_once(flags_init, &AllocateFlags);
|
|
||||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a pointer to the MarkForCompilationPassFlags struct;
|
|
||||||
// repeated calls return the same pointer.
|
|
||||||
// This should be called only after Flags::Parse() has returned.
|
|
||||||
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
|
|
||||||
std::call_once(flags_init, &AllocateFlags);
|
|
||||||
return flags;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
@ -1,56 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Legacy flags for the XLA bridge's xla_device module.
|
|
||||||
|
|
||||||
#include <mutex>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
|
||||||
// via flags_init.
|
|
||||||
static XlaDeviceFlags* flags;
|
|
||||||
static std::vector<Flag>* flag_list;
|
|
||||||
static std::once_flag flags_init;
|
|
||||||
|
|
||||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
|
||||||
static void AllocateFlags() {
|
|
||||||
flags = new XlaDeviceFlags;
|
|
||||||
flags->tf_xla_compile_on_demand = false;
|
|
||||||
flag_list = new std::vector<Flag>({
|
|
||||||
Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand,
|
|
||||||
"Switch a device into 'on-demand' mode, where instead of "
|
|
||||||
"autoclustering ops are compiled one by one just-in-time."),
|
|
||||||
});
|
|
||||||
xla::ParseFlagsFromEnv(*flag_list);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return a pointer to the XlaDeviceFlags struct;
|
|
||||||
// repeated calls return the same pointer.
|
|
||||||
// This should be called only after Flags::Parse() has returned.
|
|
||||||
XlaDeviceFlags* GetXlaDeviceFlags() {
|
|
||||||
std::call_once(flags_init, &AllocateFlags);
|
|
||||||
return flags;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
@ -1,47 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
|
|
||||||
|
|
||||||
// Legacy flags for the XLA bridge's xla_device module.
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// The values of flags associated with the XLA bridge's
|
|
||||||
// xla_device module.
|
|
||||||
typedef struct {
|
|
||||||
// Switch the CPU device into "on-demand" mode, where instead of
|
|
||||||
// autoclustering ops are compiled one by one just-in-time.
|
|
||||||
// Enabling this mode by a legacy flag is a temporary mechanism. When this
|
|
||||||
// feature is battle-tested, we will switch this to be a session option.
|
|
||||||
bool tf_xla_compile_on_demand;
|
|
||||||
} XlaDeviceFlags;
|
|
||||||
|
|
||||||
// Return a pointer to the XlaDeviceFlags struct;
|
|
||||||
// repeated calls return the same pointer.
|
|
||||||
// This should be called only after Flags::Parse() has returned.
|
|
||||||
XlaDeviceFlags* GetXlaDeviceFlags();
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
|
|
@ -1,52 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <mutex> // NOLINT
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
|
|
||||||
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
XlaOpsCommonFlags* flags;
|
|
||||||
std::vector<Flag>* flag_list;
|
|
||||||
std::once_flag flags_init;
|
|
||||||
|
|
||||||
void AllocateAndParseFlags() {
|
|
||||||
flags = new XlaOpsCommonFlags;
|
|
||||||
flags->tf_xla_always_defer_compilation = false;
|
|
||||||
flag_list = new std::vector<Flag>({
|
|
||||||
Flag("tf_xla_always_defer_compilation",
|
|
||||||
&flags->tf_xla_always_defer_compilation, ""),
|
|
||||||
});
|
|
||||||
xla::ParseFlagsFromEnv(*flag_list);
|
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
|
||||||
VLOG(1) << "Parsed XlaOpsCommonFlags:";
|
|
||||||
VLOG(1) << " tf_xla_always_defer_compilation = "
|
|
||||||
<< flags->tf_xla_always_defer_compilation;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
|
|
||||||
std::call_once(flags_init, &AllocateAndParseFlags);
|
|
||||||
return *flags;
|
|
||||||
}
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
@ -1,36 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
|
|
||||||
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace legacy_flags {
|
|
||||||
|
|
||||||
// Flags common to the _Xla* ops and their kernels.
|
|
||||||
struct XlaOpsCommonFlags {
|
|
||||||
// If true, _XlaCompile always refuses to compile the cluster, which means the
|
|
||||||
// XLA clusters always run in the TF executor. Defaults to false.
|
|
||||||
bool tf_xla_always_defer_compilation;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment
|
|
||||||
// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
|
|
||||||
// only the first time this routine is called.
|
|
||||||
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
|
||||||
|
|
||||||
} // namespace legacy_flags
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
|
|
@ -24,8 +24,8 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
|
||||||
#include "tensorflow/compiler/jit/union_find.h"
|
#include "tensorflow/compiler/jit/union_find.h"
|
||||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||||
@ -72,6 +72,11 @@ struct OperationFilter {
|
|||||||
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
|
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
|
||||||
// have dummy XLA implementations.
|
// have dummy XLA implementations.
|
||||||
bool allow_dummy_ops;
|
bool allow_dummy_ops;
|
||||||
|
|
||||||
|
// Whether ops that produce or consume DT_VARIANT values are allowed. We
|
||||||
|
// don't auto-cluster these ops because we don't yet support live-in or
|
||||||
|
// live-out DT_VARIANT values.
|
||||||
|
bool allow_ops_producing_or_consuming_variant;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool IsDummyImplOp(absl::string_view op_name) {
|
bool IsDummyImplOp(absl::string_view op_name) {
|
||||||
@ -81,7 +86,13 @@ bool IsDummyImplOp(absl::string_view op_name) {
|
|||||||
bool IsStatefulRandomOp(absl::string_view op_name) {
|
bool IsStatefulRandomOp(absl::string_view op_name) {
|
||||||
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
|
||||||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
|
||||||
op_name == "TruncatedNormal";
|
op_name == "TruncatedNormal" || op_name == "Multinomial";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OpProducesOrConsumesVariant(const Node& node) {
|
||||||
|
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
|
||||||
|
return absl::c_any_of(node.input_types(), is_variant) ||
|
||||||
|
absl::c_any_of(node.output_types(), is_variant);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
|
||||||
@ -246,6 +257,10 @@ bool IsCompilableCall(const NodeDef& call_def,
|
|||||||
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
|
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (!op_filter.allow_ops_producing_or_consuming_variant &&
|
||||||
|
OpProducesOrConsumesVariant(*node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
|
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
|
||||||
lib_runtime)) {
|
lib_runtime)) {
|
||||||
@ -427,8 +442,7 @@ Status FindCompilationCandidates(
|
|||||||
BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
|
BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
|
||||||
&compile_time_const_nodes));
|
&compile_time_const_nodes));
|
||||||
|
|
||||||
int64& fuel =
|
int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
|
||||||
legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
|
|
||||||
|
|
||||||
// Iterate over nodes in sorted order so that compiler fuel is deterministic.
|
// Iterate over nodes in sorted order so that compiler fuel is deterministic.
|
||||||
// We can't simply pass op_nodes().begin() and op_nodes().end to the
|
// We can't simply pass op_nodes().begin() and op_nodes().end to the
|
||||||
@ -471,16 +485,15 @@ Status FindCompilationCandidates(
|
|||||||
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
|
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
|
||||||
DeviceType jit_device_type(registration->compilation_device_name);
|
DeviceType jit_device_type(registration->compilation_device_name);
|
||||||
|
|
||||||
|
bool always_auto_cluster = registration->autoclustering_policy ==
|
||||||
|
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||||
|
|
||||||
OperationFilter op_filter;
|
OperationFilter op_filter;
|
||||||
op_filter.allow_resource_ops = registration->compile_resource_ops;
|
op_filter.allow_resource_ops = registration->compile_resource_ops;
|
||||||
op_filter.allow_stateful_rng_ops =
|
op_filter.allow_stateful_rng_ops = always_auto_cluster;
|
||||||
(registration->autoclustering_policy ==
|
op_filter.allow_control_trigger = always_auto_cluster;
|
||||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
op_filter.allow_dummy_ops = always_auto_cluster;
|
||||||
op_filter.allow_control_trigger =
|
op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
|
||||||
(registration->autoclustering_policy ==
|
|
||||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
|
||||||
op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
|
|
||||||
XlaOpRegistry::AutoclusteringPolicy::kAlways);
|
|
||||||
|
|
||||||
if (!HasXLAKernel(*node, jit_device_type) &&
|
if (!HasXLAKernel(*node, jit_device_type) &&
|
||||||
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
|
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
|
||||||
@ -504,6 +517,12 @@ Status FindCompilationCandidates(
|
|||||||
<< node->type_string() << ")";
|
<< node->type_string() << ")";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (!op_filter.allow_ops_producing_or_consuming_variant &&
|
||||||
|
OpProducesOrConsumesVariant(*node)) {
|
||||||
|
VLOG(2) << "Rejecting " << node->name()
|
||||||
|
<< ": produces or consumes DT_VARIANT";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (!op_filter.allow_resource_ops &&
|
if (!op_filter.allow_resource_ops &&
|
||||||
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
|
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
|
||||||
@ -607,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
|
|||||||
// To set compilation to be on by default, change the following line.
|
// To set compilation to be on by default, change the following line.
|
||||||
global_jit_level = OptimizerOptions::OFF;
|
global_jit_level = OptimizerOptions::OFF;
|
||||||
}
|
}
|
||||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
legacy_flags::GetMarkForCompilationPassFlags();
|
|
||||||
if (flags->tf_xla_auto_jit == -1 ||
|
if (flags->tf_xla_auto_jit == -1 ||
|
||||||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
|
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
|
||||||
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
|
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
|
||||||
@ -641,6 +659,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
|
|||||||
op_filter.allow_stateful_rng_ops = true;
|
op_filter.allow_stateful_rng_ops = true;
|
||||||
op_filter.allow_control_trigger = true;
|
op_filter.allow_control_trigger = true;
|
||||||
op_filter.allow_dummy_ops = true;
|
op_filter.allow_dummy_ops = true;
|
||||||
|
op_filter.allow_ops_producing_or_consuming_variant = true;
|
||||||
|
|
||||||
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
|
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
|
||||||
}
|
}
|
||||||
@ -651,8 +670,7 @@ Status MarkForCompilationPass::Run(
|
|||||||
// device ahead of time.
|
// device ahead of time.
|
||||||
OptimizerOptions::GlobalJitLevel global_jit_level =
|
OptimizerOptions::GlobalJitLevel global_jit_level =
|
||||||
GetGlobalJitLevel(options);
|
GetGlobalJitLevel(options);
|
||||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
legacy_flags::GetMarkForCompilationPassFlags();
|
|
||||||
bool fusion_only = flags->tf_xla_fusion_only;
|
bool fusion_only = flags->tf_xla_fusion_only;
|
||||||
|
|
||||||
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||||
@ -953,8 +971,7 @@ Status MarkForCompilationPass::RunImpl(
|
|||||||
|
|
||||||
OptimizerOptions::GlobalJitLevel global_jit_level =
|
OptimizerOptions::GlobalJitLevel global_jit_level =
|
||||||
GetGlobalJitLevel(options);
|
GetGlobalJitLevel(options);
|
||||||
legacy_flags::MarkForCompilationPassFlags* flags =
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
legacy_flags::GetMarkForCompilationPassFlags();
|
|
||||||
|
|
||||||
// Repeatedly contract edges between clusters that are on the same device,
|
// Repeatedly contract edges between clusters that are on the same device,
|
||||||
// provided the contraction would not create a cycle.
|
// provided the contraction would not create a cycle.
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/ops/array_ops.h"
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||||
#include "tensorflow/cc/ops/function_ops.h"
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/list_ops.h"
|
||||||
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
#include "tensorflow/cc/ops/resource_variable_ops.h"
|
||||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
@ -1147,5 +1148,80 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
|
|||||||
EXPECT_EQ(clusters["test/check"], "");
|
EXPECT_EQ(clusters["test/check"], "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
|
||||||
|
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
|
||||||
|
|
||||||
|
Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
|
||||||
|
Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
|
||||||
|
|
||||||
|
Output tensor_list_reserve = ops::TensorListReserve(
|
||||||
|
root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
Output dummy_input =
|
||||||
|
ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
|
||||||
|
Output variant_input =
|
||||||
|
ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
|
||||||
|
|
||||||
|
// Create one more node so that we don't avoid creating a cluster solely
|
||||||
|
// because it would be trivial.
|
||||||
|
Output dummy_cast =
|
||||||
|
ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
|
||||||
|
|
||||||
|
Output tensor_list_element_shape = ops::TensorListElementShape(
|
||||||
|
root.WithOpName("test/tensor_list_element_shape"), variant_input,
|
||||||
|
DT_INT32);
|
||||||
|
|
||||||
|
root.graph()->AddControlEdge(dummy_cast.node(),
|
||||||
|
tensor_list_element_shape.node());
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
|
||||||
|
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
|
||||||
|
|
||||||
|
Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
|
||||||
|
Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
|
||||||
|
|
||||||
|
Output tensor_list_reserve = ops::TensorListReserve(
|
||||||
|
root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
|
||||||
|
for (Node* n : graph->nodes()) {
|
||||||
|
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
|
||||||
|
n->set_assigned_device_name(xla_cpu_device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
|
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||||
|
EXPECT_NE(clusters["test/tensor_list_reserve"], "");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -34,15 +34,9 @@ namespace tensorflow {
|
|||||||
//
|
//
|
||||||
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
|
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
|
||||||
// make this more direct, but probably not worth it solely for this test.
|
// make this more direct, but probably not worth it solely for this test.
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
|
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
|
||||||
|
|
||||||
auto delete_devices = gtl::MakeCleanup([&] {
|
|
||||||
for (Device* d : devices) {
|
|
||||||
delete d;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
GraphOptimizationPassOptions opt_options;
|
GraphOptimizationPassOptions opt_options;
|
||||||
opt_options.graph = graph;
|
opt_options.graph = graph;
|
||||||
opt_options.session_options = session_options;
|
opt_options.session_options = session_options;
|
||||||
|
@ -18,3 +18,9 @@ tf_gen_op_wrapper_py(
|
|||||||
out = "xla_ops.py",
|
out = "xla_ops.py",
|
||||||
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "xla_ops_grad",
|
||||||
|
srcs = ["xla_ops_grad.py"],
|
||||||
|
deps = ["//tensorflow/python:framework_ops"],
|
||||||
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
"""Gradients for XLA ops."""
|
||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -12,21 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""dnn python module.
|
|
||||||
|
|
||||||
Importing from tensorflow.python.estimator is unsupported
|
|
||||||
and will soon break!
|
|
||||||
"""
|
|
||||||
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow_estimator.contrib.estimator.python.estimator import dnn
|
from tensorflow.python.framework import ops
|
||||||
|
|
||||||
# Include attrs that start with single underscore.
|
|
||||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
|
||||||
dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')]
|
|
||||||
|
|
||||||
from tensorflow_estimator.contrib.estimator.python.estimator.dnn import *
|
@ops.RegisterGradient("XlaClusterOutput")
|
||||||
|
def _XlaClusterOutputGrad(_, grad):
|
||||||
|
del grad # unused
|
||||||
|
raise RuntimeError("Gradient computation of graph in xla.compile() is "
|
||||||
|
"prohibited because it can cause performance degradation."
|
||||||
|
"Please move gradient computation inside xla.compile().")
|
@ -26,6 +26,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
|
||||||
|
|
||||||
|
namespace reduce_device_to_host_copies {
|
||||||
Status FindNodesToDecluster(const Graph& graph,
|
Status FindNodesToDecluster(const Graph& graph,
|
||||||
absl::flat_hash_set<Node*>* result,
|
absl::flat_hash_set<Node*>* result,
|
||||||
absl::Span<Node* const> post_order) {
|
absl::Span<Node* const> post_order) {
|
||||||
@ -140,8 +144,6 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
|
|
||||||
|
|
||||||
// Clones nodes to outside their cluster to avoid device-to-host copies. For
|
// Clones nodes to outside their cluster to avoid device-to-host copies. For
|
||||||
// instance, converts this:
|
// instance, converts this:
|
||||||
//
|
//
|
||||||
@ -168,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
|
|||||||
// where the ===> arrow has a hostmem source and destination and would entail a
|
// where the ===> arrow has a hostmem source and destination and would entail a
|
||||||
// device to host copy if the source and destination were not in the same XLA
|
// device to host copy if the source and destination were not in the same XLA
|
||||||
// cluster.
|
// cluster.
|
||||||
Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
|
Status PartiallyDeclusterGraph(Graph* graph) {
|
||||||
// When deciding whether to decluster a particular node, we base our decision
|
// When deciding whether to decluster a particular node, we base our decision
|
||||||
// on if we've decided that some of its consumers have to be declustered too.
|
// on if we've decided that some of its consumers have to be declustered too.
|
||||||
// Iterating the graph in post-order guarantees that consumers have been
|
// Iterating the graph in post-order guarantees that consumers have been
|
||||||
@ -206,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
|
|||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
} // namespace reduce_device_to_host_copies
|
||||||
|
|
||||||
|
namespace reduce_recompilation {
|
||||||
bool IsIntraClusterEdge(const Edge& edge) {
|
bool IsIntraClusterEdge(const Edge& edge) {
|
||||||
absl::optional<absl::string_view> src_cluster_name =
|
absl::optional<absl::string_view> src_cluster_name =
|
||||||
GetXlaClusterForNode(*edge.src());
|
GetXlaClusterForNode(*edge.src());
|
||||||
@ -269,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) {
|
|||||||
// regress performance in any significant manner. We will have to revisit this
|
// regress performance in any significant manner. We will have to revisit this
|
||||||
// algorith with a more complex cost model if this assumption turns out to be
|
// algorith with a more complex cost model if this assumption turns out to be
|
||||||
// incorrect.
|
// incorrect.
|
||||||
Status DeclusterNodesToReduceRecompilations(Graph* graph) {
|
Status PartiallyDeclusterGraph(Graph* graph) {
|
||||||
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
|
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
|
||||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||||
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
|
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
|
||||||
@ -322,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) {
|
|||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
} // namespace reduce_recompilation
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status PartiallyDeclusterPass::Run(
|
Status PartiallyDeclusterPass::Run(
|
||||||
@ -334,8 +338,9 @@ Status PartiallyDeclusterPass::Run(
|
|||||||
|
|
||||||
Graph* graph = options.graph->get();
|
Graph* graph = options.graph->get();
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
|
TF_RETURN_IF_ERROR(
|
||||||
TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
|
reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
|
||||||
|
TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
|||||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||||
|
|
||||||
// This is needed to register the XLA_GPU device.
|
// This is needed to register the XLA_GPU device.
|
||||||
std::vector<Device*> devices;
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
TF_ASSERT_OK(DeviceFactory::AddDevices(
|
TF_ASSERT_OK(DeviceFactory::AddDevices(
|
||||||
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
|
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
|
||||||
|
|
||||||
@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
|||||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||||
|
|
||||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||||
|
|
||||||
for (Device* d : devices) {
|
|
||||||
delete d;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
|
||||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
|||||||
// operators using XLA via the XLA "Host" (CPU) backend.
|
// operators using XLA via the XLA "Host" (CPU) backend.
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||||
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
|
|
||||||
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device.h"
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
#include "tensorflow/compiler/jit/xla_device_ops.h"
|
||||||
@ -31,13 +31,13 @@ namespace tensorflow {
|
|||||||
class XlaCpuDeviceFactory : public DeviceFactory {
|
class XlaCpuDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
Status XlaCpuDeviceFactory::CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags();
|
XlaDeviceFlags* flags = GetXlaDeviceFlags();
|
||||||
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
bool compile_on_demand = flags->tf_xla_compile_on_demand;
|
||||||
|
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
|||||||
options.device_ordinal = 0;
|
options.device_ordinal = 0;
|
||||||
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
|
||||||
options.use_multiple_streams = false;
|
options.use_multiple_streams = false;
|
||||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
|
||||||
devices->push_back(device.release());
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
|
|||||||
XlaDevice::~XlaDevice() {
|
XlaDevice::~XlaDevice() {
|
||||||
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
|
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
|
while (outstanding_asynchronous_operations_ > 0) {
|
||||||
|
outstanding_asynchronous_operations_cv_.wait(lock);
|
||||||
|
}
|
||||||
if (device_context_) {
|
if (device_context_) {
|
||||||
device_context_->Unref();
|
device_context_->Unref();
|
||||||
}
|
}
|
||||||
@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
|||||||
|
|
||||||
Status XlaDevice::Sync() {
|
Status XlaDevice::Sync() {
|
||||||
VLOG(1) << "XlaDevice::Sync";
|
VLOG(1) << "XlaDevice::Sync";
|
||||||
|
tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
|
||||||
std::shared_ptr<se::Stream> stream;
|
std::shared_ptr<se::Stream> stream;
|
||||||
{
|
{
|
||||||
mutex_lock lock(mu_);
|
mutex_lock lock(mu_);
|
||||||
@ -391,13 +395,46 @@ Status XlaDevice::Sync() {
|
|||||||
}
|
}
|
||||||
if (!stream) return Status::OK();
|
if (!stream) return Status::OK();
|
||||||
|
|
||||||
if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
|
Status status = stream->BlockHostUntilDone();
|
||||||
|
{
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
while (outstanding_asynchronous_operations_ > 0) {
|
||||||
|
outstanding_asynchronous_operations_cv_.wait(lock);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
if (!stream->ok()) {
|
||||||
return errors::Internal("XlaDevice::Sync() failed.");
|
return errors::Internal("XlaDevice::Sync() failed.");
|
||||||
}
|
}
|
||||||
VLOG(1) << "XlaDevice::Sync completed";
|
VLOG(1) << "XlaDevice::Sync completed";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void XlaDevice::Sync(const DoneCallback& done) {
|
||||||
|
VLOG(1) << "XlaDevice::Sync (asynchronous)";
|
||||||
|
std::shared_ptr<se::Stream> stream;
|
||||||
|
{
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
stream = stream_;
|
||||||
|
}
|
||||||
|
if (!stream) {
|
||||||
|
done(Status::OK());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream->ThenEnqueueOnBackgroundThread(
|
||||||
|
[this, stream, done](se::StreamExecutor*) {
|
||||||
|
tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
|
||||||
|
/*is_expensive=*/true);
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
while (outstanding_asynchronous_operations_ > 0) {
|
||||||
|
outstanding_asynchronous_operations_cv_.wait(lock);
|
||||||
|
}
|
||||||
|
done(stream->ok() ? Status::OK()
|
||||||
|
: errors::Internal("XlaDevice::Sync() failed."));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||||
const AllocatorAttributes alloc_attrs,
|
const AllocatorAttributes alloc_attrs,
|
||||||
Tensor* tensor) {
|
Tensor* tensor) {
|
||||||
@ -441,6 +478,49 @@ bool XlaDevice::RequiresSyncOnCompletion() const {
|
|||||||
return sync_on_completion_;
|
return sync_on_completion_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
|
||||||
|
XlaDevice* device)
|
||||||
|
: device_(device) {
|
||||||
|
mutex_lock lock(device_->mu_);
|
||||||
|
++device_->outstanding_asynchronous_operations_;
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() {
|
||||||
|
if (device_) {
|
||||||
|
mutex_lock lock(device_->mu_);
|
||||||
|
--device_->outstanding_asynchronous_operations_;
|
||||||
|
device_->outstanding_asynchronous_operations_cv_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
|
||||||
|
const XlaDevice::AsynchronousOperationHandle& other)
|
||||||
|
: device_(other.device_) {
|
||||||
|
mutex_lock lock(device_->mu_);
|
||||||
|
++device_->outstanding_asynchronous_operations_;
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
|
||||||
|
XlaDevice::AsynchronousOperationHandle&& other)
|
||||||
|
: device_(other.device_) {
|
||||||
|
other.device_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
|
||||||
|
operator=(const XlaDevice::AsynchronousOperationHandle& other) {
|
||||||
|
device_ = other.device_;
|
||||||
|
mutex_lock lock(device_->mu_);
|
||||||
|
++device_->outstanding_asynchronous_operations_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
|
||||||
|
operator=(XlaDevice::AsynchronousOperationHandle&& other) {
|
||||||
|
device_ = other.device_;
|
||||||
|
other.device_ = nullptr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||||
const char* jit_device) {
|
const char* jit_device) {
|
||||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||||
|
@ -135,6 +135,7 @@ class XlaDevice : public LocalDevice {
|
|||||||
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
|
||||||
AsyncOpKernel::DoneCallback done) override;
|
AsyncOpKernel::DoneCallback done) override;
|
||||||
Status Sync() override;
|
Status Sync() override;
|
||||||
|
void Sync(const DoneCallback& done) override;
|
||||||
|
|
||||||
Status FillContextMap(const Graph* graph,
|
Status FillContextMap(const Graph* graph,
|
||||||
DeviceContextMap* device_context_map) override
|
DeviceContextMap* device_context_map) override
|
||||||
@ -164,7 +165,30 @@ class XlaDevice : public LocalDevice {
|
|||||||
|
|
||||||
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
|
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
|
// A simple RAII handle. On construction the device's
|
||||||
|
// outstanding_asynchronous_operations_ field is incremented; on destruction
|
||||||
|
// it is decremented.
|
||||||
|
class AsynchronousOperationHandle {
|
||||||
|
public:
|
||||||
|
AsynchronousOperationHandle(XlaDevice* device);
|
||||||
|
~AsynchronousOperationHandle();
|
||||||
|
AsynchronousOperationHandle(const AsynchronousOperationHandle& other);
|
||||||
|
AsynchronousOperationHandle(AsynchronousOperationHandle&& other);
|
||||||
|
AsynchronousOperationHandle& operator=(
|
||||||
|
const AsynchronousOperationHandle& other);
|
||||||
|
AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other);
|
||||||
|
|
||||||
|
private:
|
||||||
|
XlaDevice* device_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
AsynchronousOperationHandle CreateAsynchronousOperationHandle() {
|
||||||
|
return AsynchronousOperationHandle(this);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class AsynchronousOperationHandle;
|
||||||
|
|
||||||
xla::LocalClient* client() const;
|
xla::LocalClient* client() const;
|
||||||
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
@ -227,6 +251,11 @@ class XlaDevice : public LocalDevice {
|
|||||||
// True if the device requires XlaDevice::Sync to be called on completion
|
// True if the device requires XlaDevice::Sync to be called on completion
|
||||||
// regardless of status.
|
// regardless of status.
|
||||||
bool sync_on_completion_ GUARDED_BY(mu_) = false;
|
bool sync_on_completion_ GUARDED_BY(mu_) = false;
|
||||||
|
|
||||||
|
// Count of outstanding asynchronous operations which must be zero on Sync()
|
||||||
|
// completion.
|
||||||
|
int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0;
|
||||||
|
condition_variable outstanding_asynchronous_operations_cv_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Builds OpKernel registrations on 'device' for the JIT operators
|
// Builds OpKernel registrations on 'device' for the JIT operators
|
||||||
|
@ -29,12 +29,12 @@ namespace tensorflow {
|
|||||||
class XlaGpuDeviceFactory : public DeviceFactory {
|
class XlaGpuDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
Status XlaGpuDeviceFactory::CreateDevices(
|
||||||
const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
XlaOpRegistry::DeviceRegistration registration;
|
XlaOpRegistry::DeviceRegistration registration;
|
||||||
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
|
||||||
registration.autoclustering_policy =
|
registration.autoclustering_policy =
|
||||||
@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
|
|||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
devices->push_back(device.release());
|
devices->push_back(std::move(device));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -33,12 +33,12 @@ constexpr std::array<DataType, 9> kExecAllTypes = {
|
|||||||
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
class XlaInterpreterDeviceFactory : public DeviceFactory {
|
||||||
public:
|
public:
|
||||||
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) override;
|
std::vector<std::unique_ptr<Device>>* devices) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status XlaInterpreterDeviceFactory::CreateDevices(
|
Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||||
const SessionOptions& session_options, const string& name_prefix,
|
const SessionOptions& session_options, const string& name_prefix,
|
||||||
std::vector<Device*>* devices) {
|
std::vector<std::unique_ptr<Device>>* devices) {
|
||||||
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
|
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
|
||||||
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
|
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
|
||||||
(void)registrations;
|
(void)registrations;
|
||||||
@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
|||||||
options.device_ordinal = 0;
|
options.device_ordinal = 0;
|
||||||
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
|
||||||
options.use_multiple_streams = false;
|
options.use_multiple_streams = false;
|
||||||
auto device = absl::make_unique<XlaDevice>(session_options, options);
|
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
|
||||||
devices->push_back(device.release());
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -375,27 +375,6 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_xla_py_test(
|
|
||||||
name = "resampler_ops_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["resampler_ops_test.py"],
|
|
||||||
disabled_backends = [
|
|
||||||
# TODO(b/74459949) Support BatchDot in CPU backend.
|
|
||||||
"cpu",
|
|
||||||
"cpu_ondemand",
|
|
||||||
],
|
|
||||||
# TODO(b/112295522): figure out how to make OSS build pass.
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":xla_test",
|
|
||||||
"//tensorflow/contrib/resampler:resampler_ops",
|
|
||||||
"//tensorflow/contrib/resampler:resampler_py",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_xla_py_test(
|
tf_xla_py_test(
|
||||||
name = "dynamic_stitch_test",
|
name = "dynamic_stitch_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -474,7 +453,6 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/python:extra_py_tests_deps",
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:spectral_ops",
|
|
||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run a step of AdagradDA
|
# Run a step of AdagradDA
|
||||||
update.run()
|
update.run()
|
||||||
@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
|||||||
# For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
|
# For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
|
||||||
# similarly for others.
|
# similarly for others.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.904534, -1.603567]), var0.eval())
|
np.array([-0.904534, -1.603567]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.094821, -0.189358]), var1.eval())
|
np.array([-0.094821, -0.189358]), self.evaluate(var1))
|
||||||
|
|
||||||
def testAdagradDAwithoutRegularizationBasic2(self):
|
def testAdagradDAwithoutRegularizationBasic2(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run a step of AdagradDA
|
# Run a step of AdagradDA
|
||||||
update.run()
|
update.run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.904534, -1.603567]), var0.eval())
|
np.array([-0.904534, -1.603567]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.094821, -0.189358]), var1.eval())
|
np.array([-0.094821, -0.189358]), self.evaluate(var1))
|
||||||
|
|
||||||
def testAdagradDAWithL1(self):
|
def testAdagradDAWithL1(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run a step of AdagradDA
|
# Run a step of AdagradDA
|
||||||
update.run()
|
update.run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.895489, -1.59555]), var0.eval())
|
np.array([-0.895489, -1.59555]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.085339, -0.17989]), var1.eval())
|
np.array([-0.085339, -0.17989]), self.evaluate(var1))
|
||||||
|
|
||||||
def testAdagradDAWithL1_L2(self):
|
def testAdagradDAWithL1_L2(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run a step of AdagradDA
|
# Run a step of AdagradDA
|
||||||
update.run()
|
update.run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.046907, -0.093659]), var0.eval())
|
np.array([-0.046907, -0.093659]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.004275, -0.009023]), var1.eval())
|
np.array([-0.004275, -0.009023]), self.evaluate(var1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]))
|
zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 3 steps of adagrad
|
# Run 3 steps of adagrad
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
ada_update.run()
|
ada_update.run()
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
|
np.array([-1.6026098728179932, -0.6026098728179932]),
|
||||||
|
self.evaluate(var0),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
|
np.array([2.715679168701172, 3.715679168701172]),
|
||||||
|
self.evaluate(var1),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
|
|
||||||
def testTensorLearningRate(self):
|
def testTensorLearningRate(self):
|
||||||
@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
|||||||
zip([grads0, grads1], [var0, var1]))
|
zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Run 3 steps of adagrad
|
# Run 3 steps of adagrad
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
ada_update.run()
|
ada_update.run()
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
|
np.array([-1.6026098728179932, -0.6026098728179932]),
|
||||||
|
self.evaluate(var0),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
|
np.array([2.715679168701172, 3.715679168701172]),
|
||||||
|
self.evaluate(var1),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
|
|
||||||
def testSharing(self):
|
def testSharing(self):
|
||||||
@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
# Fetch params to validate initial values.
|
# Fetch params to validate initial values.
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Mix the first and the second adagrad for 3 steps.
|
# Mix the first and the second adagrad for 3 steps.
|
||||||
ada_update1.run()
|
ada_update1.run()
|
||||||
ada_update2.run()
|
ada_update2.run()
|
||||||
ada_update1.run()
|
ada_update1.run()
|
||||||
# Validate updated params (the same as with only 1 Adagrad).
|
# Validate updated params (the same as with only 1 Adagrad).
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(),
|
np.array([-1.6026098728179932, -0.6026098728179932]),
|
||||||
|
self.evaluate(var0),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([2.715679168701172, 3.715679168701172]), var1.eval(),
|
np.array([2.715679168701172, 3.715679168701172]),
|
||||||
|
self.evaluate(var1),
|
||||||
float_rtol=1e-5)
|
float_rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||||
|
|
||||||
# Run 3 steps of Adam
|
# Run 3 steps of Adam
|
||||||
for t in range(1, 4):
|
for t in range(1, 4):
|
||||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
self.assertAllCloseAccordingToType(0.999**t,
|
||||||
|
self.evaluate(beta2_power))
|
||||||
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
||||||
|
|
||||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
def testTensorLearningRate(self):
|
def testTensorLearningRate(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||||
|
|
||||||
# Run 3 steps of Adam
|
# Run 3 steps of Adam
|
||||||
for t in range(1, 4):
|
for t in range(1, 4):
|
||||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
self.assertAllCloseAccordingToType(0.999**t,
|
||||||
|
self.evaluate(beta2_power))
|
||||||
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
||||||
|
|
||||||
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
def testSharing(self):
|
def testSharing(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
|||||||
beta1_power, beta2_power = opt._get_beta_accumulators()
|
beta1_power, beta2_power = opt._get_beta_accumulators()
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 3 steps of intertwined Adam1 and Adam2.
|
# Run 3 steps of intertwined Adam1 and Adam2.
|
||||||
for t in range(1, 4):
|
for t in range(1, 4):
|
||||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||||
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
|
self.assertAllCloseAccordingToType(0.999**t,
|
||||||
|
self.evaluate(beta2_power))
|
||||||
if t % 2 == 0:
|
if t % 2 == 0:
|
||||||
update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
|
||||||
else:
|
else:
|
||||||
@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase):
|
|||||||
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
beta1_power = opt._get_beta_accumulators()
|
beta1_power = opt._get_beta_accumulators()
|
||||||
|
|
||||||
@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
|
|||||||
for t in range(1, 4):
|
for t in range(1, 4):
|
||||||
update.run()
|
update.run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
|
self.assertAllCloseAccordingToType(0.9**(t + 1),
|
||||||
|
self.evaluate(beta1_power))
|
||||||
|
|
||||||
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
|
self.assertAllCloseAccordingToType(
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
|
var0_np, self.evaluate(var0), rtol=1e-2)
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
var1_np, self.evaluate(var1), rtol=1e-2)
|
||||||
self.assertEqual("var0_%d/AdaMax:0" % (i,),
|
self.assertEqual("var0_%d/AdaMax:0" % (i,),
|
||||||
opt.get_slot(var=var0, name="m").name)
|
opt.get_slot(var=var0, name="m").name)
|
||||||
|
|
||||||
@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
beta1_power = opt._get_beta_accumulators()
|
beta1_power = opt._get_beta_accumulators()
|
||||||
|
|
||||||
# Run 3 steps of AdaMax
|
# Run 3 steps of AdaMax
|
||||||
for t in range(1, 4):
|
for t in range(1, 4):
|
||||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
|
||||||
update.run()
|
update.run()
|
||||||
|
|
||||||
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||||
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase):
|
|||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 7 steps of AddSign
|
# Run 7 steps of AddSign
|
||||||
# first 4 steps with positive gradient
|
# first 4 steps with positive gradient
|
||||||
@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
var0_np, var0.eval(), half_rtol=1e-2)
|
var0_np, self.evaluate(var0), half_rtol=1e-2)
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
def testDense(self):
|
def testDense(self):
|
||||||
decay_steps = 10
|
decay_steps = 10
|
||||||
|
@ -218,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
|||||||
],
|
],
|
||||||
equality_test=self.ListsAreClose)
|
equality_test=self.ListsAreClose)
|
||||||
|
|
||||||
|
# TF doesn't define these for bf16.
|
||||||
|
if dtype != dtypes.bfloat16.as_numpy_dtype:
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops.xdivy,
|
||||||
|
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||||
|
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||||
|
expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype))
|
||||||
|
|
||||||
|
self._testBinary(
|
||||||
|
gen_math_ops.xlogy,
|
||||||
|
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||||
|
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||||
|
expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
def testIntOps(self):
|
def testIntOps(self):
|
||||||
for dtype in self.signed_int_types:
|
for dtype in self.signed_int_types:
|
||||||
self._testBinary(
|
self._testBinary(
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
Returns:
|
Returns:
|
||||||
Frequencies from sampled classes; shape [batch_size, num_classes].
|
Frequencies from sampled classes; shape [batch_size, num_classes].
|
||||||
"""
|
"""
|
||||||
with self.cached_session() as sess, self.test_scope():
|
with self.cached_session(), self.test_scope():
|
||||||
random_seed.set_random_seed(1618)
|
random_seed.set_random_seed(1618)
|
||||||
op = random_ops.multinomial(logits, num_samples,
|
op = random_ops.multinomial(logits, num_samples,
|
||||||
output_dtype=dtypes.int32)
|
output_dtype=dtypes.int32)
|
||||||
d = sess.run(op)
|
d = self.evaluate(op)
|
||||||
|
|
||||||
batch_size, num_classes = logits.shape
|
batch_size, num_classes = logits.shape
|
||||||
freqs_mat = []
|
freqs_mat = []
|
||||||
@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
|
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
|
||||||
# Tests that 'rng' does not always return the same value.
|
# Tests that 'rng' does not always return the same value.
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = rng(dtype, output_dtype)
|
x = rng(dtype, output_dtype)
|
||||||
|
|
||||||
# The random-number generator, if working correctly, should produce the
|
# The random-number generator, if working correctly, should produce the
|
||||||
# same output multiple times with low probability.
|
# same output multiple times with low probability.
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
z = sess.run(x)
|
z = self.evaluate(x)
|
||||||
w = sess.run(x)
|
w = self.evaluate(x)
|
||||||
|
|
||||||
# We use exact equality here. If the random-number generator is producing
|
# We use exact equality here. If the random-number generator is producing
|
||||||
# deterministic output, all three outputs will be bitwise identical.
|
# deterministic output, all three outputs will be bitwise identical.
|
||||||
@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
def testCategoricalIsInRange(self):
|
def testCategoricalIsInRange(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
for output_dtype in self.output_dtypes():
|
for output_dtype in self.output_dtypes():
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
x = random_ops.multinomial(
|
x = random_ops.multinomial(
|
||||||
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
|
||||||
output_dtype=output_dtype)
|
output_dtype=output_dtype)
|
||||||
y = sess.run(x)
|
y = self.evaluate(x)
|
||||||
self.assertTrue((y >= 0).sum() == 1000)
|
self.assertTrue((y >= 0).sum() == 1000)
|
||||||
self.assertTrue((y < 20).sum() == 1000)
|
self.assertTrue((y < 20).sum() == 1000)
|
||||||
|
|
||||||
@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase):
|
|||||||
chi2 = self._chi2(probs, freqs)
|
chi2 = self._chi2(probs, freqs)
|
||||||
self.assertLess(chi2, 1e-3)
|
self.assertLess(chi2, 1e-3)
|
||||||
|
|
||||||
|
def testStatelessMultinomialIsInRange(self):
|
||||||
|
for dtype in self.float_types:
|
||||||
|
for output_dtype in self.output_dtypes():
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
with self.test_scope():
|
||||||
|
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||||
|
x = stateless_random_ops.stateless_multinomial(
|
||||||
|
array_ops.ones(shape=[1, 20], dtype=dtype),
|
||||||
|
1000,
|
||||||
|
seed_t,
|
||||||
|
output_dtype=output_dtype)
|
||||||
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
||||||
|
self.assertTrue((y >= 0).sum() == 1000)
|
||||||
|
self.assertTrue((y < 20).sum() == 1000)
|
||||||
|
|
||||||
|
def testDeterminismMultinomial(self):
|
||||||
|
# Stateless values should be equal iff the seeds are equal (roughly)
|
||||||
|
num_samples = 10
|
||||||
|
with self.cached_session(), self.test_scope():
|
||||||
|
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||||
|
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
|
||||||
|
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
|
||||||
|
[0.25, 0.75]]):
|
||||||
|
pure = stateless_random_ops.stateless_multinomial(
|
||||||
|
logits, num_samples, seed=seed_t)
|
||||||
|
values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
|
||||||
|
for s0, v0 in values:
|
||||||
|
for s1, v1 in values:
|
||||||
|
self.assertEqual(s0 == s1, np.all(v0 == v1))
|
||||||
|
|
||||||
|
def testEmpty(self):
|
||||||
|
with self.cached_session():
|
||||||
|
with self.test_scope():
|
||||||
|
x = random_ops.multinomial(
|
||||||
|
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
|
||||||
|
y = self.evaluate(x)
|
||||||
|
self.assertEqual(y.shape, (42, 0))
|
||||||
|
|
||||||
|
def testEmptyStateless(self):
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
with self.test_scope():
|
||||||
|
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
|
||||||
|
x = stateless_random_ops.stateless_multinomial(
|
||||||
|
array_ops.zeros([42, 40]),
|
||||||
|
0,
|
||||||
|
seed=seed_t,
|
||||||
|
output_dtype=dtypes.int32)
|
||||||
|
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
|
||||||
|
self.assertEqual(y.shape, (42, 0))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase):
|
|||||||
input1 = constant_op.constant(val1, name="const1")
|
input1 = constant_op.constant(val1, name="const1")
|
||||||
input2 = constant_op.constant(val2, name="const2")
|
input2 = constant_op.constant(val2, name="const2")
|
||||||
output = math_ops.add(input1, input2)
|
output = math_ops.add(input1, input2)
|
||||||
result = output.eval()
|
result = self.evaluate(output)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testAddFromCpuMultiple(self):
|
def testAddFromCpuMultiple(self):
|
||||||
@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
output = math_ops.add(input1, input2)
|
output = math_ops.add(input1, input2)
|
||||||
for _ in xrange(10):
|
for _ in xrange(10):
|
||||||
result = output.eval()
|
result = self.evaluate(output)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testDeadlock(self):
|
def testDeadlock(self):
|
||||||
|
@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
x2 = constant_op.constant(p2)
|
x2 = constant_op.constant(p2)
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
c = array_ops.concat([x1, x2], 0)
|
c = array_ops.concat([x1, x2], 0)
|
||||||
result = c.eval()
|
result = self.evaluate(c)
|
||||||
self.assertAllEqual(result[:2, :], p1)
|
self.assertAllEqual(result[:2, :], p1)
|
||||||
self.assertAllEqual(result[2:, :], p2)
|
self.assertAllEqual(result[2:, :], p2)
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||||
concated_grad = array_ops.concat(grad, 1)
|
concated_grad = array_ops.concat(grad, 1)
|
||||||
result = concated_grad.eval()
|
result = self.evaluate(concated_grad)
|
||||||
self.assertAllEqual(result, grad_inp)
|
self.assertAllEqual(result, grad_inp)
|
||||||
|
|
||||||
def testGradientsSimpleAll(self):
|
def testGradientsSimpleAll(self):
|
||||||
@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||||
concated_grad = array_ops.concat(grad, 0)
|
concated_grad = array_ops.concat(grad, 0)
|
||||||
result = concated_grad.eval()
|
result = self.evaluate(concated_grad)
|
||||||
|
|
||||||
self.assertAllEqual(result, grad_inp)
|
self.assertAllEqual(result, grad_inp)
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||||
concated_grad = array_ops.concat(grad, 2)
|
concated_grad = array_ops.concat(grad, 2)
|
||||||
result = concated_grad.eval()
|
result = self.evaluate(concated_grad)
|
||||||
|
|
||||||
self.assertAllEqual(result, grad_inp)
|
self.assertAllEqual(result, grad_inp)
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
[float(x) for x in grad_inp.flatten()], shape=output_shape)
|
||||||
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
|
||||||
concated_grad = array_ops.concat(grad, concat_dim)
|
concated_grad = array_ops.concat(grad, concat_dim)
|
||||||
result = concated_grad.eval()
|
result = self.evaluate(concated_grad)
|
||||||
|
|
||||||
self.assertAllEqual(result, grad_inp)
|
self.assertAllEqual(result, grad_inp)
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
def DISABLED_testZeroSize(self):
|
def DISABLED_testZeroSize(self):
|
||||||
# Verify that concat doesn't crash and burn for zero size inputs
|
# Verify that concat doesn't crash and burn for zero size inputs
|
||||||
np.random.seed(7)
|
np.random.seed(7)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
for shape0 in (), (2,):
|
for shape0 in (), (2,):
|
||||||
axis = len(shape0)
|
axis = len(shape0)
|
||||||
@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual(c.eval(), correct)
|
self.assertAllEqual(c.eval(), correct)
|
||||||
# Check gradients
|
# Check gradients
|
||||||
dc = np.random.randn(*c.get_shape().as_list())
|
dc = np.random.randn(*c.get_shape().as_list())
|
||||||
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
|
dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
|
||||||
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
|
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
|
||||||
|
|
||||||
def testConcatTuple(self):
|
def testConcatTuple(self):
|
||||||
@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
concat_list_t = array_ops.concat([c1, c2], 0)
|
concat_list_t = array_ops.concat([c1, c2], 0)
|
||||||
concat_tuple_t = array_ops.concat((c1, c2), 0)
|
concat_tuple_t = array_ops.concat((c1, c2), 0)
|
||||||
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
|
self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t))
|
||||||
|
|
||||||
def testConcatNoScalars(self):
|
def testConcatNoScalars(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase):
|
|||||||
class ConcatOffsetTest(xla_test.XLATestCase):
|
class ConcatOffsetTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
cdim = constant_op.constant(1, dtypes.int32)
|
cdim = constant_op.constant(1, dtypes.int32)
|
||||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||||
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
|
||||||
ans = sess.run(off)
|
ans = self.evaluate(off)
|
||||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||||
|
|
||||||
|
|
||||||
class PackTest(xla_test.XLATestCase):
|
class PackTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
|
||||||
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
|
||||||
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
|
||||||
|
|
||||||
def testScalars(self):
|
def testScalars(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
s0 = constant_op.constant(2, dtypes.int32)
|
s0 = constant_op.constant(2, dtypes.int32)
|
||||||
s1 = constant_op.constant(3, dtypes.int32)
|
s1 = constant_op.constant(3, dtypes.int32)
|
||||||
s2 = constant_op.constant(5, dtypes.int32)
|
s2 = constant_op.constant(5, dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [2, 3, 5])
|
self.assertAllEqual(ans, [2, 3, 5])
|
||||||
|
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
s0 = constant_op.constant([[]], dtypes.int32)
|
s0 = constant_op.constant([[]], dtypes.int32)
|
||||||
s1 = constant_op.constant([[]], dtypes.int32)
|
s1 = constant_op.constant([[]], dtypes.int32)
|
||||||
s2 = constant_op.constant([[]], dtypes.int32)
|
s2 = constant_op.constant([[]], dtypes.int32)
|
||||||
packed = array_ops.stack([s0, s1, s2])
|
packed = array_ops.stack([s0, s1, s2])
|
||||||
ans = sess.run(packed)
|
ans = self.evaluate(packed)
|
||||||
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
self.assertAllEqual(ans, [[[]], [[]], [[]]])
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
|
|||||||
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
output = nn_ops.conv3d_transpose(
|
output = nn_ops.conv3d_transpose(
|
||||||
x, f, y_shape, strides=strides, padding="SAME")
|
x, f, y_shape, strides=strides, padding="SAME")
|
||||||
value = output.eval()
|
value = self.evaluate(output)
|
||||||
|
|
||||||
# We count the number of cells being added at the locations in the output.
|
# We count the number of cells being added at the locations in the output.
|
||||||
# At the center, #cells = kernel_depth * kernel_height * kernel_width
|
# At the center, #cells = kernel_depth * kernel_height * kernel_width
|
||||||
@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
|
|||||||
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
output = nn_ops.conv3d_transpose(
|
output = nn_ops.conv3d_transpose(
|
||||||
x, f, y_shape, strides=strides, padding="SAME")
|
x, f, y_shape, strides=strides, padding="SAME")
|
||||||
value = output.eval()
|
value = self.evaluate(output)
|
||||||
|
|
||||||
for n in xrange(x_shape[0]):
|
for n in xrange(x_shape[0]):
|
||||||
for k in xrange(f_shape[3]):
|
for k in xrange(f_shape[3]):
|
||||||
@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
|
|||||||
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
|
||||||
output = nn_ops.conv3d_transpose(
|
output = nn_ops.conv3d_transpose(
|
||||||
x, f, y_shape, strides=strides, padding="VALID")
|
x, f, y_shape, strides=strides, padding="VALID")
|
||||||
value = output.eval()
|
value = self.evaluate(output)
|
||||||
|
|
||||||
cache_values = np.zeros(y_shape, dtype=np.float32)
|
cache_values = np.zeros(y_shape, dtype=np.float32)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata):
|
|||||||
|
|
||||||
def InLabels(labels, substr):
|
def InLabels(labels, substr):
|
||||||
"""Returns true iff one of the labels contains substr."""
|
"""Returns true iff one of the labels contains substr."""
|
||||||
return any([substr in x for x in labels])
|
return any(substr in x for x in labels)
|
||||||
|
|
||||||
|
|
||||||
class DenseLayerTest(test.TestCase):
|
class DenseLayerTest(test.TestCase):
|
||||||
@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase):
|
|||||||
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
|
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
|
||||||
y = layers.dense(x, 3)
|
y = layers.dense(x, 3)
|
||||||
|
|
||||||
sess.run(variables.initialize_all_variables())
|
self.evaluate(variables.initialize_all_variables())
|
||||||
run_metadata = config_pb2.RunMetadata()
|
run_metadata = config_pb2.RunMetadata()
|
||||||
test_utils.RunWithWarmup(
|
test_utils.RunWithWarmup(
|
||||||
sess,
|
sess,
|
||||||
@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase):
|
|||||||
with jit_scope():
|
with jit_scope():
|
||||||
y = layers.dense(x, 3)
|
y = layers.dense(x, 3)
|
||||||
|
|
||||||
sess.run(variables.initialize_all_variables())
|
self.evaluate(variables.initialize_all_variables())
|
||||||
run_metadata = config_pb2.RunMetadata()
|
run_metadata = config_pb2.RunMetadata()
|
||||||
test_utils.RunWithWarmup(
|
test_utils.RunWithWarmup(
|
||||||
sess,
|
sess,
|
||||||
@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase):
|
|||||||
with jit_scope():
|
with jit_scope():
|
||||||
y = layers.dense(x, 3)
|
y = layers.dense(x, 3)
|
||||||
|
|
||||||
sess.run(variables.initialize_all_variables())
|
self.evaluate(variables.initialize_all_variables())
|
||||||
run_metadata = config_pb2.RunMetadata()
|
run_metadata = config_pb2.RunMetadata()
|
||||||
test_utils.RunWithWarmup(
|
test_utils.RunWithWarmup(
|
||||||
sess,
|
sess,
|
||||||
|
@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase):
|
|||||||
[idx1, idx2], [val1, val2],
|
[idx1, idx2], [val1, val2],
|
||||||
expected=np.array([[], [], [], []], np.int32))
|
expected=np.array([[], [], [], []], np.int32))
|
||||||
|
|
||||||
|
def testEmptyIndex(self):
|
||||||
|
idx1 = np.array([], dtype=np.int32)
|
||||||
|
idx2 = np.array([[], []], dtype=np.int32)
|
||||||
|
val1 = np.ndarray(shape=(0, 9), dtype=np.int32)
|
||||||
|
val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32)
|
||||||
|
self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2],
|
||||||
|
expected=np.ndarray(
|
||||||
|
shape=(0, 9), dtype=np.int32))
|
||||||
|
|
||||||
def testSimple1D(self):
|
def testSimple1D(self):
|
||||||
val1 = np.array([0, 4, 7], dtype=np.int32)
|
val1 = np.array([0, 4, 7], dtype=np.int32)
|
||||||
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
|
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
|
||||||
|
@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual(15, product)
|
self.assertAllEqual(15, product)
|
||||||
|
|
||||||
# Run some ops graphly
|
# Run some ops graphly
|
||||||
with context.graph_mode(), self.cached_session() as sess:
|
with context.graph_mode(), self.cached_session():
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
three = constant_op.constant(3)
|
three = constant_op.constant(3)
|
||||||
five = constant_op.constant(5)
|
five = constant_op.constant(5)
|
||||||
product = three * five
|
product = three * five
|
||||||
self.assertAllEqual(15, sess.run(product))
|
self.assertAllEqual(15, self.evaluate(product))
|
||||||
|
|
||||||
def testDegenerateSlices(self):
|
def testDegenerateSlices(self):
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
@ -27,8 +27,7 @@ from tensorflow.compiler.tests import xla_test
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import signal
|
from tensorflow.python.ops.signal import signal
|
||||||
from tensorflow.python.ops import spectral_ops
|
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
BATCH_DIMS = (3, 5)
|
BATCH_DIMS = (3, 5)
|
||||||
@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
def testFFT(self):
|
def testFFT(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
|
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
|
||||||
spectral_ops.fft)
|
signal.fft)
|
||||||
|
|
||||||
def testFFT2D(self):
|
def testFFT2D(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
|
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
|
||||||
spectral_ops.fft2d)
|
signal.fft2d)
|
||||||
|
|
||||||
def testFFT3D(self):
|
def testFFT3D(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
|
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
|
||||||
lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
|
lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
|
||||||
spectral_ops.fft3d)
|
signal.fft3d)
|
||||||
|
|
||||||
def testIFFT(self):
|
def testIFFT(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
|
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
|
||||||
spectral_ops.ifft)
|
signal.ifft)
|
||||||
|
|
||||||
def testIFFT2D(self):
|
def testIFFT2D(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
|
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
|
||||||
spectral_ops.ifft2d)
|
signal.ifft2d)
|
||||||
|
|
||||||
def testIFFT3D(self):
|
def testIFFT3D(self):
|
||||||
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
|
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
|
||||||
lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
|
lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
|
||||||
spectral_ops.ifft3d)
|
signal.ifft3d)
|
||||||
|
|
||||||
def testRFFT(self):
|
def testRFFT(self):
|
||||||
self._VerifyFftMethod(
|
self._VerifyFftMethod(
|
||||||
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
|
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
|
||||||
lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value]))
|
lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
|
||||||
|
|
||||||
def testRFFT2D(self):
|
def testRFFT2D(self):
|
||||||
|
|
||||||
def _tf_fn(x):
|
def _tf_fn(x):
|
||||||
return spectral_ops.rfft2d(
|
return signal.rfft2d(
|
||||||
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
|
x, fft_length=[x.shape[-2].value, x.shape[-1].value])
|
||||||
|
|
||||||
self._VerifyFftMethod(
|
self._VerifyFftMethod(
|
||||||
@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase):
|
|||||||
x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
|
x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
|
||||||
|
|
||||||
def _tf_fn(x):
|
def _tf_fn(x):
|
||||||
return spectral_ops.rfft3d(
|
return signal.rfft3d(
|
||||||
x,
|
x,
|
||||||
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
|
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
|
||||||
|
|
||||||
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||||
|
|
||||||
|
def testRFFT3DMismatchedSize(self):
|
||||||
|
|
||||||
|
def _to_expected(x):
|
||||||
|
return np.fft.rfftn(
|
||||||
|
x,
|
||||||
|
axes=(-3, -2, -1),
|
||||||
|
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||||
|
|
||||||
|
def _tf_fn(x):
|
||||||
|
return signal.rfft3d(
|
||||||
|
x,
|
||||||
|
fft_length=[
|
||||||
|
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||||
|
])
|
||||||
|
|
||||||
|
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
|
||||||
|
|
||||||
def testIRFFT(self):
|
def testIRFFT(self):
|
||||||
|
|
||||||
def _tf_fn(x):
|
def _tf_fn(x):
|
||||||
return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
|
return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
|
||||||
|
|
||||||
self._VerifyFftMethod(
|
self._VerifyFftMethod(
|
||||||
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
|
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
|
||||||
@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase):
|
|||||||
def testIRFFT2D(self):
|
def testIRFFT2D(self):
|
||||||
|
|
||||||
def _tf_fn(x):
|
def _tf_fn(x):
|
||||||
return spectral_ops.irfft2d(
|
return signal.irfft2d(
|
||||||
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
|
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
|
||||||
|
|
||||||
self._VerifyFftMethod(
|
self._VerifyFftMethod(
|
||||||
@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase):
|
|||||||
s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
|
s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
|
||||||
|
|
||||||
def _tf_fn(x):
|
def _tf_fn(x):
|
||||||
return spectral_ops.irfft3d(
|
return signal.irfft3d(
|
||||||
x,
|
x,
|
||||||
fft_length=[
|
fft_length=[
|
||||||
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
|
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
|
||||||
@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||||
|
|
||||||
|
def testIRFFT3DMismatchedSize(self):
|
||||||
|
|
||||||
|
def _to_input(x):
|
||||||
|
return np.fft.rfftn(
|
||||||
|
np.real(x),
|
||||||
|
axes=(-3, -2, -1),
|
||||||
|
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||||
|
|
||||||
|
def _to_expected(x):
|
||||||
|
return np.fft.irfftn(
|
||||||
|
x,
|
||||||
|
axes=(-3, -2, -1),
|
||||||
|
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
|
||||||
|
|
||||||
|
def _tf_fn(x):
|
||||||
|
return signal.irfft3d(
|
||||||
|
x,
|
||||||
|
fft_length=[
|
||||||
|
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
|
||||||
|
])
|
||||||
|
|
||||||
|
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
|
|||||||
enqueue_op.run()
|
enqueue_op.run()
|
||||||
|
|
||||||
for i in xrange(len(elems)):
|
for i in xrange(len(elems)):
|
||||||
vals = dequeued_t.eval()
|
vals = self.evaluate(dequeued_t)
|
||||||
self.assertEqual([elems[i]], vals)
|
self.assertEqual([elems[i]], vals)
|
||||||
|
|
||||||
def testEnqueueAndBlockingDequeue(self):
|
def testEnqueueAndBlockingDequeue(self):
|
||||||
@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase):
|
|||||||
self.assertEqual([], size.get_shape())
|
self.assertEqual([], size.get_shape())
|
||||||
|
|
||||||
enqueue_op.run()
|
enqueue_op.run()
|
||||||
self.assertEqual(1, size.eval())
|
self.assertEqual(1, self.evaluate(size))
|
||||||
dequeued_t.op.run()
|
dequeued_t.op.run()
|
||||||
self.assertEqual(0, size.eval())
|
self.assertEqual(0, self.evaluate(size))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run Ftrl for a few steps
|
# Run Ftrl for a few steps
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
ftrl_update.run()
|
ftrl_update.run()
|
||||||
|
|
||||||
return var0.eval(), var1.eval()
|
return self.evaluate(var0), self.evaluate(var1)
|
||||||
|
|
||||||
def equivAdagradTest_AdagradPart(self, steps, dtype):
|
def equivAdagradTest_AdagradPart(self, steps, dtype):
|
||||||
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
||||||
@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run Adagrad for a few steps
|
# Run Adagrad for a few steps
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
adagrad_update.run()
|
adagrad_update.run()
|
||||||
|
|
||||||
return var0.eval(), var1.eval()
|
return self.evaluate(var0), self.evaluate(var1)
|
||||||
|
|
||||||
def equivGradientDescentTest_FtrlPart(self, steps, dtype):
|
def equivGradientDescentTest_FtrlPart(self, steps, dtype):
|
||||||
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
||||||
@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run Ftrl for a few steps
|
# Run Ftrl for a few steps
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
ftrl_update.run()
|
ftrl_update.run()
|
||||||
|
|
||||||
return var0.eval(), var1.eval()
|
return self.evaluate(var0), self.evaluate(var1)
|
||||||
|
|
||||||
def equivGradientDescentTest_GradientDescentPart(self, steps, dtype):
|
def equivGradientDescentTest_GradientDescentPart(self, steps, dtype):
|
||||||
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
|
||||||
@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run GradientDescent for a few steps
|
# Run GradientDescent for a few steps
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
sgd_update.run()
|
sgd_update.run()
|
||||||
|
|
||||||
return var0.eval(), var1.eval()
|
return self.evaluate(var0), self.evaluate(var1)
|
||||||
|
|
||||||
def testFtrlwithoutRegularization(self):
|
def testFtrlwithoutRegularization(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var0))
|
||||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
self.assertAllClose([0.0, 0.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 3 steps FTRL
|
# Run 3 steps FTRL
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-2.60260963, -4.29698515]),
|
np.array([-2.60260963, -4.29698515]),
|
||||||
var0.eval(),
|
self.evaluate(var0),
|
||||||
float_rtol=1e-4,
|
float_rtol=1e-4,
|
||||||
half_rtol=1e-2)
|
half_rtol=1e-2)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.28432083, -0.56694895]),
|
np.array([-0.28432083, -0.56694895]),
|
||||||
var1.eval(),
|
self.evaluate(var1),
|
||||||
float_rtol=1e-5,
|
float_rtol=1e-5,
|
||||||
half_rtol=1e-2)
|
half_rtol=1e-2)
|
||||||
|
|
||||||
@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
self.assertAllClose([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 3 steps FTRL
|
# Run 3 steps FTRL
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5,
|
np.array([-2.55607247, -3.98729396]),
|
||||||
|
self.evaluate(var0),
|
||||||
|
1e-5,
|
||||||
|
1e-5,
|
||||||
float_rtol=1e-4)
|
float_rtol=1e-4)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5)
|
np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5,
|
||||||
|
1e-5)
|
||||||
|
|
||||||
def testFtrlWithL1(self):
|
def testFtrlWithL1(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
self.assertAllClose([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 10 steps FTRL
|
# Run 10 steps FTRL
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-7.66718769, -10.91273689]),
|
np.array([-7.66718769, -10.91273689]),
|
||||||
var0.eval(),
|
self.evaluate(var0),
|
||||||
rtol=1e-4,
|
rtol=1e-4,
|
||||||
bfloat16_rtol=1e-1,
|
bfloat16_rtol=1e-1,
|
||||||
bfloat16_atol=1e-1)
|
bfloat16_atol=1e-1)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4)
|
np.array([-0.93460727, -1.86147261]),
|
||||||
|
self.evaluate(var1),
|
||||||
|
rtol=1e-4)
|
||||||
|
|
||||||
def testFtrlWithL1_L2(self):
|
def testFtrlWithL1_L2(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
self.assertAllClose([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 10 steps FTRL
|
# Run 10 steps FTRL
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5)
|
np.array([-0.24059935, -0.46829352]),
|
||||||
|
self.evaluate(var0),
|
||||||
|
rtol=1e-5)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5)
|
np.array([-0.02406147, -0.04830509]),
|
||||||
|
self.evaluate(var1),
|
||||||
|
rtol=1e-5)
|
||||||
|
|
||||||
def testFtrlWithL1_L2_L2Shrinkage(self):
|
def testFtrlWithL1_L2_L2Shrinkage(self):
|
||||||
"""Test the new FTRL op with support for l2 shrinkage.
|
"""Test the new FTRL op with support for l2 shrinkage.
|
||||||
@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 10 steps FTRL
|
# Run 10 steps FTRL
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# Validate updated params
|
# Validate updated params
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4)
|
np.array([-0.22578996, -0.44345799]),
|
||||||
|
self.evaluate(var0),
|
||||||
|
rtol=1e-4)
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4)
|
np.array([-0.14378493, -0.13229476]),
|
||||||
|
self.evaluate(var1),
|
||||||
|
rtol=1e-4)
|
||||||
|
|
||||||
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
|
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
|
||||||
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
|
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
|
||||||
@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
update1 = opt1.apply_gradients([(grads1, var1)])
|
update1 = opt1.apply_gradients([(grads1, var1)])
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
|
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval())
|
self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1))
|
||||||
|
|
||||||
# Run 10 steps FTRL
|
# Run 10 steps FTRL
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# var0 is experiencing L2 shrinkage so it should be smaller than var1
|
# var0 is experiencing L2 shrinkage so it should be smaller than var1
|
||||||
# in magnitude.
|
# in magnitude.
|
||||||
self.assertTrue((var0.eval()**2 < var1.eval()**2).all())
|
self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all())
|
||||||
accum0 = list(opt0._slots["accum"].values())[0].eval()
|
accum0 = list(opt0._slots["accum"].values())[0].eval()
|
||||||
accum1 = list(opt1._slots["accum"].values())[0].eval()
|
accum1 = list(opt1._slots["accum"].values())[0].eval()
|
||||||
# L2 shrinkage should not change how we update grad accumulator.
|
# L2 shrinkage should not change how we update grad accumulator.
|
||||||
|
@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||||
expected = APlus2B(aval, bval)
|
expected = APlus2B(aval, bval)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
|
|
||||||
@function.Defun(dtypes.float32, dtypes.float32)
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
def Foo(a, b):
|
def Foo(a, b):
|
||||||
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_f = Foo(a, b)
|
call_f = Foo(a, b)
|
||||||
result = sess.run(call_f)
|
result = self.evaluate(call_f)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testNestedFunctions(self):
|
def testNestedFunctions(self):
|
||||||
@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
|
||||||
expected = APlus2B(aval, bval)
|
expected = APlus2B(aval, bval)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
|
|
||||||
@function.Defun(dtypes.float32, dtypes.float32)
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
def Foo(a, b):
|
def Foo(a, b):
|
||||||
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_g = Foo(a, b)
|
call_g = Foo(a, b)
|
||||||
result = sess.run(call_g)
|
result = self.evaluate(call_g)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testFunctionMultipleRetvals(self):
|
def testFunctionMultipleRetvals(self):
|
||||||
@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
|
||||||
expected = Func(aval, bval)
|
expected = Func(aval, bval)
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
|
|
||||||
@function.Defun(dtypes.float32, dtypes.float32)
|
@function.Defun(dtypes.float32, dtypes.float32)
|
||||||
def Foo(a, b):
|
def Foo(a, b):
|
||||||
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
|
|||||||
b = constant_op.constant(bval, name="b")
|
b = constant_op.constant(bval, name="b")
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
call_f = Foo(a, b)
|
call_f = Foo(a, b)
|
||||||
result = sess.run(call_f)
|
result = self.evaluate(call_f)
|
||||||
self.assertAllClose(result, expected, rtol=1e-3)
|
self.assertAllClose(result, expected, rtol=1e-3)
|
||||||
|
|
||||||
def testCompileTimeConstantsInDefun(self):
|
def testCompileTimeConstantsInDefun(self):
|
||||||
|
@ -75,7 +75,7 @@ def RunMetadataLabels(run_metadata):
|
|||||||
|
|
||||||
def InLabels(labels, substr):
|
def InLabels(labels, substr):
|
||||||
"""Returns true iff one of the labels contains substr."""
|
"""Returns true iff one of the labels contains substr."""
|
||||||
return any([substr in x for x in labels])
|
return any(substr in x for x in labels)
|
||||||
|
|
||||||
|
|
||||||
def MetadataHasXlaRunOp(run_metadata):
|
def MetadataHasXlaRunOp(run_metadata):
|
||||||
|
@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase):
|
|||||||
def _testListDiff(self, x, y, out, idx):
|
def _testListDiff(self, x, y, out, idx):
|
||||||
for dtype in [dtypes.int32, dtypes.int64]:
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
for index_dtype in [dtypes.int32, dtypes.int64]:
|
for index_dtype in [dtypes.int32, dtypes.int64]:
|
||||||
with self.cached_session() as sess:
|
with self.cached_session():
|
||||||
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
|
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
|
||||||
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
|
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
out_tensor, idx_tensor = array_ops.listdiff(
|
out_tensor, idx_tensor = array_ops.listdiff(
|
||||||
x_tensor, y_tensor, out_idx=index_dtype)
|
x_tensor, y_tensor, out_idx=index_dtype)
|
||||||
tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
|
tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
|
||||||
self.assertAllEqual(out, tf_out)
|
self.assertAllEqual(out, tf_out)
|
||||||
self.assertAllEqual(idx, tf_idx)
|
self.assertAllEqual(idx, tf_idx)
|
||||||
self.assertEqual(1, out_tensor.get_shape().ndims)
|
self.assertEqual(1, out_tensor.get_shape().ndims)
|
||||||
|
@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase):
|
|||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image,
|
actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image,
|
||||||
depth_radius, bias, alpha, beta)
|
depth_radius, bias, alpha, beta)
|
||||||
expected_val = expected.eval()
|
expected_val = self.evaluate(expected)
|
||||||
actual_val = actual.eval()
|
actual_val = self.evaluate(actual)
|
||||||
self.assertAllClose(actual_val, expected_val, rtol=1e-3)
|
self.assertAllClose(actual_val, expected_val, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,8 +88,8 @@ class LSTMTest(test.TestCase):
|
|||||||
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
(basename, m_prev_scalar, c_prev_scalar, pad_scalar))
|
||||||
|
|
||||||
# Initialize variables and run the unrolled LSTM step.
|
# Initialize variables and run the unrolled LSTM step.
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
return sess.run([m, c])
|
return self.evaluate([m, c])
|
||||||
|
|
||||||
def testLSTMCell(self):
|
def testLSTMCell(self):
|
||||||
# Run with all-0 weights, no padding.
|
# Run with all-0 weights, no padding.
|
||||||
@ -173,8 +173,8 @@ class LSTMTest(test.TestCase):
|
|||||||
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
(basename, m_init_scalar, c_init_scalar, pad_scalar))
|
||||||
|
|
||||||
# Initialize variables and run the unrolled LSTM layer.
|
# Initialize variables and run the unrolled LSTM layer.
|
||||||
sess.run(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
return sess.run(out_seq)
|
return self.evaluate(out_seq)
|
||||||
|
|
||||||
def testLSTMLayer(self):
|
def testLSTMLayer(self):
|
||||||
# Run with all-0 weights, no padding.
|
# Run with all-0 weights, no padding.
|
||||||
|
@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
|
|||||||
self.assertFalse(slot1 in variables.trainable_variables())
|
self.assertFalse(slot1 in variables.trainable_variables())
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Step 1: the momentum accumulators where 0. So we should see a normal
|
# Step 1: the momentum accumulators where 0. So we should see a normal
|
||||||
# update: v -= grad * learning_rate
|
# update: v -= grad * learning_rate
|
||||||
mom_update.run()
|
mom_update.run()
|
||||||
# Check that the momentum accumulators have been updated.
|
# Check that the momentum accumulators have been updated.
|
||||||
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
|
self.assertAllCloseAccordingToType(
|
||||||
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
|
np.array([0.1, 0.1]), self.evaluate(slot0))
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([0.01, 0.01]), self.evaluate(slot1))
|
||||||
# Check that the parameters have been updated.
|
# Check that the parameters have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
|
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
|
||||||
|
self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
|
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
|
||||||
|
self.evaluate(var1))
|
||||||
# Step 2: the momentum accumulators contain the previous update.
|
# Step 2: the momentum accumulators contain the previous update.
|
||||||
mom_update.run()
|
mom_update.run()
|
||||||
# Check that the momentum accumulators have been updated.
|
# Check that the momentum accumulators have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
|
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
|
||||||
|
self.evaluate(slot0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
|
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
|
||||||
|
self.evaluate(slot1))
|
||||||
# Check that the parameters have been updated.
|
# Check that the parameters have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([
|
np.array([
|
||||||
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
|
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
|
||||||
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
|
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
|
||||||
]), var0.eval())
|
]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([
|
np.array([
|
||||||
2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
|
2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
|
||||||
(0.9 * 0.01 + 0.01) * 2.0)
|
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
|
||||||
]), var1.eval())
|
]), self.evaluate(var1))
|
||||||
|
|
||||||
def testNesterovMomentum(self):
|
def testNesterovMomentum(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
|
|||||||
var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9)
|
var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9)
|
||||||
var1_np, accum1_np = self._update_nesterov_momentum_numpy(
|
var1_np, accum1_np = self._update_nesterov_momentum_numpy(
|
||||||
var1_np, accum1_np, 0.9, 0.1, 0.9)
|
var1_np, accum1_np, 0.9, 0.1, 0.9)
|
||||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
def testTensorLearningRateAndMomentum(self):
|
def testTensorLearningRateAndMomentum(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
|
|||||||
self.assertFalse(slot1 in variables.trainable_variables())
|
self.assertFalse(slot1 in variables.trainable_variables())
|
||||||
|
|
||||||
# Fetch params to validate initial values
|
# Fetch params to validate initial values
|
||||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
# Step 1: the momentum accumulators where 0. So we should see a normal
|
# Step 1: the momentum accumulators where 0. So we should see a normal
|
||||||
# update: v -= grad * learning_rate
|
# update: v -= grad * learning_rate
|
||||||
mom_update.run()
|
mom_update.run()
|
||||||
# Check that the momentum accumulators have been updated.
|
# Check that the momentum accumulators have been updated.
|
||||||
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
|
self.assertAllCloseAccordingToType(
|
||||||
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
|
np.array([0.1, 0.1]), self.evaluate(slot0))
|
||||||
|
self.assertAllCloseAccordingToType(
|
||||||
|
np.array([0.01, 0.01]), self.evaluate(slot1))
|
||||||
# Check that the parameters have been updated.
|
# Check that the parameters have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
|
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
|
||||||
|
self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
|
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
|
||||||
|
self.evaluate(var1))
|
||||||
# Step 2: the momentum accumulators contain the previous update.
|
# Step 2: the momentum accumulators contain the previous update.
|
||||||
mom_update.run()
|
mom_update.run()
|
||||||
# Check that the momentum accumulators have been updated.
|
# Check that the momentum accumulators have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
|
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
|
||||||
|
self.evaluate(slot0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
|
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
|
||||||
|
self.evaluate(slot1))
|
||||||
# Check that the parameters have been updated.
|
# Check that the parameters have been updated.
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([
|
np.array([
|
||||||
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
|
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
|
||||||
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
|
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
|
||||||
]), var0.eval())
|
]), self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([
|
np.array([
|
||||||
2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
|
2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
|
||||||
(0.9 * 0.01 + 0.01) * 2.0)
|
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
|
||||||
]), var1.eval())
|
]), self.evaluate(var1))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user