Merge branch 'master' into yang/bn_relu_fwd
This commit is contained in:
commit
2f29011259
65
.bazelrc
65
.bazelrc
@ -19,10 +19,10 @@
|
||||
# Compiler options:
|
||||
# cuda_clang: Use clang when building CUDA code.
|
||||
# c++17: Build with C++17 options
|
||||
# C++1z: Build with C++17 options
|
||||
# c++1z: Build with C++17 options
|
||||
# avx_linux: Build with avx instruction set on linux.
|
||||
# avx2_linux: Build with avx2 instruction set on linux.
|
||||
# arch_native_linux: Build with instruction sets available to the host machine on linux
|
||||
# native_arch_linux: Build with instruction sets available to the host machine on linux
|
||||
# avx_win: Build with avx instruction set on windows
|
||||
# avx2_win: Build with avx2 instruction set on windows
|
||||
#
|
||||
@ -73,6 +73,10 @@
|
||||
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
|
||||
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
|
||||
#
|
||||
# Embedded Linux options (experimental and only tested with TFLite build yet)
|
||||
# elinux: General Embedded Linux options shared by all flavors.
|
||||
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
|
||||
|
||||
@ -352,9 +356,10 @@ build:rbe_linux --linkopt=-lm
|
||||
build:rbe_cpu_linux --config=rbe_linux
|
||||
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
|
||||
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
|
||||
build:rbe_cpu_linux --extra_execution_platforms"=@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --host_platform="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --platforms="@org_tensorflow//third_party/toolchains:rbe_ubuntu16.04-manylinux2010"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
|
||||
|
||||
build:rbe_linux_cuda_base --config=rbe_linux
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
|
||||
@ -376,17 +381,37 @@ build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
|
||||
build:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang --crosstool_top="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang --extra_toolchains="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang --extra_execution_platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --host_platform="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --platforms="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang --define=using_cuda_clang=true
|
||||
test:rbe_linux_cuda_clang --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
|
||||
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
@ -432,6 +457,14 @@ build:tensorflow_testing_rbe_linux --config=rbe_linux
|
||||
|
||||
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
|
||||
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
|
||||
|
||||
# TFLite build configs for generic embedded Linux
|
||||
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
|
||||
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
build:elinux_aarch64 --config=elinux
|
||||
build:elinux_aarch64 --cpu=aarch64
|
||||
build:elinux_armhf --config=elinux
|
||||
build:elinux_armhf --cpu=armhf
|
||||
# END TF REMOTE BUILD EXECUTION OPTIONS
|
||||
|
||||
# Default options should come above this line
|
||||
|
@ -1 +1 @@
|
||||
2.0.0
|
||||
3.0.0
|
||||
|
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
34
.github/ISSUE_TEMPLATE/00-bug-issue.md
vendored
@ -10,32 +10,30 @@ labels: 'type:bug'
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:bug_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
|
@ -38,6 +38,9 @@ state what is wrong:
|
||||
- Producing correct results, but the model is slower than expected (model generated from old converter)
|
||||
|
||||
|
||||
**RNN conversion support**
|
||||
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
|
||||
|
||||
**Any other info / logs**
|
||||
|
||||
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
|
||||
|
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
33
.github/ISSUE_TEMPLATE/80-performance-issue.md
vendored
@ -11,32 +11,29 @@ As per our
|
||||
we only address code/doc bugs, performance issues, feature requests and
|
||||
build/installation issues on GitHub. tag:performance_template</em>
|
||||
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock
|
||||
example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g.,
|
||||
Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
|
||||
the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or
|
||||
binary): - TensorFlow version (use command below):
|
||||
- Python version: - Bazel
|
||||
version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from
|
||||
source):
|
||||
- CUDA/cuDNN version: - GPU model and memory:
|
||||
**System information**
|
||||
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
|
||||
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
|
||||
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
|
||||
- TensorFlow installed from (source or binary):
|
||||
- TensorFlow version (use command below):
|
||||
- Python version:
|
||||
- Bazel version (if compiling from source):
|
||||
- GCC/Compiler version (if compiling from source):
|
||||
- CUDA/cuDNN version:
|
||||
- GPU model and memory:
|
||||
|
||||
You can collect some of this information using our environment capture
|
||||
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
|
||||
You can also obtain the TensorFlow version with: 1. TF 1.0: `python -c "import
|
||||
tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"` 2. TF 2.0: `python -c
|
||||
"import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
You can also obtain the TensorFlow version with:
|
||||
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
|
||||
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
|
||||
|
||||
**Describe the current behavior**
|
||||
|
||||
**Describe the expected behavior**
|
||||
|
||||
**Standalone code to reproduce the issue**
|
||||
**Standalone code to reproduce the issue**
|
||||
Provide a reproducible test case that is the bare minimum necessary to generate
|
||||
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
|
||||
|
||||
|
39
.github/stale.yml
vendored
Normal file
39
.github/stale.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
#
|
||||
# THIS IS A GENERATED DOCKERFILE.
|
||||
#
|
||||
# This file was assembled from multiple pieces, whose use is documented
|
||||
# throughout. Please refer to the TensorFlow dockerfiles documentation
|
||||
# for more information.
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you.
|
||||
# Comment to post when removing the stale label. Set to `false` to disable
|
||||
unmarkComment: false
|
||||
closeComment: >
|
||||
Closing as stale. Please reopen if you'd like to work on this further.
|
||||
limitPerRun: 30
|
||||
# Limit to only `issues` or `pulls`
|
||||
only: issues
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,6 +38,7 @@ gradleBuild
|
||||
*.pbxproj
|
||||
*.xcworkspace
|
||||
/*.podspec
|
||||
/tensorflow/lite/**/coreml/**/BUILD
|
||||
/tensorflow/lite/**/ios/BUILD
|
||||
/tensorflow/lite/**/objc/BUILD
|
||||
/tensorflow/lite/**/swift/BUILD
|
||||
|
@ -2,6 +2,10 @@
|
||||
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
|
||||
</div>
|
||||
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
[](https://badge.fury.io/py/tensorflow)
|
||||
|
||||
|
||||
**`Documentation`** |
|
||||
------------------- |
|
||||
[](https://www.tensorflow.org/api_docs/) |
|
||||
@ -135,6 +139,7 @@ Build Type | Status
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
|
@ -2,58 +2,42 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
filegroup(
|
||||
name = "gcc",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-gcc",
|
||||
],
|
||||
srcs = glob(["bin/*-gcc"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ar",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-ar",
|
||||
],
|
||||
srcs = glob(["bin/*-ar"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ld",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-ld",
|
||||
],
|
||||
srcs = glob(["bin/*-ld"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "nm",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-nm",
|
||||
],
|
||||
srcs = glob(["bin/*-nm"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objcopy",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-objcopy",
|
||||
],
|
||||
srcs = glob(["bin/*-objcopy"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "objdump",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-objdump",
|
||||
],
|
||||
srcs = glob(["bin/*-objdump"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "strip",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-strip",
|
||||
],
|
||||
srcs = glob(["bin/*-strip"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "as",
|
||||
srcs = [
|
||||
"bin/arm-rpi-linux-gnueabihf-as",
|
||||
],
|
||||
srcs = glob(["bin/*-as"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@ -66,6 +50,16 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aarch64_compiler_pieces",
|
||||
srcs = glob([
|
||||
"aarch64-none-linux-gnu/**",
|
||||
"libexec/**",
|
||||
"lib/gcc/aarch64-none-linux-gnu/**",
|
||||
"include/**",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "compiler_components",
|
||||
srcs = [
|
||||
|
2
configure
vendored
2
configure
vendored
@ -4,7 +4,7 @@ set -e
|
||||
set -o pipefail
|
||||
|
||||
if [ -z "$PYTHON_BIN_PATH" ]; then
|
||||
PYTHON_BIN_PATH=$(which python || which python3 || true)
|
||||
PYTHON_BIN_PATH=$(which python3 || which python || true)
|
||||
fi
|
||||
|
||||
# Set all env variables
|
||||
|
12
configure.py
12
configure.py
@ -50,7 +50,7 @@ _TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
|
||||
@ -1171,14 +1171,16 @@ def system_specific_test_config(environ_cp):
|
||||
test_only_filters = ['-oss_serial']
|
||||
if is_windows():
|
||||
test_and_build_filters.append('-no_windows')
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters += ['-no_windows_gpu', '-no_gpu']
|
||||
else:
|
||||
test_and_build_filters.append('-gpu')
|
||||
elif is_macos():
|
||||
test_and_build_filters += ['-gpu', '-nomac', '-no_mac']
|
||||
elif is_linux():
|
||||
if environ_cp.get('TF_NEED_CUDA', None) == '1':
|
||||
if ((environ_cp.get('TF_NEED_CUDA', None) == '1') or
|
||||
(environ_cp.get('TF_NEED_ROCM', None) == '1')):
|
||||
test_and_build_filters.append('-no_gpu')
|
||||
write_to_bazelrc('test --test_env=LD_LIBRARY_PATH')
|
||||
else:
|
||||
@ -1416,6 +1418,10 @@ def main():
|
||||
write_action_env_to_bazelrc('LD_LIBRARY_PATH',
|
||||
environ_cp.get('LD_LIBRARY_PATH'))
|
||||
|
||||
if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')):
|
||||
write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH'))
|
||||
write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH'))
|
||||
|
||||
environ_cp['TF_NEED_CUDA'] = str(
|
||||
int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)))
|
||||
if (environ_cp.get('TF_NEED_CUDA') == '1' and
|
||||
|
@ -214,6 +214,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_armhf",
|
||||
values = {"cpu": "armhf"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_x86_64",
|
||||
values = {"cpu": "k8"},
|
||||
@ -517,6 +523,12 @@ package_group(
|
||||
],
|
||||
)
|
||||
|
||||
package_group(name = "ndarray_tensor_allow_list")
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
@ -639,7 +651,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/cc/saved_model:loader_lite_impl",
|
||||
"//tensorflow/core:core_cpu_impl",
|
||||
"//tensorflow/core:framework_internal_impl",
|
||||
"//tensorflow/core:gpu_runtime_impl",
|
||||
"//tensorflow/core/common_runtime/gpu:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/profiler:profiler_impl",
|
||||
@ -703,8 +715,8 @@ tf_cc_shared_object(
|
||||
"//tensorflow/c:version_script.lds",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:distributed_tensorflow_dependencies",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -116,7 +116,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -126,7 +126,7 @@ from tensorflow.python.lib.io import file_io as _fi
|
||||
|
||||
# Get sitepackages directories for the python installation.
|
||||
_site_packages_dirs = []
|
||||
_site_packages_dirs += [_site.USER_SITE]
|
||||
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
|
||||
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
|
||||
if 'getsitepackages' in dir(_site):
|
||||
_site_packages_dirs += _site.getsitepackages()
|
||||
|
@ -23,6 +23,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
"tensor_interface.h",
|
||||
"tf_attrtype.h",
|
||||
"tf_datatype.h",
|
||||
"tf_file_statistics.h",
|
||||
@ -58,6 +59,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_internal.h",
|
||||
"python_api.h",
|
||||
"tensor_interface.h",
|
||||
"tf_status_helper.h",
|
||||
"tf_status_internal.h",
|
||||
"tf_tensor_internal.h",
|
||||
@ -116,6 +118,13 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_macros",
|
||||
hdrs = ["c_api_macros.h"],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
hdrs = [
|
||||
@ -238,6 +247,16 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_interface",
|
||||
hdrs = ["tensor_interface.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_datatype",
|
||||
srcs = ["tf_datatype.cc"],
|
||||
@ -264,6 +283,7 @@ cc_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
":tf_status_helper",
|
||||
@ -271,6 +291,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
],
|
||||
}),
|
||||
)
|
||||
@ -281,16 +302,18 @@ tf_cuda_library(
|
||||
"tf_tensor.h",
|
||||
"tf_tensor_internal.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c:__subpackages__"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":tensor_interface",
|
||||
":tf_datatype",
|
||||
":tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:casts",
|
||||
],
|
||||
}),
|
||||
)
|
||||
@ -311,6 +334,9 @@ tf_cuda_library(
|
||||
":checkpoint_reader",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
@ -318,6 +344,8 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -499,12 +527,12 @@ tf_cuda_cc_test(
|
||||
":test_op1.so",
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
kernels = [":test_op_kernel"],
|
||||
linkopts = select({
|
||||
"//tensorflow:macos": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
"no_windows", # TODO(b/155444728)
|
||||
"noasan",
|
||||
],
|
||||
# We must ensure that the dependencies can be dynamically linked since
|
||||
@ -513,6 +541,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_test_util",
|
||||
":test_op_kernel",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:grad_ops",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
@ -579,6 +608,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -703,3 +733,11 @@ tf_cuda_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conversion_macros",
|
||||
hdrs = [
|
||||
"conversion_macros.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
@ -53,19 +54,18 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/validate.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/coding.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
@ -21,20 +21,24 @@ limitations under the License.
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/net.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
@ -685,9 +689,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -708,7 +710,7 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
||||
|
||||
// New server created for new server_def. Unused if updating server_def.
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server =
|
||||
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
||||
if (grpc_server == nullptr) {
|
||||
@ -822,15 +824,13 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
node_def.set_name(tfe_op->operation->Name());
|
||||
node_def.set_op(tfe_op->operation->Name());
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
node_def.add_input("dummy_input");
|
||||
}
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
tfe_op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(node_def.mutable_attr());
|
||||
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
|
||||
|
||||
const tensorflow::OpRegistrationData* op_reg_data;
|
||||
status->status =
|
||||
|
@ -218,7 +218,7 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
|
||||
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
|
||||
|
||||
float five = 5.0;
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(five);
|
||||
TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five);
|
||||
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||
CheckOutputShapes(fill_op,
|
||||
|
@ -27,8 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/base64.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
|
||||
using tensorflow::errors::InvalidArgument;
|
||||
|
||||
|
@ -14,17 +14,17 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -38,10 +38,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
@ -186,10 +186,6 @@ struct TF_Server {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
||||
|
33
tensorflow/c/c_api_macros.h
Normal file
33
tensorflow/c/c_api_macros.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_C_API_MACROS_H_
|
||||
#define TENSORFLOW_C_C_API_MACROS_H_
|
||||
|
||||
#ifdef SWIG
|
||||
#define TF_CAPI_EXPORT
|
||||
#else
|
||||
#if defined(_WIN32)
|
||||
#ifdef TF_COMPILE_LIBRARY
|
||||
#define TF_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TF_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
#endif // TENSORFLOW_C_C_API_MACROS_H_
|
@ -43,10 +43,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/resource_loader.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
using tensorflow::GraphDef;
|
||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/saved_tensor_slice_util.h"
|
||||
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader.h"
|
||||
|
30
tensorflow/c/conversion_macros.h
Normal file
30
tensorflow/c/conversion_macros.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
#define TENSORFLOW_C_CONVERSION_MACROS_H_
|
||||
|
||||
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
|
||||
inline cpp_impl *unwrap(wrapper *w) { \
|
||||
return reinterpret_cast<cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline const cpp_impl *unwrap(const wrapper *w) { \
|
||||
return reinterpret_cast<const cpp_impl *>(w); \
|
||||
} \
|
||||
\
|
||||
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); }
|
||||
|
||||
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_
|
@ -16,6 +16,7 @@ load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"tf_cuda_tests_tags",
|
||||
)
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -29,11 +30,6 @@ tf_cuda_library(
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.cc",
|
||||
"context_interface.h",
|
||||
"operation_interface.cc",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
],
|
||||
hdrs = ["c_api.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
@ -43,29 +39,38 @@ tf_cuda_library(
|
||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -104,6 +109,11 @@ filegroup(
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
"tfe_tensor_debug_info_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
@ -111,38 +121,170 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
cc_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"c_api_internal.h",
|
||||
],
|
||||
hdrs = ["c_api_internal.h"],
|
||||
visibility = [
|
||||
"//learning/deepmind/courier:__subpackages__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/c:c_api",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_executor_internal",
|
||||
":tfe_monitoring_internal",
|
||||
":tfe_op_attrs_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensor_debug_info_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operation_interface",
|
||||
hdrs = ["operation_interface.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "context_interface",
|
||||
hdrs = ["context_interface.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_context_internal",
|
||||
hdrs = ["tfe_context_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_cancellation_manager_internal",
|
||||
hdrs = ["tfe_cancellation_manager_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_executor_internal",
|
||||
hdrs = ["tfe_executor_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_monitoring_internal",
|
||||
hdrs = ["tfe_monitoring_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_attrs_internal",
|
||||
hdrs = ["tfe_op_attrs_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_op_internal",
|
||||
hdrs = ["tfe_op_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensor_debug_info_internal",
|
||||
hdrs = ["tfe_tensor_debug_info_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tfe_tensorhandle_internal",
|
||||
hdrs = ["tfe_tensorhandle_internal.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
||||
@ -157,6 +299,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -174,7 +317,8 @@ tf_cuda_cc_test(
|
||||
],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"noguitar", # TODO(b/155445984): flaky
|
||||
#"guitar",
|
||||
"multi_gpu",
|
||||
],
|
||||
deps = [
|
||||
@ -182,11 +326,16 @@ tf_cuda_cc_test(
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -197,32 +346,63 @@ tf_cuda_cc_test(
|
||||
srcs = [
|
||||
"c_api_remote_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"guitar",
|
||||
"multi_gpu",
|
||||
"no_oss",
|
||||
],
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_cluster_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_cluster_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = ["noasan"], # leaks gRPC server instances
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:env",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
@ -237,12 +417,16 @@ tf_cuda_library(
|
||||
"//conditions:default": [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
@ -265,7 +449,6 @@ tf_cuda_library(
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
@ -319,6 +502,7 @@ tf_cuda_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
@ -329,6 +513,22 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_device_testutil",
|
||||
testonly = True,
|
||||
srcs = ["custom_device_testutil.cc"],
|
||||
hdrs = ["custom_device_testutil.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
@ -339,6 +539,7 @@ tf_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":custom_device_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
@ -383,11 +584,13 @@ cc_library(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@dlpack",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -405,6 +608,7 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"*c_api_tfrt*",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
File diff suppressed because it is too large
Load Diff
313
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
313
tensorflow/c/eager/c_api_cluster_test.cc
Normal file
@ -0,0 +1,313 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||
tensorflow::ServerDef server_def;
|
||||
server_def.set_protocol("grpc");
|
||||
server_def.set_job_name(job_name);
|
||||
server_def.set_task_index(0);
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||
job_def->set_name(job_name);
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||
return GetServerDef("localhost", num_tasks);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDef) {
|
||||
TestRemoteExecuteUpdateServerDef(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||
TestRemoteExecuteUpdateServerDef(true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
// Adding a non-existent remote worker to cluster def. This should cause the
|
||||
// UpdateServerDef call to fail.
|
||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||
string serialized_update = server_def.SerializeAsString();
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||
serialized_update.size(), status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Even after the prevoiusly failed cluster update, another update and op
|
||||
// execution should work fine as long as the provided server_def is valid.
|
||||
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(false);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
@ -54,36 +57,33 @@ extern "C" {
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
return h->handle->TensorDebugInfo(&status->status);
|
||||
}
|
||||
|
||||
TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
Status* status) {
|
||||
tensorflow::TensorHandle* handle =
|
||||
TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
const tensorflow::Tensor* tensor;
|
||||
*status = handle_->Tensor(&tensor);
|
||||
if (!status->ok()) {
|
||||
status->status = handle->Tensor(&tensor);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||
auto* device = absl::get<tensorflow::Device*>(handle->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
|
||||
if (xla_device != nullptr) {
|
||||
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
|
||||
xla_device->metadata().padded_shape_fn();
|
||||
xla::Shape padded_shape;
|
||||
*status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->ok()) {
|
||||
status->status = shape_fn(*tensor, &padded_shape);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VLOG_IS_ON(3)) {
|
||||
std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
std::vector<int64> shape_to_log =
|
||||
TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
// Ignore the status here as we are simply logging.
|
||||
*status = tensorflow::Status::OK();
|
||||
status->status = tensorflow::Status::OK();
|
||||
} else {
|
||||
VLOG(3) << "Fully padded shape of ["
|
||||
<< absl::StrJoin(shape_to_log, ", ") << "] is "
|
||||
@ -96,7 +96,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
// Currently, the only case of XlaTensor containing a tuple shape is to
|
||||
// represent 64 bit ints, doubles, and complex numbers (we don't support
|
||||
// 64bit complex numbers).
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should only contain tuples of size 2. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -108,13 +108,13 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
const xla::Shape& shape1 =
|
||||
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
|
||||
if (shape0.IsTuple() || shape1.IsTuple()) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"XlaTensors should not contain nested tuples. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
|
||||
*status = tensorflow::errors::InvalidArgument(
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"Subshapes of XlaTensors should be the same. Shape: ",
|
||||
padded_shape.DebugString());
|
||||
return nullptr;
|
||||
@ -139,15 +139,15 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
dev_dims.push_back(padded_shape.dimensions(dim_index));
|
||||
}
|
||||
}
|
||||
*status = tensorflow::Status::OK();
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
}
|
||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||
|
||||
// If the tensor is not an XLA tensor, the device shape is
|
||||
// the same as regular tensor shape.
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
|
||||
if (!status->ok()) {
|
||||
std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorDebugInfo(dev_dims);
|
||||
|
@ -21,8 +21,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
TEST(CApiDebug, ScalarCPU) {
|
||||
TFE_TensorHandle* h = TestScalarTensorHandle(1.0f);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f);
|
||||
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -30,12 +35,18 @@ TEST(CApiDebug, ScalarCPU) {
|
||||
|
||||
TFE_DeleteTensorDebugInfo(debug_info);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CApiDebug, 2DCPU) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle3X2();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx);
|
||||
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -46,5 +57,6 @@ TEST(CApiDebug, 2DCPU) {
|
||||
|
||||
TFE_DeleteTensorDebugInfo(debug_info);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
@ -15,24 +15,32 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"op_to_reset should not be nullptr");
|
||||
@ -41,13 +49,13 @@ void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
|
||||
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(true);
|
||||
}
|
||||
|
||||
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
@ -479,7 +487,7 @@ void TFE_ContextOptionsSetMirroringPolicy(TFE_ContextOptions* options,
|
||||
void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
TFE_Context* ctx, TFE_ContextMirroringPolicy policy) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetThreadLocalMirroringPolicy(
|
||||
static_cast<tensorflow::ContextMirroringPolicy>(policy));
|
||||
}
|
||||
@ -490,7 +498,7 @@ void TFE_ContextSetThreadLocalMirroringPolicy(
|
||||
extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
|
||||
TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return static_cast<TFE_ContextMirroringPolicy>(context->GetMirroringPolicy());
|
||||
}
|
||||
|
||||
@ -525,7 +533,11 @@ void TFE_DeleteCancellationManager(
|
||||
void TFE_OpSetCancellationManager(TFE_Op* op,
|
||||
TFE_CancellationManager* cancellation_manager,
|
||||
TF_Status* status) {
|
||||
status->status = op->operation->SetCancellationManager(cancellation_manager);
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_NewExecutor(bool is_async) {
|
||||
@ -549,19 +561,19 @@ void TFE_ExecutorClearError(TFE_Executor* executor) {
|
||||
|
||||
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
context->SetExecutorForThread(executor->executor());
|
||||
}
|
||||
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return new TFE_Executor(&context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
@ -574,16 +586,10 @@ void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
};
|
||||
}
|
||||
|
||||
void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
TF_Status* status) {
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(ctx->context);
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
auto* function_def = context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
@ -600,3 +606,35 @@ void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims,
|
||||
TF_Status* status) {
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
|
||||
if (ctx == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorInterface* t =
|
||||
tensorflow::unwrap(ctx)->CreateTensor(
|
||||
static_cast<tensorflow::DataType>(dtype), dimvec);
|
||||
|
||||
if (t == nullptr) {
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new TF_Tensor{t};
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
|
||||
TF_Status* status) {
|
||||
return tensorflow::wrap(
|
||||
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
|
||||
}
|
||||
|
@ -392,12 +392,6 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
||||
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// If the TensorHandle is copied to another device as part of an op execution,
|
||||
// the copy is destroyed after the op has executed. Enabling implicit mirroring
|
||||
// causes the copy to be held as a mirror for the lifetime of the TensorHandle.
|
||||
TF_CAPI_EXPORT extern void TFE_TensorHandleEnableImplicitMirroring(
|
||||
TFE_TensorHandle*, TF_Status*);
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
// completed. This is only valid on local TFE_TensorHandles. The pointer
|
||||
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
|
||||
@ -437,11 +431,6 @@ TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
// A reference to an op's name -> attribute mapping
|
||||
typedef struct TFE_OpAttrs TFE_OpAttrs;
|
||||
|
||||
// Fetch a struct with a reference to information about attributes of `op`.
|
||||
//
|
||||
// The `attrs` struct does not own any memory, and `op` must outlive it.
|
||||
TF_CAPI_EXPORT extern void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs);
|
||||
|
||||
// Add attributes in `attrs` to `op`.
|
||||
//
|
||||
// Does not overwrite or update existing attributes, but adds new ones.
|
||||
@ -521,15 +510,34 @@ typedef struct TFE_CustomDevice {
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
|
||||
TFE_CustomDevice device,
|
||||
const char* device_name,
|
||||
void* device_info,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Allocate and return a new Tensor on the host.
|
||||
//
|
||||
// The caller must set the Tensor values by writing them to the pointer returned
|
||||
// by TF_TensorData with length TF_TensorByteSize.
|
||||
TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
|
||||
TF_DataType dtype,
|
||||
const int64_t* dims,
|
||||
int num_dims,
|
||||
TF_Status* status);
|
||||
|
||||
// Given a Tensor, wrap it with a TensorHandle
|
||||
//
|
||||
// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
|
||||
// The context should be identical to that of the Tensor.
|
||||
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
|
||||
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
@ -378,7 +378,7 @@ void Executor_MatMul_CPU(bool async) {
|
||||
TFE_Executor* executor = TFE_NewExecutor(async);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
|
||||
int num_retvals = 2;
|
||||
@ -423,7 +423,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float* m_float = static_cast<float*>(TF_TensorData(m_data));
|
||||
|
@ -15,42 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
|
||||
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
|
||||
|
||||
// TODO(b/154564140): Move this to its own header. This requires splitting
|
||||
// c_api_experimental.h
|
||||
struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
// true if async execution is enabled.
|
||||
@ -64,181 +39,4 @@ struct TFE_ContextOptions {
|
||||
bool use_tfrt = false;
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
std::unique_ptr<tensorflow::AbstractContextInterface> context;
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||
const char* op_name)
|
||||
: name(op_name), attributes(value) {}
|
||||
|
||||
const char* name;
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -74,8 +76,8 @@ void TestRemoteExecute(bool async) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
@ -128,7 +130,45 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
string MatMulFunction() {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }",
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
@ -159,23 +199,46 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(h1_task2, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -183,12 +246,11 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
if (!remote && !async) {
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_EQ(op->GetInput(1), remote_arg);
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
@ -218,6 +280,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
@ -228,16 +293,23 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(false, true);
|
||||
TestRemoteExecuteSilentCopies(false, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, true);
|
||||
TestRemoteExecuteSilentCopies(true, true, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(true, true, true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(false, false);
|
||||
TestRemoteExecuteSilentCopies(false, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(true, false);
|
||||
TestRemoteExecuteSilentCopies(true, false, false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
// TODO(b/155493048): skip test due to flakiness
|
||||
// TestRemoteExecuteSilentCopies(true, false, true);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
@ -268,8 +340,8 @@ void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||
|
||||
// Use large matrices so that RPCs don't return before we get a chance
|
||||
// to call TFE_DeleteContext.
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
auto* h0_task1 =
|
||||
@ -310,150 +382,4 @@ TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||
}
|
||||
|
||||
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
|
||||
const std::vector<float>& expected_values) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
|
||||
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
|
||||
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
|
||||
for (int i = 0; i < expected_values.size(); i++) {
|
||||
EXPECT_EQ(expected_values[i], actual_values[i])
|
||||
<< "Mismatch in expected values at (zero-based) index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
|
||||
const char* remote_device_name,
|
||||
const char* local_device_name) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 =
|
||||
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
|
||||
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
void TestRemoteExecuteChangeServerDef(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char local_device_name[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
|
||||
// Update the server def with a new set of names (worker instead of
|
||||
// localhost).
|
||||
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
|
||||
serialized = updated_server_def.SerializeAsString();
|
||||
|
||||
updated_server_def.set_task_index(1);
|
||||
tensorflow::Status s = tensorflow::GrpcServer::Create(
|
||||
updated_server_def, tensorflow::Env::Default(), &worker_server);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Create a new tensor_handle.
|
||||
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle();
|
||||
|
||||
// Check that copying it to the old remote device (named localhost) fails.
|
||||
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
|
||||
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
// Copying and executing on the new remote device works.
|
||||
const char new_remote_device_name[] =
|
||||
"/job:worker/replica:0/task:1/device:CPU:0";
|
||||
const char new_local_device_name[] =
|
||||
"/job:worker/replica:0/task:0/device:CPU:0";
|
||||
|
||||
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
|
||||
h0_task0_new, ctx, new_remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0_new);
|
||||
TFE_DeleteTensorHandle(h0_task1_new);
|
||||
|
||||
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
|
||||
new_local_device_name);
|
||||
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteChangeServerDef) {
|
||||
TestRemoteExecuteChangeServerDef(false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
||||
TestRemoteExecuteChangeServerDef(true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,16 +19,24 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
@ -47,7 +55,7 @@ void BM_InitOp(int iters) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
@ -71,12 +79,19 @@ void BM_Execute(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMul", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(matmul, "MatMul", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -106,12 +121,16 @@ void BM_Execute_Identity(int iters, int async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* identity = IdentityOp(ctx, m);
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identity = TFE_NewOp(ctx, "Identity", status);
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
TFE_OpReset(identity, "Identity", nullptr, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(identity, m, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_Execute(identity, &retvals[0], &num_retvals, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
@ -153,11 +172,16 @@ TEST(CAPI, Context) {
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
|
||||
ASSERT_EQ(16, TF_TensorByteSize(t));
|
||||
float data[4] = {0};
|
||||
@ -168,6 +192,7 @@ TEST(CAPI, TensorHandle) {
|
||||
EXPECT_EQ(4.0, data[3]);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void TensorHandleCopyBetweenDevices(bool async) {
|
||||
@ -179,7 +204,7 @@ void TensorHandleCopyBetweenDevices(bool async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -255,7 +280,7 @@ void TensorHandleCopyBetweenDevicesError(bool async) {
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* kErrorDevice = "NoSuchDevice:0";
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get());
|
||||
@ -296,7 +321,7 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -382,7 +407,7 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
@ -393,6 +418,13 @@ void TensorHandleSilentCopy(bool async,
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
auto cpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
|
||||
auto gpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
if (cpu_op) {
|
||||
string cpu_device_name;
|
||||
@ -408,15 +440,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
|
||||
// The input handles should never change since they have been mirrored.
|
||||
EXPECT_EQ(op->GetInput(0), arg0);
|
||||
EXPECT_EQ(op->GetInput(1), arg1);
|
||||
// The CPU handle should have been copied and have a mirror on the GPU
|
||||
ASSERT_TRUE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
@ -455,7 +480,7 @@ void SetAndGetOpDevices(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
@ -487,40 +512,35 @@ TEST(CAPI, TensorHandleNullptr) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(t, nullptr);
|
||||
ASSERT_EQ("The passed in handle is a nullptr",
|
||||
string(TF_Message(status.get())));
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
const char* device_name = TFE_TensorHandleDeviceName(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_name, nullptr);
|
||||
ASSERT_EQ("The passed in handle is a nullptr",
|
||||
string(TF_Message(status.get())));
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(device_name, nullptr);
|
||||
ASSERT_EQ("The passed in handle is a nullptr",
|
||||
string(TF_Message(status.get())));
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
int num_dims = TFE_TensorHandleNumDims(h, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(num_dims, -1);
|
||||
ASSERT_EQ("The passed in handle is a nullptr",
|
||||
string(TF_Message(status.get())));
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
|
||||
TF_SetStatus(status.get(), TF_OK, "");
|
||||
|
||||
int dim = TFE_TensorHandleDim(h, 0, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
ASSERT_EQ(dim, -1);
|
||||
ASSERT_EQ("The passed in handle is a nullptr",
|
||||
string(TF_Message(status.get())));
|
||||
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleDevices) {
|
||||
@ -531,7 +551,7 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
|
||||
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
|
||||
@ -581,15 +601,16 @@ TEST(CAPI, TensorHandleDevices) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
void ExecuteAdd(bool async, bool forward_input) {
|
||||
void ExecuteAdd(bool async, bool forward_input, bool tfrt) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, tfrt);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* n = TestMatrixTensorHandle100x100(ctx);
|
||||
// If a GPU exists, copy the handle to GPU so that we can exercise
|
||||
// unprotecting a mirror.
|
||||
std::string gpu_device_name;
|
||||
@ -597,12 +618,11 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
TFE_TensorHandle* n_gpu =
|
||||
TFE_TensorHandleCopyToDevice(n, ctx, gpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandleEnableImplicitMirroring(n_gpu, status);
|
||||
TFE_DeleteTensorHandle(n);
|
||||
n = n_gpu;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle100x100(ctx);
|
||||
|
||||
// Store pointer to raw buffer for validation of forwarding behaviour.
|
||||
TF_Tensor* orig = TFE_TensorHandleResolve(n, status);
|
||||
@ -619,17 +639,6 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
}
|
||||
|
||||
int num_retvals = 1;
|
||||
|
||||
if (async) {
|
||||
// Enqueue dummy ops so we backlog async execution & actually test async.
|
||||
for (int i = 0; i < 10000; ++i) {
|
||||
TFE_TensorHandle* dummy = nullptr;
|
||||
TFE_Execute(add_op, &dummy, &num_retvals, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(dummy);
|
||||
}
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retval = nullptr;
|
||||
TFE_Execute(add_op, &retval, &num_retvals, status);
|
||||
EXPECT_EQ(1, num_retvals);
|
||||
@ -649,7 +658,6 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
float result[100 * 100] = {0};
|
||||
@ -659,12 +667,42 @@ void ExecuteAdd(bool async, bool forward_input) {
|
||||
for (int i = 0; i < 100 * 100; ++i) {
|
||||
EXPECT_EQ(2.0f, result[i]);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteAdd) { ExecuteAdd(false, false); }
|
||||
TEST(CAPI, ExecuteAddAsync) { ExecuteAdd(true, false); }
|
||||
TEST(CAPI, ExecuteAddForward) { ExecuteAdd(false, true); }
|
||||
TEST(CAPI, ExecuteAddForwardAsync) { ExecuteAdd(true, true); }
|
||||
TEST(CAPI, ExecuteAdd) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddAsync) {
|
||||
ExecuteAdd(
|
||||
/*async=*/true,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddForward) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ true,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
TEST(CAPI, ExecuteAddForwardAsync) {
|
||||
ExecuteAdd(
|
||||
/*async=*/true,
|
||||
/*forward_input*/ true,
|
||||
/*tfrt*/ false);
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// TODO(b/153349425): Add add forwarding tests for TFRT
|
||||
TEST(CAPI, ExecuteAddTfrt) {
|
||||
ExecuteAdd(
|
||||
/*async=*/false,
|
||||
/*forward_input*/ false,
|
||||
/*tfrt*/ true);
|
||||
}
|
||||
#endif
|
||||
|
||||
void Execute_MatMul_CPU(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
@ -674,7 +712,7 @@ void Execute_MatMul_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
|
||||
int num_retvals = 2;
|
||||
@ -710,8 +748,8 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
|
||||
TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
status);
|
||||
@ -782,8 +820,8 @@ void Execute_MatMul_CPU_Type_Error(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m1, m2);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -812,8 +850,8 @@ TEST(CAPI, Execute_Min_CPU) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -847,7 +885,7 @@ void Execute_MatMul_XLA_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
TFE_OpSetXLACompilation(matmul, true);
|
||||
@ -889,8 +927,8 @@ void Execute_Min_XLA_CPU(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = MinOp(ctx, input, axis);
|
||||
|
||||
TFE_OpSetXLACompilation(minOp, true);
|
||||
@ -930,7 +968,7 @@ void ExecuteWithTracing(bool async) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
@ -1016,7 +1054,7 @@ void FunctionDefAndExecute(bool async) {
|
||||
if (clear_cache) {
|
||||
TFE_ContextClearCaches(ctx);
|
||||
}
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* retval[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
TFE_Op* op = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
@ -1065,7 +1103,7 @@ void BM_ExecuteFunction(int iters, int async) {
|
||||
status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, m, status);
|
||||
@ -1280,11 +1318,15 @@ TEST(CAPI, StringAttributes) {
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle(ctx);
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
||||
TFE_TensorHandle* h_shares_tensor =
|
||||
TFE_TensorHandleCopySharingTensor(h, status.get());
|
||||
@ -1302,13 +1344,14 @@ TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
|
||||
|
||||
TFE_DeleteTensorHandle(h);
|
||||
TFE_DeleteTensorHandle(h_shares_tensor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
tensorflow::AttrValueMap ExtractAttrs(TFE_Op* op) {
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
tensorflow::down_cast<tensorflow::OperationInterface*>(op->operation.get())
|
||||
->Attrs()
|
||||
.FillAttrValueMap(&attr_values);
|
||||
tensorflow::EagerOperation* operation =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
|
||||
operation->Attrs().FillAttrValueMap(&attr_values);
|
||||
return attr_values;
|
||||
}
|
||||
|
||||
@ -1319,8 +1362,8 @@ TEST(CAPI, TestTFE_OpInferSingleInputAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* input = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* axis = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* minOp = TFE_NewOp(ctx, "Min", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(minOp, input, status);
|
||||
@ -1356,9 +1399,9 @@ TEST(CAPI, TestTFE_OpInferSingleTypeInputListAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
|
||||
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
@ -1396,9 +1439,9 @@ TEST(CAPI, TestTFE_OpInferMixedTypeInputListAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* condition = TestScalarTensorHandle(true);
|
||||
TFE_TensorHandle* t1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* t2 = TestAxisTensorHandle();
|
||||
TFE_TensorHandle* condition = TestScalarTensorHandle(ctx, true);
|
||||
TFE_TensorHandle* t1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* t2 = TestAxisTensorHandle(ctx);
|
||||
TFE_Op* assertOp = TFE_NewOp(ctx, "Assert", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpAddInput(assertOp, condition, status);
|
||||
@ -1435,18 +1478,18 @@ TEST(CAPI, TestTFE_OpAttrsInferenceDisabledWhenNotCallingOpAddInputList) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(0);
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* dim = TestScalarTensorHandle(ctx, 0);
|
||||
TFE_Op* concatOp = TFE_NewOp(ctx, "Concat", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
TFE_OpAddInput(concatOp, dim, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
CHECK(concatOp->operation->OpDef());
|
||||
CHECK(tensorflow::unwrap(concatOp)->OpDef());
|
||||
TFE_OpAddInput(concatOp, inputs[0], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
EXPECT_FALSE(concatOp->operation->OpDef())
|
||||
EXPECT_FALSE(tensorflow::unwrap(concatOp)->OpDef())
|
||||
<< "Inference context is still present";
|
||||
TFE_OpAddInput(concatOp, inputs[1], status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -1470,8 +1513,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengths) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
@ -1518,8 +1561,8 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* input1 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* input2 = TestMatrixTensorHandle(ctx);
|
||||
TFE_Op* identityOp = TFE_NewOp(ctx, "IdentityN", status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_TensorHandle* inputs[] = {input1, input2};
|
||||
@ -1538,7 +1581,7 @@ TEST(CAPI, TestTFE_OpGetInputAndOutputLengthsFailForUnknownArguments) {
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TEST(CAPI, TestTFE_OpAddAttrs) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
@ -1548,8 +1591,11 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
TFE_Op* var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
|
||||
TFE_Op* copy_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_OpSetAttrType(copy_op, "dtype", TF_FLOAT);
|
||||
@ -1563,8 +1609,8 @@ TEST(CAPI, TestTFE_OpGetAttrs) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
copy_op->operation.get());
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(copy_op));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_FLOAT, attr_values.find("dtype")->second.type());
|
||||
|
||||
@ -1585,8 +1631,11 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_OpSetAttrType(var_op, "dtype", TF_INT64);
|
||||
TFE_OpSetAttrShape(var_op, "shape", {}, 0, status);
|
||||
TFE_OpAttrs attributes;
|
||||
TFE_OpGetAttrs(var_op, &attributes);
|
||||
// There is currently no API to fetch attributes from an operation, fetching
|
||||
// happens only as an implementation detail of custom devices.
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(var_op));
|
||||
TFE_OpAttrs attributes{&operation->Attrs()};
|
||||
|
||||
TF_Buffer* serialized_attr_values = TF_NewBuffer();
|
||||
TFE_OpAttrsSerialize(&attributes, serialized_attr_values, status);
|
||||
@ -1599,26 +1648,26 @@ TEST(CAPI, TestTFE_OpAttrsSerialize) {
|
||||
name_and_attrs.attr().find("dtype")->second.type());
|
||||
TF_DeleteBuffer(serialized_attr_values);
|
||||
|
||||
TFE_Op* second_var_op = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
TFE_Op* var_op_2 = TFE_NewOp(ctx, "VarHandleOp", status);
|
||||
|
||||
string serialized_dtype;
|
||||
ASSERT_TRUE(name_and_attrs.attr().find("dtype")->second.SerializeToString(
|
||||
&serialized_dtype));
|
||||
TFE_OpSetAttrValueProto(
|
||||
second_var_op, "dtype",
|
||||
var_op_2, "dtype",
|
||||
reinterpret_cast<const void*>(serialized_dtype.c_str()),
|
||||
serialized_dtype.length(), status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
tensorflow::AttrValueMap attr_values;
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
second_var_op->operation.get());
|
||||
tensorflow::EagerOperation* op =
|
||||
tensorflow::OperationFromInterface(tensorflow::unwrap(var_op_2));
|
||||
op->Attrs().FillAttrValueMap(&attr_values);
|
||||
EXPECT_EQ(tensorflow::DT_INT64, attr_values.find("dtype")->second.type());
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteOp(var_op);
|
||||
TFE_DeleteOp(second_var_op);
|
||||
TFE_DeleteOp(var_op_2);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
|
@ -16,115 +16,117 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(float value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(int value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
|
||||
int data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_INT32, nullptr, 0, sizeof(int));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestScalarTensorHandle(bool value) {
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
|
||||
bool data[] = {value};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_BOOL, nullptr, 0, sizeof(bool));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle() {
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {2, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {2, 2};
|
||||
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
|
||||
constexpr int64_t dims[] = {100, 100};
|
||||
constexpr int num_elements = dims[0] * dims[1];
|
||||
float data[num_elements];
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
data[i] = 1.0f;
|
||||
}
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() {
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
int64_t dims[] = {3, 2};
|
||||
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2() {
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
|
||||
int64_t dims[] = {3, 2};
|
||||
float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
@ -187,14 +189,14 @@ TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
|
||||
return op;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TestAxisTensorHandle() {
|
||||
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
|
||||
int64_t dims[] = {1};
|
||||
int data[] = {1};
|
||||
TF_Tensor* t = TF_AllocateTensor(
|
||||
TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data));
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandle(t, status);
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
|
||||
sizeof(dims) / sizeof(int64_t), status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
|
@ -19,28 +19,28 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
// Return a tensor handle containing a float scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(float value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
||||
|
||||
// Return a tensor handle containing a int scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(int value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
|
||||
|
||||
// Return a tensor handle containing a bool scalar
|
||||
TFE_TensorHandle* TestScalarTensorHandle(bool value);
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle();
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 2x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of doubles
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return a tensor handle containing a 3x2 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2();
|
||||
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
|
||||
|
||||
// Return an add op multiplying `a` by `b`.
|
||||
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
|
||||
@ -55,7 +55,7 @@ TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
|
||||
|
||||
// Return an 1-D INT32 tensor containing a single value 1.
|
||||
TFE_TensorHandle* TestAxisTensorHandle();
|
||||
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
|
||||
|
||||
// Return an op taking minimum of `input` long `axis` dimension.
|
||||
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
|
||||
|
@ -15,247 +15,78 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// Public C API entry points
|
||||
//
|
||||
// These are only the generic entry points for the C API. This file does not
|
||||
// have any visibility into the graph/eager implementation and is only providing
|
||||
// C bindings to the abstract classes defined in the
|
||||
// c_api_unified_experimental_internal.h header.
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
struct TF_ExecutionContext {
|
||||
explicit TF_ExecutionContext() {}
|
||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
||||
ExecuteOperation execution_callback;
|
||||
};
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
string op_type;
|
||||
string op_name;
|
||||
};
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
||||
return new TF_ExecutionContext();
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
return wrap(unwrap(c)->CreateOperation());
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp() {
|
||||
TF_AbstractOp* op = new TF_AbstractOp;
|
||||
return op;
|
||||
}
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
struct TF_GraphContext {
|
||||
TF_Graph* graph;
|
||||
// TODO(srbs): Handle captures.
|
||||
};
|
||||
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
||||
auto ctx = new TF_GraphContext;
|
||||
ctx->graph = g;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
||||
TF_Status* s) {
|
||||
TF_GraphTensor* t = new TF_GraphTensor;
|
||||
t->output = output;
|
||||
t->ctx = ctx;
|
||||
return t;
|
||||
}
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
||||
return t->output;
|
||||
}
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TF_GraphTensor*>(at->t);
|
||||
}
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
}
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
unwrap(o)->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
return unwrap(o)->outputs.size();
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
}
|
||||
|
||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
auto* tfe_op =
|
||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
TFE_DeleteOp(tfe_op);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
||||
TF_Graph* g = graph_ctx->graph;
|
||||
auto* tf_opdesc =
|
||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != graph_ctx) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
t->t = output_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = eager_context;
|
||||
context->execution_callback = &ExecuteOperationEager;
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = graph_context;
|
||||
context->execution_callback = &ExecuteOperationGraph;
|
||||
return wrap(unwrap(o)->outputs[i]);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->op_type = op_type;
|
||||
unwrap(op)->SetOpType(op_type, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->op_name = op_name;
|
||||
unwrap(op)->SetOpName(op_name, s);
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s) {
|
||||
unwrap(op)->SetAttrType(attr_name, value, s);
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
||||
unwrap(o), s);
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
delete unwrap(func);
|
||||
}
|
||||
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||
}
|
||||
|
@ -15,8 +15,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -34,39 +35,34 @@ extern "C" {
|
||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||
// of gradient tapes, etc.
|
||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||
|
||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||
// type of eager and graph tensors.
|
||||
// type of eager and graph tensors. It is also the result of executing an
|
||||
// operation.
|
||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
|
||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext();
|
||||
// Stores a function representation that can be used for execution or for
|
||||
// setting functional attributes of other composite ops e.g. control flow.
|
||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||
|
||||
// Creates a context for tracing the execution of operations into a function.
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
||||
|
||||
// Creates a context for eager execution of operations.
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp();
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs for Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Keeps track of the current graph and other state e.g. captures etc.
|
||||
typedef struct TF_GraphContext TF_GraphContext;
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
||||
|
||||
// `eager_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context, TF_Status*);
|
||||
// `graph_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status*);
|
||||
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
@ -74,29 +70,20 @@ void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
// `op_name` must outlive `op`.
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s);
|
||||
// `attr_name` must outlive `op`.
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s);
|
||||
|
||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
||||
TF_Status* s);
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s);
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
|
||||
// an operation.
|
||||
// It just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
// TODO(aminim): the description above isn't clear with respect to
|
||||
// TF_OutputListNumOutputs and the current eager implementation which requires
|
||||
// the number of outputs to be set by the client.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
@ -105,13 +92,41 @@ int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
// which happens to be active.
|
||||
// capture some inputs and then add a node in the graph. The output tensors are
|
||||
// returned through the provided TF_OutputList.
|
||||
// Any active tape will observe the effects of this execution.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
||||
// TODO(aminim): clarify the contract on the state of the context after this
|
||||
// call.
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||
|
||||
// Register the function with the given context. This is particularly useful for
|
||||
// making a function available to an eager context.
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
|
||||
TF_AbstractFunction*, TF_Status*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs specific to Eager modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Temporary APIs till we figure out how to create scalar valued Eager
|
||||
// tensors and how to get value out of eager abstract tensors.
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
183
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
183
tensorflow/c/eager/c_api_unified_experimental_eager.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Simple wrapper over a TFE_TensorHandle
|
||||
struct EagerTensor : public AbstractTensor {
|
||||
TFE_TensorHandle* t = nullptr;
|
||||
EagerTensor() : AbstractTensor(kKind) {}
|
||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||
};
|
||||
|
||||
// Simple wrapper over a TFE_Op
|
||||
class EagerOp : public AbstractOp {
|
||||
public:
|
||||
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||
|
||||
private:
|
||||
friend class EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
||||
class EagerContext : public ExecutionContext {
|
||||
public:
|
||||
EagerContext() : ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dyncast<EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
||||
if (!eager_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
return;
|
||||
}
|
||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||
}
|
||||
|
||||
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Eager API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::dyncast;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||
TF_Status* s) {
|
||||
auto* ctx = new tensorflow::internal::EagerContext();
|
||||
ctx->Build(options, s);
|
||||
return wrap(ctx);
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::EagerTensor(t));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
||||
if (!eager_tensor) {
|
||||
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return eager_tensor->t;
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
||||
if (!eager_ctx) return nullptr;
|
||||
return eager_ctx->eager_ctx_;
|
||||
}
|
248
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
248
tensorflow/c/eager/c_api_unified_experimental_graph.cc
Normal file
@ -0,0 +1,248 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
class GraphContext;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
struct GraphTensor : public AbstractTensor {
|
||||
TF_Output output{};
|
||||
GraphContext* ctx = nullptr;
|
||||
GraphTensor() : AbstractTensor(kKind) {}
|
||||
GraphTensor(TF_Output output, GraphContext* ctx)
|
||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||
};
|
||||
|
||||
// GraphOp wraps and populate a TF_OperationDescription.
|
||||
class GraphOp : public AbstractOp {
|
||||
public:
|
||||
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
}
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
}
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~GraphOp() override {}
|
||||
|
||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||
|
||||
private:
|
||||
friend class GraphContext; // For access to op_.
|
||||
TF_Graph* g_;
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
const char* op_name_ = nullptr;
|
||||
};
|
||||
|
||||
// GraphFunction is a thin wrapper over a TF_Function.
|
||||
struct GraphFunction : public AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
GraphFunction() : AbstractFunction(kKind) {}
|
||||
explicit GraphFunction(TF_Function* func)
|
||||
: AbstractFunction(kKind), func(func) {}
|
||||
~GraphFunction() override {
|
||||
if (func) TF_DeleteFunction(func);
|
||||
}
|
||||
|
||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
||||
// adding them to the graph.
|
||||
class GraphContext : public ExecutionContext {
|
||||
public:
|
||||
GraphContext()
|
||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new GraphOp(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dyncast<GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (graph_tensor->ctx != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||
}
|
||||
}
|
||||
|
||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
||||
const GraphTensor* inputs, int num_outputs,
|
||||
const GraphTensor* outputs, TF_Status* status) const {
|
||||
std::vector<TF_Output> graph_inputs;
|
||||
graph_inputs.resize(num_inputs);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.resize(num_outputs);
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
graph_inputs[i] = inputs[i].output;
|
||||
}
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
graph_outputs[i] = outputs[i].output;
|
||||
}
|
||||
|
||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
||||
graph_inputs.size(), graph_inputs.data(),
|
||||
graph_outputs.size(), graph_outputs.data(),
|
||||
nullptr, nullptr, fn_name, status);
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
};
|
||||
|
||||
// Helper that converts the graph currently held in the context into a function.
|
||||
static AbstractFunction* ExecutionContextToFunction(
|
||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const AbstractTensor* inputs, int num_outputs,
|
||||
const AbstractTensor* outputs, TF_Status* status) {
|
||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
||||
if (graph_ctx == nullptr) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"fn_body is not a TF_GraphContext.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
||||
if (!graph_inputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
||||
if (!graph_outputs) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Graph API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::GraphContext());
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
||||
const TF_AbstractTensor* inputs, int num_outputs,
|
||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
||||
unwrap(inputs), num_outputs,
|
||||
unwrap(outputs), status));
|
||||
}
|
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
184
tensorflow/c/eager/c_api_unified_experimental_internal.h
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// =============================================================================
|
||||
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||
// backends (graph/MLIR).
|
||||
//
|
||||
// This defines a set of abstract classes that are intended to provide the
|
||||
// functionality of the opaque C types exposed in the public APIs defined in the
|
||||
// `c_api_unified_experimental.h` header.
|
||||
// =============================================================================
|
||||
|
||||
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
||||
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
||||
// Instead we model RTTI by listing all the possible subclasses for each
|
||||
// abstract base. Each subclass initializes the base class with the right
|
||||
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
||||
// utility.
|
||||
template <typename T, typename S>
|
||||
T* dyncast(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
// Represents either an EagerTensor or a GraphTensor.
|
||||
// This base class does not expose any public methods other than to distinguish
|
||||
// which subclass it actually is. The user is responsible to use the right
|
||||
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
protected:
|
||||
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractTensorKind getKind() const { return kind_; }
|
||||
virtual ~AbstractTensor() = default;
|
||||
|
||||
private:
|
||||
const AbstractTensorKind kind_;
|
||||
};
|
||||
|
||||
// Represents the results of the execution of an operation.
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
// Holds the result of tracing a function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Temporary API till we figure the right abstraction for AbstractFunction.
|
||||
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
// An abstract operation describes an operation by its type, name, and
|
||||
// attributes. It can be "executed" by the context with some input tensors.
|
||||
// It is allowed to reusing the same abstract operation for multiple execution
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
protected:
|
||||
enum AbstractOpKind { kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractOpKind getKind() const { return kind_; }
|
||||
virtual ~AbstractOp() = default;
|
||||
|
||||
// Sets the type of the operation (for example `AddV2`).
|
||||
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
||||
|
||||
// Sets the name of the operation: this is an optional identifier that is
|
||||
// not intended to carry semantics and preserved/propagated without
|
||||
// guarantees.
|
||||
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
||||
|
||||
// Add a `TypeAttribute` on the operation.
|
||||
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractOpKind kind_;
|
||||
};
|
||||
|
||||
// This holds the context for the execution: dispatching operations either to an
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
virtual ~ExecutionContext() = default;
|
||||
|
||||
// Executes the operation on the provided inputs and populate the OutputList
|
||||
// with the results. The input tensors must match the current context.
|
||||
// The effect of "executing" an operation depends on the context: in an Eager
|
||||
// context it will dispatch it to the runtime for execution, while in a
|
||||
// tracing context it will add the operation to the current function.
|
||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
};
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
||||
} \
|
||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
||||
}
|
||||
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
@ -15,17 +15,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -33,26 +30,24 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
@ -69,7 +64,6 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
TFE_DeleteTensorHandle(t);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
@ -83,100 +77,93 @@ TEST(UnifedCAPI, TestBasicEager) {
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TFE_DeleteTensorHandle(result_t);
|
||||
TF_DeleteOutputList(o);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Enter a graph context.
|
||||
TF_Graph* g = TF_NewGraph();
|
||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output placeholder_t = {operation, 0};
|
||||
TF_GraphTensor* graph_t =
|
||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {t, t};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(t);
|
||||
TF_DeleteGraphTensor(graph_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TF_GraphTensor* result_graph_tensor =
|
||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
||||
TF_DeleteAbstractTensor(result);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output result_output =
|
||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
||||
TF_DeleteGraphTensor(result_graph_tensor);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
string fn_name = "double";
|
||||
TF_Function* f = TF_GraphToFunction(
|
||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractTensor(output_t);
|
||||
|
||||
// Build an eager context to run the function.
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
||||
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
||||
// Build the abstract op to run the function.
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* input_t =
|
||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(add_outputs, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
@ -185,20 +172,181 @@ TEST(UnifedCAPI, TestBasicGraph) {
|
||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TFE_DeleteTensorHandle(input_eager);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TFE_DeleteTensorHandle(final);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteFunction(f);
|
||||
TF_DeleteAbstractFunction(func);
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteGraphContext(graph_context);
|
||||
TF_DeleteGraph(g);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// This should fail.
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_FAILED_PRECONDITION, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an Eager operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,143 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Scalar(
|
||||
int64 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Scalar(
|
||||
uint64 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Scalar(
|
||||
int32 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatScalar(
|
||||
float value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleScalar(
|
||||
double value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfScalar(
|
||||
Eigen::half value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringScalar(
|
||||
tstring value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface>
|
||||
ContextInterface::CreateComplex128Scalar(complex128 value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolScalar(
|
||||
bool value) {
|
||||
return std::make_unique<TensorInterface>(Tensor(value));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_INT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_UINT64, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_INT32, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_FLOAT, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_DOUBLE, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_HALF, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_STRING, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface>
|
||||
ContextInterface::CreateComplex128Tensor(absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_COMPLEX128, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) {
|
||||
return std::make_unique<TensorInterface>(
|
||||
Tensor(DT_BOOL, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface>
|
||||
ContextInterface::CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) {
|
||||
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
|
||||
return std::make_unique<TensorHandleInterface>(
|
||||
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
|
||||
/*op_device=*/nullptr, ctx_));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractOperationInterface>
|
||||
ContextInterface::CreateOperation() {
|
||||
return std::make_unique<tensorflow::OperationInterface>(ctx_);
|
||||
}
|
||||
|
||||
void ContextInterface::ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* devices) {
|
||||
ctx_->ListDevices(devices);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -15,13 +15,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor_interface.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -32,123 +36,55 @@ namespace tensorflow {
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
public:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Scalar creation functions
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
|
||||
int64 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
|
||||
uint64 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
|
||||
int32 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
|
||||
float value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
|
||||
double value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
|
||||
Eigen::half value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
|
||||
tstring value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
|
||||
complex128 value) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
|
||||
bool value) = 0;
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
|
||||
virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
|
||||
virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
|
||||
virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
|
||||
virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
|
||||
|
||||
// Tensor creation functions
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
virtual AbstractTensorInterface* CreateTensor(
|
||||
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) = 0;
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
||||
virtual void ClearCachesAndThreadExecutors() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
||||
// TODO(gjn): Try to move these all to EagerContext and make it implement
|
||||
// AbstractContextInterface. Currently, this is not so straightforward because
|
||||
// of various BUILD file dependencies.
|
||||
class ContextInterface : public AbstractContextInterface {
|
||||
public:
|
||||
explicit ContextInterface(EagerContext* ctx) : ctx_(ctx) {}
|
||||
~ContextInterface() override {}
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt64Scalar(
|
||||
int64 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateUint64Scalar(
|
||||
uint64 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt32Scalar(
|
||||
int32 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateFloatScalar(
|
||||
float value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateDoubleScalar(
|
||||
double value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateHalfScalar(
|
||||
Eigen::half value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateStringScalar(
|
||||
tensorflow::tstring value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateComplex128Scalar(
|
||||
tensorflow::complex128 value) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateBoolScalar(
|
||||
bool value) override;
|
||||
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateUint64Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateInt32Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateFloatTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateDoubleTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateHalfTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateStringTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateComplex128Tensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) override;
|
||||
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
|
||||
|
||||
void ListDevices(std::vector<DeviceAttributes>* devices) override;
|
||||
|
||||
// For runtime specific APIs, provide ability to get the underlying context.
|
||||
EagerContext* Context() const { return ctx_; }
|
||||
|
||||
private:
|
||||
EagerContext* ctx_;
|
||||
};
|
||||
|
||||
inline EagerContext* ContextFromInterface(
|
||||
const std::unique_ptr<AbstractContextInterface>& context) {
|
||||
return down_cast<ContextInterface*>(context.get())->Context();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
|
@ -16,134 +16,16 @@ limitations under the License.
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/custom_device_testutil.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -156,7 +38,7 @@ TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived, &executed, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
@ -245,7 +127,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -276,9 +158,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
tensorflow::string(
|
||||
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
|
||||
TFE_TensorHandle* var_value_unpacked =
|
||||
reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(var_value, status.get()))
|
||||
->tensor;
|
||||
UnpackTensorHandle(var_value, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
|
||||
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
|
||||
@ -296,7 +176,7 @@ TEST(CUSTOM_DEVICE, MakeVariable) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
@ -331,7 +211,7 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
|
||||
// Assign to the variable, copying to the custom device.
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
|
||||
TestScalarTensorHandle(111.f), TFE_DeleteTensorHandle);
|
||||
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
|
||||
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -346,16 +226,21 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
|
||||
// Read the variable's value.
|
||||
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
|
||||
executed = false;
|
||||
num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
|
||||
EXPECT_FALSE(TF_GetCode(status.get()) == TF_OK)
|
||||
<< "Execution should fail because the variable is being used on the "
|
||||
"wrong device.";
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
ASSERT_EQ(
|
||||
tensorflow::string(name),
|
||||
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
|
||||
TFE_DeleteTensorHandle(var_value);
|
||||
|
||||
// Free the backing buffer for the variable.
|
||||
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
|
||||
TFE_OpAddInput(op.get(), var_handle, status.get());
|
||||
@ -366,6 +251,79 @@ TEST(CUSTOM_DEVICE, AccessVariableOnWrongDevice) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
bool arrived = false;
|
||||
bool executed = false;
|
||||
RegisterLoggingDevice(context.get(), custom0, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
RegisterLoggingDevice(context.get(), custom1, &arrived, &executed,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
|
||||
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
|
||||
ASSERT_FALSE(arrived);
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
arrived = false;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
|
||||
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
|
||||
status.get()),
|
||||
TFE_DeleteTensorHandle);
|
||||
ASSERT_TRUE(arrived);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Base case: two CPU inputs executes fine.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
|
||||
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
|
||||
TFE_TensorHandle* retval;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in same custom device works.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
|
||||
num_retvals = 1;
|
||||
executed = false;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_TRUE(executed);
|
||||
TFE_DeleteTensorHandle(retval);
|
||||
|
||||
// Custom device: inputs in different custom devices fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
|
||||
|
||||
// Custom device: mix of custom/physical fails.
|
||||
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
|
||||
num_retvals = 1;
|
||||
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
|
||||
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
|
||||
ASSERT_TRUE(
|
||||
absl::StrContains(TF_Message(status.get()), "[]")); // kVariantDeviceNull
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -394,5 +352,3 @@ TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
172
tensorflow/c/eager/custom_device_testutil.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
// Set to true whenever an operation is executed
|
||||
bool* executed_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* context, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs, const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(context, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddAttrs(op, attributes);
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
*(dev->executed_flag) = true;
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->executed_flag = executed_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status) {
|
||||
return reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
|
||||
->tensor;
|
||||
}
|
||||
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info) {
|
||||
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
|
||||
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device->delete_device = &DeleteLoggingDevice;
|
||||
custom_device->execute = &LoggingDeviceExecute;
|
||||
*device = custom_device;
|
||||
LoggingDevice* logging_device = new LoggingDevice;
|
||||
logging_device->arrived_flag = arrived_flag;
|
||||
logging_device->executed_flag = executed_flag;
|
||||
logging_device->device_name = name;
|
||||
logging_device->underlying_device =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
*device_info = reinterpret_cast<void*>(logging_device);
|
||||
}
|
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
36
tensorflow/c/eager/custom_device_testutil.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag, bool* executed_flag,
|
||||
TF_Status* status);
|
||||
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
|
||||
bool* executed_flag, TFE_CustomDevice** device,
|
||||
void** device_info);
|
||||
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
|
||||
TF_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
|
@ -16,8 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/dlpack.h"
|
||||
|
||||
#include "include/dlpack/dlpack.h" // from @dlpack
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -40,16 +43,15 @@ struct TfDlManagedTensorCtx {
|
||||
|
||||
// Gets tensor from eager tensor handle.
|
||||
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
if (h == nullptr || !h->handle->IsValid(&status->status)) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"The passed in handle is a nullptr");
|
||||
if (h == nullptr) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (handle->Type() != TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
"DLPack doesn't support ", handle->TypeString(), " tensor");
|
||||
return nullptr;
|
||||
}
|
||||
const tensorflow::Tensor* tensor;
|
||||
@ -107,7 +109,7 @@ DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
|
||||
// Gets DLPack's DLContext from eager tensor handle.
|
||||
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
|
||||
DLContext ctx;
|
||||
const char* device_name = h->handle->DeviceName(&status->status);
|
||||
const char* device_name = tensorflow::unwrap(h)->DeviceName(&status->status);
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
|
||||
std::string device_type = parsed_name.type;
|
||||
@ -286,7 +288,8 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return static_cast<void*>(dlm_tensor);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
|
||||
TFE_Context* ctx) {
|
||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
|
||||
DLTensor* dl_tensor = &dlmt->dl_tensor;
|
||||
absl::optional<std::string> device_name =
|
||||
@ -319,16 +322,10 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,8 @@ TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
|
||||
|
||||
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
|
||||
TF_Status* status);
|
||||
TF_Status* status,
|
||||
TFE_Context* ctx);
|
||||
|
||||
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
|
||||
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
|
||||
|
@ -1,309 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
OperationInterface::OperationInterface(EagerContext* ctx) : operation_(ctx) {}
|
||||
|
||||
const string& OperationInterface::DeviceName() const {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
(operation_.Device() == kVariantDeviceNull)
|
||||
? operation_.EagerContext().HostCPU()
|
||||
: operation_.Device();
|
||||
return absl::visit([](auto* d) -> const string& { return d->name(); },
|
||||
variant_device);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetDeviceName(const char* name) {
|
||||
return operation_.SetDeviceName(name);
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) {
|
||||
operation_.MutableAttrs()->Set(attr_name, StringPiece(data, length));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloat(const char* attr_name, float value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBool(const char* attr_name, bool value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrType(const char* attr_name,
|
||||
TF_DataType value) {
|
||||
operation_.MutableAttrs()->Set(attr_name, static_cast<DataType>(value));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
operation_.MutableAttrs()->Set(attr_name, proto);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
OperationInterface* value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value.get());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionName(const char* attr_name,
|
||||
const char* data,
|
||||
size_t length) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(data, length);
|
||||
operation_.MutableAttrs()->Set(attr_name, attr_value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) {
|
||||
Tensor t;
|
||||
TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
|
||||
operation_.MutableAttrs()->Set(attr_name, t);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
operation_.MutableAttrs()->Set(attr_name, v);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const float>(values, num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) {
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const DataType>(
|
||||
reinterpret_cast<const DataType*>(values), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) {
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
auto value_operation =
|
||||
tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
|
||||
funcs[i].set_name(value_operation->operation_.Name());
|
||||
value_operation->operation_.Attrs().FillAttrValueMap(
|
||||
funcs[i].mutable_attr());
|
||||
}
|
||||
operation_.MutableAttrs()->Set(
|
||||
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const OpDef* OperationInterface::GetOpDef(Status* status) {
|
||||
const tensorflow::OpDef* op_def = operation_.OpDef();
|
||||
if (op_def) return op_def;
|
||||
*status = OpDefForOp(Name(), &op_def);
|
||||
return op_def;
|
||||
}
|
||||
|
||||
Status OperationInterface::InputLength(const char* input_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, &name_ranges, nullptr));
|
||||
auto iter = name_ranges.find(input_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Input '", input_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
Status status;
|
||||
const tensorflow::OpDef* op_def = GetOpDef(&status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
AttrValueMap attrs;
|
||||
operation_.Attrs().FillAttrValueMap(&attrs);
|
||||
NameRangeMap name_ranges;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(AttrSlice(&attrs), *op_def, nullptr, &name_ranges));
|
||||
auto iter = name_ranges.find(output_name);
|
||||
if (iter == name_ranges.end()) {
|
||||
return errors::InvalidArgument("Output '", output_name, "' not found");
|
||||
}
|
||||
*length = iter->second.second - iter->second.first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
}
|
||||
|
||||
Status OperationInterface::Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) {
|
||||
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerExecute(&operation_, handle_retvals.data(), num_retvals));
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals->at(i).reset(
|
||||
new tensorflow::TensorHandleInterface(handle_retvals[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
operation_.SetCancellationManager(
|
||||
&cancellation_manager->cancellation_manager);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OperationInterface::SetUseXla(bool enable) {
|
||||
operation_.SetUseXla(enable);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -15,36 +15,62 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/fixed_array.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
public:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
|
||||
// Returns the operation's device name.
|
||||
//
|
||||
// The value returned may be different from the one set by SetDeviceName, but
|
||||
// it will be compatible with it: the name will be updated by device placement
|
||||
// logic to refer to the specific device chosen.
|
||||
//
|
||||
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
|
||||
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
|
||||
// chosen for the operation by the device placement logic in the
|
||||
// executor. After that, the value returned by DeviceName will be a full
|
||||
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
|
||||
virtual const string& DeviceName() const = 0;
|
||||
|
||||
// Sets the operation device name.
|
||||
//
|
||||
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
|
||||
// the result will be used as a constraint for device placement. See the
|
||||
// documentation for DeviceName for more details.
|
||||
//
|
||||
// The value will override the previous value - that is, no "merging" of
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
virtual Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
@ -52,15 +78,15 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
|
||||
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
|
||||
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
|
||||
virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) = 0;
|
||||
virtual Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) = 0;
|
||||
@ -68,103 +94,25 @@ class AbstractOperationInterface {
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values, int num_values) = 0;
|
||||
virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) {
|
||||
return errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
virtual Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return errors::Unimplemented("SetCancellationManager not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
public:
|
||||
explicit OperationInterface(EagerContext* ctx);
|
||||
~OperationInterface() override{};
|
||||
|
||||
void Clear() override { operation_.Clear(); }
|
||||
Status Reset(const char* op, const char* raw_device_name) override {
|
||||
return operation_.Reset(op, raw_device_name, false, nullptr);
|
||||
}
|
||||
|
||||
const string& Name() const override { return operation_.Name(); }
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
|
||||
Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) override;
|
||||
Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) override;
|
||||
Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override {
|
||||
return operation_.OpDef();
|
||||
};
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, TF_DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const TF_DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(const char* attr_name, const TFE_Op** value,
|
||||
int num_values) override;
|
||||
|
||||
Status InputLength(const char* input_name, int* length) override;
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
||||
Status SetUseXla(bool enable) override;
|
||||
Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
private:
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
EagerOperation operation_;
|
||||
protected:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
54
tensorflow/c/eager/parallel_device/BUILD
Normal file
54
tensorflow/c/eager/parallel_device/BUILD
Normal file
@ -0,0 +1,54 @@
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# Currently pybind extension shared objects must use only C API headers since
|
||||
# the C API has static initializers duplicated in the Python bindings. So we
|
||||
# need a second rule that omits .cc files, in
|
||||
# tensorflow/python:_pywrap_parallel_device.
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["parallel_device.h"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "sources",
|
||||
srcs = ["parallel_device.cc"],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_device",
|
||||
srcs = [":sources"],
|
||||
hdrs = [":headers"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_device_test",
|
||||
srcs = ["parallel_device_test.cc"],
|
||||
deps = [
|
||||
":parallel_device",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
595
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
595
tensorflow/c/eager/parallel_device/parallel_device.cc
Normal file
@ -0,0 +1,595 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
namespace {
|
||||
|
||||
// Functor for making unique_ptrs slightly more ergonomic. Using
|
||||
// decltype(delete_fn) in the unique_ptr's second template argument requires
|
||||
// passing a function pointer to delete_fn when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) const {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class OpDeleter {
|
||||
public:
|
||||
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
|
||||
};
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
|
||||
// A representation of the custom device passed in and out of the TFE custom
|
||||
// device APIs, providing context about the parallel device to
|
||||
// ParallelDeviceExecute.
|
||||
class ParallelDevice {
|
||||
public:
|
||||
ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices);
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
// Sets a bad status and returns a nullptr if `tensor` is already on the
|
||||
// ParallelDevice, or if the individual copies fail.
|
||||
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) const;
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
//
|
||||
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
|
||||
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
|
||||
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
|
||||
// tensor, but other operations will implicitly broadcast non-parallel input
|
||||
// tensors across the ParallelDevice's component devices.
|
||||
//
|
||||
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
|
||||
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
|
||||
// causes `Execute` to return non-parallel tensors.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK.
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
// Implements the parallel case for `Execute`, where all of the outputs of the
|
||||
// operation are ParallelTensors, and all inputs are either ParallelTensors or
|
||||
// should be implicitly broadcast. This means the operation is not
|
||||
// TPUReplicatedInput or TPUReplicatedOutput.
|
||||
//
|
||||
// The returned optional has a value if and only if `status` evaluates to
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ExecuteParallelOperation(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
const std::string& device_name() const { return device_name_; }
|
||||
|
||||
private:
|
||||
// The name of the parallel device
|
||||
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
|
||||
const std::string device_name_;
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
};
|
||||
|
||||
// The internal representation of a TFE_TensorHandle placed on a
|
||||
// ParallelDevice. Contains a tuple of tensors, one on each of the
|
||||
// `underlying_devices_` of the ParallelDevice.
|
||||
class ParallelTensor {
|
||||
public:
|
||||
// Construct a ParallelTensor from TensorHandles placed on the component
|
||||
// devices of a ParallelDevice.
|
||||
static std::unique_ptr<ParallelTensor> FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status);
|
||||
|
||||
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
|
||||
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
|
||||
std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status);
|
||||
|
||||
size_t num_tensors() const { return tensors_.size(); }
|
||||
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
|
||||
|
||||
private:
|
||||
ParallelTensor(const ParallelDevice& device,
|
||||
std::vector<TensorHandlePtr> tensors,
|
||||
std::vector<int64_t> shape, const TF_DataType dtype)
|
||||
: device_(device),
|
||||
tensors_(std::move(tensors)),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
|
||||
const ParallelDevice& device_;
|
||||
const std::vector<TensorHandlePtr> tensors_;
|
||||
const std::vector<int64_t> shape_;
|
||||
const TF_DataType dtype_;
|
||||
};
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::string& name,
|
||||
const std::vector<std::string>& devices)
|
||||
: device_name_(name),
|
||||
underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
|
||||
if (device_name_ == current_device) {
|
||||
std::string message(absl::StrCat(
|
||||
"Tried to copy a TensorHandle to its existing device: ", device_name_));
|
||||
TF_SetStatus(status, TF_INTERNAL, message.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (const std::string& underlying_device_name : underlying_devices_) {
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, context, underlying_device_name.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
components.emplace_back(t);
|
||||
}
|
||||
return ParallelTensor::FromTensorHandles(*this, std::move(components),
|
||||
status);
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
|
||||
// TODO(allenl): We should remove "TPU" from these op names at the very least,
|
||||
// or consider other ways of packing/unpacking parallel tensors.
|
||||
if (operation_name == std::string("TPUReplicatedInput")) {
|
||||
// Special-cased operation for packing per-device tensors into one parallel
|
||||
// tensor.
|
||||
if (inputs.size() != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
|
||||
inputs.size()));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
|
||||
std::string message(absl::StrCat(
|
||||
"Expected all inputs to TPUReplicatedInput to be non-parallel "
|
||||
"TensorHandles. The input ",
|
||||
i,
|
||||
" was a parallel tensor (already "
|
||||
"placed on the parallel device)."));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
components.emplace_back(TFE_TensorHandleCopySharingTensor(
|
||||
absl::get<TFE_TensorHandle*>(inputs[i]), status));
|
||||
}
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(1);
|
||||
result_content.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
} else if (operation_name == std::string("TPUReplicatedOutput")) {
|
||||
// Special-cased operation for un-packing one parallel tensor into
|
||||
// per-device tensors.
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (expected_outputs != underlying_devices_.size()) {
|
||||
std::string message(absl::StrCat(
|
||||
"The parallel device ", device_name_, " expected ",
|
||||
underlying_devices_.size(),
|
||||
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
|
||||
return result;
|
||||
}
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"Expected the input to "
|
||||
"TPUReplicatedOutput to be a parallel tensor (placed on the "
|
||||
"parallel device).");
|
||||
return result;
|
||||
}
|
||||
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
|
||||
std::vector<MaybeParallelTensorOwned> outputs;
|
||||
outputs.reserve(t->num_tensors());
|
||||
for (int i = 0; i < t->num_tensors(); ++i) {
|
||||
TensorHandlePtr this_output(
|
||||
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
|
||||
outputs.emplace_back(std::move(this_output));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(outputs));
|
||||
return result;
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
ExecuteParallelOperation(context, std::move(inputs), operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
std::move(maybe_parallel_results.value()));
|
||||
std::vector<MaybeParallelTensorOwned> result_content;
|
||||
result_content.reserve(parallel_results.size());
|
||||
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
|
||||
result_content.push_back(
|
||||
MaybeParallelTensorOwned(std::move(parallel_result)));
|
||||
}
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::ExecuteParallelOperation(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const {
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
|
||||
per_device_outputs.reserve(first_op_output_count);
|
||||
for (int i = 0; i < first_op_output_count; ++i) {
|
||||
std::vector<TensorHandlePtr> components;
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int j = 0; j < underlying_devices_.size(); ++j) {
|
||||
components.push_back(std::move(per_device_output_tensors[j][i]));
|
||||
}
|
||||
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
|
||||
*this, std::move(components), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
result.emplace(std::move(per_device_outputs));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
|
||||
const ParallelDevice& parallel_device,
|
||||
std::vector<TensorHandlePtr> components, TF_Status* status) {
|
||||
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
|
||||
std::vector<int64_t> shape(
|
||||
TFE_TensorHandleNumDims(components[0].get(), status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
|
||||
// Verify that the TensorHandle's shape and dtype match all of the component
|
||||
// shapes and dtypes.
|
||||
for (TensorHandlePtr& component : components) {
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
if (tensor_dim != shape[i]) {
|
||||
// TODO(allenl): Allow shapes to differ.
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Components of a ParallelTensor must currently all have "
|
||||
"the same shape");
|
||||
return nullptr;
|
||||
}
|
||||
if (TFE_TensorHandleDataType(component.get()) != dtype) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Components of a ParallelTensor must all have "
|
||||
"the same dtype");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
|
||||
parallel_device, std::move(components), std::move(shape), dtype));
|
||||
}
|
||||
|
||||
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
|
||||
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
|
||||
// reference counts drop to zero.
|
||||
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<ParallelTensor*>(data);
|
||||
}
|
||||
|
||||
TensorHandlePtr ParallelTensor::AsTensorHandle(
|
||||
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
|
||||
TF_Status* status) {
|
||||
// The resulting TensorHandle owns an opaque pointer to "device memory", which
|
||||
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
|
||||
// deleted, it will call ParallelTensorDeallocator to free the struct.
|
||||
ParallelTensor* t_released = t.release();
|
||||
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
|
||||
context, t_released->device_.device_name().c_str(), t_released->dtype_,
|
||||
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
|
||||
&ParallelTensorDeallocator, nullptr, status));
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
|
||||
// a ParallelTensor with one copy of `tensor` for each device in the
|
||||
// ParallelDevice.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
dev->CopyToParallelDevice(context, tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
|
||||
status)
|
||||
.release();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
|
||||
// registration.
|
||||
//
|
||||
// Currently this is an error, and un-packing ParallelTensors must be performed
|
||||
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
|
||||
//
|
||||
// TODO(allenl): There are some use-cases that are only supported by copying to
|
||||
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
|
||||
// need to return something here or address these use-cases one by one.
|
||||
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a parallel device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::execute in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
|
||||
TFE_TensorHandle** inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* status,
|
||||
void* device_info) {
|
||||
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
|
||||
std::vector<MaybeParallelTensorUnowned> typed_inputs;
|
||||
typed_inputs.reserve(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
const char* tensor_handle_device =
|
||||
TFE_TensorHandleDeviceName(inputs[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (dev->device_name() == tensor_handle_device) {
|
||||
// We assume that any tensors already placed on this device are
|
||||
// ParallelTensors.
|
||||
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
|
||||
TFE_TensorHandleDevicePointer(inputs[i], status)));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
typed_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
|
||||
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
|
||||
*num_outputs, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
if (!maybe_typed_outputs.has_value()) {
|
||||
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<MaybeParallelTensorOwned> typed_outputs(
|
||||
std::move(maybe_typed_outputs.value()));
|
||||
|
||||
if (typed_outputs.size() > *num_outputs) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"The allocated output buffer was too small.");
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < typed_outputs.size(); ++i) {
|
||||
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
|
||||
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
|
||||
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
|
||||
} else {
|
||||
outputs[i] = ParallelTensor::AsTensorHandle(
|
||||
context,
|
||||
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
|
||||
typed_output)),
|
||||
status)
|
||||
.release();
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
}
|
||||
*num_outputs = typed_outputs.size();
|
||||
}
|
||||
|
||||
// For TFE_CustomDevice::delete_device in the parallel device registration.
|
||||
//
|
||||
// Since this function is used to satisfy the TFE_CustomDevice C API,
|
||||
// device_info is passed in using a C-style generic. It must always be a
|
||||
// ParallelDevice.
|
||||
void DeleteParallelDevice(void* device_info) {
|
||||
delete reinterpret_cast<ParallelDevice*>(device_info);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info) {
|
||||
device->copy_tensor_to_device = &CopyToParallelDevice;
|
||||
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
|
||||
device->delete_device = &DeleteParallelDevice;
|
||||
device->execute = &ParallelDeviceExecute;
|
||||
std::vector<std::string> underlying_devices_vector;
|
||||
underlying_devices_vector.reserve(num_underlying_devices);
|
||||
for (int device_index = 0; device_index < num_underlying_devices;
|
||||
++device_index) {
|
||||
underlying_devices_vector.push_back(underlying_devices[device_index]);
|
||||
}
|
||||
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
65
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
65
tensorflow/c/eager/parallel_device/parallel_device.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// Allocate a parallel device named `device_name` which forwards operations to
|
||||
// `underlying_devices`, maintaining "parallel tensors" with components placed
|
||||
// on each underlying device.
|
||||
//
|
||||
// For example if `device_name` is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0"
|
||||
// and `underlying_devices` is
|
||||
// {"/job:localhost/replica:0/task:0/device:GPU:0",
|
||||
// "/job:localhost/replica:0/task:0/device:GPU:1"}
|
||||
// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1.
|
||||
//
|
||||
// Implicit copies onto `device_name` are allowed, replicating the value once
|
||||
// per device in `underlying_devices`. Implicit copies off of the device throw
|
||||
// an error.
|
||||
//
|
||||
// All component tensors must have the same dtype. Currently they must also have
|
||||
// the same shape, although this requirement may be relaxed in the future.
|
||||
//
|
||||
// `device_name` must not name an existing physical or custom device (see
|
||||
// the documentation for TFE_RegisterCustomDevice for more information).
|
||||
//
|
||||
// Tensors may be copied on or off the device explicitly using
|
||||
// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with
|
||||
// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the
|
||||
// parallel device creates a parallel tensor `x` with `a` on the first of
|
||||
// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked =
|
||||
// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor
|
||||
// into its components.
|
||||
//
|
||||
// The filled `device` struct and the allocated `device_info` struct may be
|
||||
// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match.
|
||||
void AllocateParallelDevice(const char* device_name,
|
||||
const char* const* underlying_devices,
|
||||
int num_underlying_devices,
|
||||
TFE_CustomDevice* device, void** device_info);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
|
921
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
921
tensorflow/c/eager/parallel_device/parallel_device_test.cc
Normal file
@ -0,0 +1,921 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
// NOTE(allenl): These tests currently go through TFE_Execute and so are
|
||||
// integration testing rather than purely testing the parallel device. They
|
||||
// correspond fairly well to the implementation, but testing the C++ directly is
|
||||
// another option.
|
||||
|
||||
// Functor for making unique_ptr to TFE_TensorHandle slightly more
|
||||
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
|
||||
// template argument requires passing a function pointer to
|
||||
// TFE_DeleteTensorHandle when constructing the unique_ptr.
|
||||
class TensorHandleDeleter {
|
||||
public:
|
||||
void operator()(TFE_TensorHandle* to_delete) {
|
||||
TFE_DeleteTensorHandle(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
// A helper for performing common operations on variables. A much more
|
||||
// restricted stand-in for tf.Variable in Python.
|
||||
class Variable {
|
||||
public:
|
||||
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
|
||||
// indication of the dtype of the variable's value.
|
||||
//
|
||||
// Note that creating this resource-dtype handle can fail, so `Create` is a
|
||||
// separate static method which returns a status.
|
||||
Variable(TFE_TensorHandle* handle, TF_DataType type)
|
||||
: handle_(handle), type_(type) {}
|
||||
|
||||
// Helper for constructing a resource handle and wrapping it in a `Variable`
|
||||
// object.
|
||||
static Variable* Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status);
|
||||
// Dereferences the backing buffer for the variable. Note that since this can
|
||||
// fail (it runs operations), it must be called explicitly and the resulting
|
||||
// `status` checked.
|
||||
void Destroy(TFE_Context* context, TF_Status* status);
|
||||
|
||||
// Reads from the variable.
|
||||
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
|
||||
// Assigns a new value to the variable.
|
||||
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
|
||||
// Adds `value` to the existing value of the variable.
|
||||
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status);
|
||||
|
||||
private:
|
||||
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
|
||||
// AssignSub, ...).
|
||||
void GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status);
|
||||
|
||||
// The a handle for the resource-dtype tensor pointing to the variable's
|
||||
// buffer.
|
||||
TFE_TensorHandle* handle_;
|
||||
// The dtype of the variable's buffer (input dtype for assignments, output
|
||||
// dtype of read operations).
|
||||
TF_DataType type_;
|
||||
};
|
||||
|
||||
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
|
||||
const int64_t* dims, const int num_dims,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type);
|
||||
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
|
||||
TFE_OpSetAttrString(op.get(), "container", "", 0);
|
||||
// Use the special GUID for no buffer sharing
|
||||
//
|
||||
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
|
||||
// only reasonable way to make variables with no aliasing using the eager C
|
||||
// API.
|
||||
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
|
||||
no_sharing.length());
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_TensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return new Variable(var_handle, type);
|
||||
}
|
||||
|
||||
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
|
||||
// Free the backing buffer for the variable.
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Delete the variable handle itself.
|
||||
TFE_DeleteTensorHandle(handle_);
|
||||
}
|
||||
|
||||
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* var_value = nullptr;
|
||||
TFE_Execute(op.get(), &var_value, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(var_value);
|
||||
}
|
||||
|
||||
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
|
||||
TFE_TensorHandle* value, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrType(op.get(), "dtype", type_);
|
||||
TFE_OpAddInput(op.get(), handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpAddInput(op.get(), value, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(handle_, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op.get(), nullptr, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
|
||||
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignAddVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
|
||||
TF_Status* status) {
|
||||
GeneralAssignment("AssignVariableOp", context, value, status);
|
||||
}
|
||||
|
||||
// Passed to `TF_NewTensor` to indicate how an array of floats should be
|
||||
// deleted.
|
||||
static void FloatDeallocator(void* data, size_t, void* arg) {
|
||||
delete[] static_cast<float*>(data);
|
||||
}
|
||||
|
||||
// Creates a TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
|
||||
const int num_bytes = sizeof(float);
|
||||
float* values = new float[1];
|
||||
values[0] = v;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Creates a rank-one TFE_TensorHandle with value `v`.
|
||||
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
|
||||
TF_Status* status) {
|
||||
const int num_bytes = v.size() * sizeof(float);
|
||||
float* values = new float[v.size()];
|
||||
memcpy(values, v.data(), num_bytes);
|
||||
int64_t dims = v.size();
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
|
||||
&FloatDeallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
|
||||
}
|
||||
|
||||
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
void ExtractPerDeviceValues(
|
||||
TFE_Context* context, TFE_TensorHandle* input,
|
||||
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles[num_replicas];
|
||||
int num_retvals = num_replicas;
|
||||
TFE_Execute(op.get(), result_handles, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
(*components)[i].reset(result_handles[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
|
||||
template <std::size_t num_replicas>
|
||||
TensorHandlePtr CreatePerDeviceValues(
|
||||
TFE_Context* context,
|
||||
const std::array<TFE_TensorHandle*, num_replicas>& components,
|
||||
const char* device, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
|
||||
for (int i = 0; i < num_replicas; ++i) {
|
||||
TFE_OpAddInput(op.get(), components[i], status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
|
||||
TFE_TensorHandle* second, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpAddInput(op.get(), second, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
const char* first_device = TFE_TensorHandleDeviceName(first, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), first_device, status);
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
// Assert that `handle` is equal to `expected_value`.
|
||||
void AssertScalarFloatEq(TFE_TensorHandle* handle, float expected_value) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
|
||||
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(expected_value,
|
||||
*static_cast<float*>(TF_TensorData(value_zero.get())));
|
||||
}
|
||||
|
||||
template <std::size_t num_devices>
|
||||
void RegisterParallelDevice(
|
||||
TFE_Context* context, const char* device_name,
|
||||
const std::array<const char*, num_devices>& underlying_devices,
|
||||
TF_Status* status) {
|
||||
TFE_CustomDevice device;
|
||||
void* device_info;
|
||||
tensorflow::eager::AllocateParallelDevice(
|
||||
device_name, underlying_devices.data(), underlying_devices.size(),
|
||||
&device, &device_info);
|
||||
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
|
||||
}
|
||||
|
||||
// Create and modify a variable placed on a parallel device which composes
|
||||
// `first_device` and `second_device`.
|
||||
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
const char* second_device) {
|
||||
// Register the custom device
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{first_device, second_device};
|
||||
RegisterParallelDevice(context, device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a variable handle (uninitialized to start) placed on the parallel
|
||||
// device.
|
||||
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
|
||||
to_delete->Destroy(context, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
delete to_delete;
|
||||
};
|
||||
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
|
||||
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
|
||||
status.get()),
|
||||
variable_deleter);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Assign an initial value to the variable, implicitly mirroring it to each
|
||||
// component device.
|
||||
{
|
||||
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
variable->Assign(context, initial_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read from the variable and verify that we have a parallel tensor.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 20.);
|
||||
AssertScalarFloatEq(components[1].get(), 20.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Add a parallel tensor with different values on each device to the variable.
|
||||
{
|
||||
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value =
|
||||
CreatePerDeviceValues(context, components, device_name, status.get());
|
||||
variable->AssignAdd(context, combined_value.get(), status.get());
|
||||
}
|
||||
|
||||
// Read the variable and verify that each component has the right modified
|
||||
// value.
|
||||
{
|
||||
TensorHandlePtr read = variable->Read(context, status.get());
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context, read.get(), &components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(components[0].get(), 23.);
|
||||
AssertScalarFloatEq(components[1].get(), 18.);
|
||||
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPU) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Skip the test if no TPU is available.
|
||||
std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
|
||||
TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool has_tpu = false;
|
||||
for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
|
||||
++device_index) {
|
||||
std::string device_type =
|
||||
TF_DeviceListType(devices.get(), device_index, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
if (device_type == "TPU") {
|
||||
has_tpu = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_tpu) {
|
||||
BasicTestsForTwoDevices(context.get(),
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1";
|
||||
std::array<const char*, 2> underlying_devices{first_device_name,
|
||||
second_device_name};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Copying on to a parallel device is OK.
|
||||
TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
|
||||
cpu_value.get(), context.get(), device_name, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
const char* backing_device =
|
||||
TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ASSERT_EQ(std::string(device_name), backing_device);
|
||||
|
||||
// Un-pack the parallel tensor to verify that the copy was successful.
|
||||
{
|
||||
std::array<TensorHandlePtr, 2> components;
|
||||
ExtractPerDeviceValues(context.get(), device_value.get(), &components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// The value of the original tensor is replicated on each device.
|
||||
AssertScalarFloatEq(components[0].get(), 3.);
|
||||
AssertScalarFloatEq(components[1].get(), 3.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
// Copies off of parallel devices must be explicit.
|
||||
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
|
||||
device_value.get(), context.get(), first_device_name, status.get()));
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create two vectors with different lengths
|
||||
std::vector<float> size_two_value{1., 2.};
|
||||
std::vector<float> size_three_value{1., 2., 3.};
|
||||
TensorHandlePtr size_two(
|
||||
VectorFloatTensorHandle(size_two_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr size_three(
|
||||
VectorFloatTensorHandle(size_three_value, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Try to combine these values into a single parallel tensor.
|
||||
std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
3),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a parallel device with two CPUs
|
||||
const char* first_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> first_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), first_device_name,
|
||||
first_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a second parallel device with the first parallel device and one
|
||||
// additional CPU.
|
||||
const char* second_device_name =
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:1";
|
||||
std::array<const char*, 2> second_underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CUSTOM:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:2"};
|
||||
RegisterParallelDevice(context.get(), second_device_name,
|
||||
second_underlying_devices, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the first parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr first_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, first_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Nest the first parallel tensor into a second
|
||||
TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
components[0] = first_combined_value.get();
|
||||
components[1] = value_three.get();
|
||||
TensorHandlePtr second_combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, second_device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr multiply_result(Multiply(context.get(),
|
||||
second_combined_value.get(),
|
||||
negative_one.get(), status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Un-pack the parallel tensor to verify that the operation was
|
||||
// successful. The resulting structure should be:
|
||||
// second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
|
||||
std::array<TensorHandlePtr, 2> second_components;
|
||||
ExtractPerDeviceValues(context.get(), multiply_result.get(),
|
||||
&second_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
AssertScalarFloatEq(second_components[1].get(), 9.);
|
||||
|
||||
// Verify that the mirrors are placed on the component devices.
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[0].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
second_components[1].get(), status.get());
|
||||
ASSERT_EQ(second_underlying_devices[1], second_device);
|
||||
|
||||
// Un-pack the first parallel device's tensor too
|
||||
std::array<TensorHandlePtr, 2> first_components;
|
||||
ExtractPerDeviceValues(context.get(), second_components[0].get(),
|
||||
&first_components, status.get());
|
||||
AssertScalarFloatEq(first_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(first_components[1].get(), 6.);
|
||||
|
||||
first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[0], first_device);
|
||||
second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
|
||||
status.get());
|
||||
ASSERT_EQ(first_underlying_devices[1], second_device);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestInvalidPacking) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 1> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
{
|
||||
// Try to pack two TensorHandles onto a parallel device with a single
|
||||
// component.
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
|
||||
value_two.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to extract the wrong number of components from a parallel tensor
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> incorrect_components;
|
||||
ExtractPerDeviceValues(context.get(), combined_value.get(),
|
||||
&incorrect_components, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a ParallelTensor to TPUReplicatedInput
|
||||
std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
|
||||
TensorHandlePtr combined_value = CreatePerDeviceValues(
|
||||
context.get(), correct_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
|
||||
TensorHandlePtr recombined_value = CreatePerDeviceValues(
|
||||
context.get(), incorrect_components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
|
||||
{
|
||||
// Try to pass a non-parallel tensor to TPUReplicatedOutput
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
|
||||
TFE_DeleteOp);
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
|
||||
TFE_OpAddInput(op.get(), value_one.get(), status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
if (TF_GetCode(status.get()) != TF_OK) return;
|
||||
|
||||
TFE_TensorHandle* result_handles;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
|
||||
<< TF_Message(status.get());
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
int group_size, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
const char* device = TFE_TensorHandleDeviceName(input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetDevice(op.get(), device, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
|
||||
TFE_OpSetAttrInt(op.get(), "group_size", group_size);
|
||||
TFE_OpSetAttrInt(op.get(), "group_key", 0);
|
||||
TFE_OpSetAttrInt(op.get(), "instance_key", 0);
|
||||
const std::string merge_op("Add");
|
||||
TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
|
||||
|
||||
TFE_OpAddInput(op.get(), input, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
|
||||
TFE_TensorHandle* result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Create a tensor on the parallel device
|
||||
TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Run a collective sum, so each component should now be the same.
|
||||
TensorHandlePtr reduced(
|
||||
CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 3.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
|
||||
TF_DeleteGraph);
|
||||
TF_OperationDescription* placeholder_desc =
|
||||
TF_NewOperation(body.get(), "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
|
||||
TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Output x{placeholder_op, 0};
|
||||
|
||||
TF_OperationDescription* reduce_desc =
|
||||
TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
|
||||
TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
|
||||
TF_SetAttrInt(reduce_desc, "group_size", group_size);
|
||||
TF_SetAttrInt(reduce_desc, "group_key", 0);
|
||||
TF_SetAttrInt(reduce_desc, "instance_key", 0);
|
||||
|
||||
const std::string merge_op("Mul");
|
||||
TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
|
||||
merge_op.length());
|
||||
const std::string final_op("Id");
|
||||
TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
|
||||
final_op.length());
|
||||
TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
|
||||
TF_AddInput(reduce_desc, x);
|
||||
TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TF_Operation* operations[]{placeholder_op, reduce_op};
|
||||
TF_Output y{reduce_op, 0};
|
||||
const char* output_name = "y";
|
||||
std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
|
||||
TF_GraphToFunction(
|
||||
/* fn_body */ body.get(), /* fn_name */ function_name,
|
||||
/* append_hash_to_fn_name */ 0, /* num_opers */ 2,
|
||||
/* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
|
||||
/* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
|
||||
/* opts */ nullptr, /* description */ "", /* status */ status),
|
||||
TF_DeleteFunction);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_ContextAddFunction(context, function.get(), status);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestFunction) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
||||
2),
|
||||
TF_DeleteBuffer);
|
||||
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
|
||||
status.get());
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:1"};
|
||||
RegisterParallelDevice(context.get(), device_name, underlying_devices,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
const char* function_name = "test_reduce_mul";
|
||||
RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
|
||||
TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
|
||||
TensorHandlePtr parallel_value = CreatePerDeviceValues(
|
||||
context.get(), components, device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
|
||||
TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpSetDevice(op.get(), device_name, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* raw_result_handle;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TensorHandlePtr reduced(raw_result_handle);
|
||||
|
||||
std::array<TensorHandlePtr, 2> result_components;
|
||||
ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
|
||||
status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
AssertScalarFloatEq(result_components[0].get(), 7. * 9.);
|
||||
AssertScalarFloatEq(result_components[1].get(), 7. * 9.);
|
||||
|
||||
std::string first_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
std::string second_device = TFE_TensorHandleBackingDeviceName(
|
||||
result_components[1].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[1], second_device);
|
||||
}
|
@ -856,7 +856,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
}
|
||||
VLOG(1) << "Final gradients size: "
|
||||
<< gradients.size() - used_gradient_ids.size();
|
||||
for (auto grad_pair : gradients) {
|
||||
for (const auto& grad_pair : gradients) {
|
||||
if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) {
|
||||
for (const auto& g : grad_pair.second) {
|
||||
vspace.DeleteGradient(g);
|
||||
|
@ -15,11 +15,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -34,75 +32,37 @@ namespace tensorflow {
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
public:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Check if the handle is in a valid initialized state.
|
||||
virtual bool IsValid(Status* status) const = 0;
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType DataType() const = 0;
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims(Status* status) const = 0;
|
||||
virtual Status NumDims(int* num_dims) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
virtual int64_t NumElements(Status* status) const = 0;
|
||||
virtual Status NumElements(int64* num_elements) const = 0;
|
||||
// Returns size of specified dimension
|
||||
virtual int64_t Dim(int dim_index, Status* status) const = 0;
|
||||
virtual Status Dim(int dim_index, int64* dim) const = 0;
|
||||
|
||||
// Returns the device which created the handle.
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual TF_Tensor* Resolve(Status* status) = 0;
|
||||
// Returns debug information about the tensor.
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
|
||||
// Maintain mirror tensors for any implicit copies to local devices. This
|
||||
// setting is offered on a per tensor handle basis to avoid potential memory
|
||||
// over utilization due to holding on to mirrors as well as the original
|
||||
// tensor. Note this setting overrides the context mirroring policy whereby if
|
||||
// the mirroring policy is MIRRORING_NONE, we will still continue to mirror
|
||||
// this tensor.
|
||||
virtual void EnableImplicitMirroring() = 0;
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
};
|
||||
|
||||
// TODO(gjn): Try to move these all to TensorHandle and make it implement
|
||||
// AbstractTensorHandleInterface. Currently, this is not so straightforward
|
||||
// because of various BUILD file dependencies.
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
~TensorHandleInterface() override;
|
||||
|
||||
bool IsValid(Status* status) const override;
|
||||
TF_DataType DataType() const override;
|
||||
int NumDims(Status* status) const override;
|
||||
int64_t NumElements(Status* status) const override;
|
||||
int64_t Dim(int dim_index, Status* status) const override;
|
||||
|
||||
const char* DeviceName(Status* status) const override;
|
||||
const char* BackingDeviceName(Status* status) const override;
|
||||
TF_Tensor* Resolve(Status* status) override;
|
||||
TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
|
||||
void EnableImplicitMirroring() override;
|
||||
|
||||
// For runtime specific APIs, provide ability to get the underlying handle.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
inline TensorHandle* TensorHandleFromInterface(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
|
||||
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
24
tensorflow/c/eager/tfe_cancellation_manager_internal.h
Normal file
@ -0,0 +1,24 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
|
||||
struct TFE_CancellationManager {
|
||||
tensorflow::CancellationManager cancellation_manager;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
35
tensorflow/c/eager/tfe_context_internal.h
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying context object. Instead, call
|
||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||
// the TFE_Context structure.
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
37
tensorflow/c/eager/tfe_executor_internal.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
|
||||
struct TFE_Executor {
|
||||
explicit TFE_Executor(bool async)
|
||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||
|
||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||
|
||||
tensorflow::EagerExecutor* executor() {
|
||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||
tensorflow::EagerExecutor* unowned_executor;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
146
tensorflow/c/eager/tfe_monitoring_internal.h
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringCounter {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringCounter(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||
};
|
||||
struct TFE_MonitoringStringGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||
};
|
||||
struct TFE_MonitoringBoolGaugeCell {
|
||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||
};
|
||||
|
||||
template <typename ValueType, int NumLabels>
|
||||
struct TFE_MonitoringGauge {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringGauge(const char* name, const char* description,
|
||||
LabelDesc&&... label) {
|
||||
gauge = absl::WrapUnique(
|
||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||
name, description, label...));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringBuckets {
|
||||
explicit TFE_MonitoringBuckets(
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
fn) {
|
||||
create_buckets = fn;
|
||||
}
|
||||
|
||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||
create_buckets;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSamplerCell {
|
||||
tensorflow::monitoring::SamplerCell cell;
|
||||
};
|
||||
|
||||
template <int NumLabels>
|
||||
struct TFE_MonitoringSampler {
|
||||
template <typename... LabelDesc>
|
||||
TFE_MonitoringSampler(
|
||||
const char* name,
|
||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||
const char* description, LabelDesc&&... label) {
|
||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||
{name, description, label...}, std::move(buckets)));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
51
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
51
tensorflow/c/eager/tfe_op_attrs_internal.h
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
|
||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||
// that sometimes do not require serialization.
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
struct TFE_OpAttrs {
|
||||
explicit TFE_OpAttrs() : attributes(nullptr) {}
|
||||
|
||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value)
|
||||
: attributes(value) {}
|
||||
|
||||
const tensorflow::AttrBuilder* attributes;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
const tensorflow::AttrValue& default_value,
|
||||
const char* attr_name, TF_Status* status);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
36
tensorflow/c/eager/tfe_op_internal.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying operation object. Instead, call
|
||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||
// the TFE_Op structure.
|
||||
typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
30
tensorflow/c/eager/tfe_tensor_debug_info_internal.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||
: dev_dims(dims) {}
|
||||
|
||||
// Fully-padded, minor-to-major.
|
||||
std::vector<tensorflow::int64> dev_dims;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
38
tensorflow/c/eager/tfe_tensorhandle_internal.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||
// interface cannot destruct the underlying handle object. Instead, call
|
||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||
// the TFE_TensorHandle structure.
|
||||
typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
@ -24,8 +24,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using tensorflow::ServerFactory;
|
||||
|
||||
|
@ -22,8 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/rendezvous.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
@ -24,9 +24,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
66
tensorflow/c/experimental/saved_model/README.md
Normal file
66
tensorflow/c/experimental/saved_model/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Tensorflow C SavedModel API
|
||||
|
||||
## Overview
|
||||
|
||||
These are the new experimental C SavedModel APIs for loading and running
|
||||
SavedModels in a TF2-idiomatic fashion. See
|
||||
[RFC 207](https://github.com/tensorflow/community/pull/207) for additional
|
||||
context.
|
||||
|
||||
The directory structure is as follows:
|
||||
|
||||
```none
|
||||
saved_model/
|
||||
|
||||
public/
|
||||
|
||||
internal/
|
||||
|
||||
core/
|
||||
|
||||
```
|
||||
|
||||
## saved_model/public
|
||||
|
||||
`saved_model/public` is intended to house *only the public headers* of the
|
||||
SavedModel C API.
|
||||
|
||||
These headers:
|
||||
|
||||
1. declare opaque C types (like `TF_SavedModel`),
|
||||
|
||||
2. declare the functions that operate on these types (like `TF_LoadSavedModel`).
|
||||
|
||||
Once they leave experimental, these APIs should be considered stable for use
|
||||
by external clients.
|
||||
|
||||
These headers are in a separate directory to make it obvious to clients which
|
||||
headers they should depend on, and which headers are implementation details.
|
||||
Separating these public headers by directory also allow future programmatic
|
||||
checks to ensure that TF public headers only `#include` other public TF headers.
|
||||
|
||||
## saved_model/internal
|
||||
|
||||
`saved_model/internal` is the "glue" between the C API and the internal C++
|
||||
implementation.
|
||||
|
||||
Its role is to:
|
||||
|
||||
1. implement the C API functions declared in `saved_model/public`
|
||||
|
||||
2. define the C API types declared in `saved_model/public`
|
||||
|
||||
The files fulfilling 1. are named `*.cc` (eg: `concrete_function.cc`), while
|
||||
the files fulfilling 2. are `*type.h` (eg: `concrete_function_type.h`).
|
||||
|
||||
The headers exposing the internal implementation of the opaque C types are only
|
||||
visible to other implementors of the C API. This is similar to how other
|
||||
TF C API implementations use `tf_status_internal.h` (to extract the underlying
|
||||
`tensorflow::Status`). All other targets in this directory are private.
|
||||
|
||||
## saved_model/core
|
||||
|
||||
`saved_model/core` contains pure C++ "Classes" underlying the C API types
|
||||
in `saved_model/public/`. These are implementation
|
||||
details subject to change, and have limited visibility to implementors only.
|
||||
This is the bottom-most layer of the `C++ -> C -> C++` sandwich.
|
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
85
tensorflow/c/experimental/saved_model/core/BUILD
Normal file
@ -0,0 +1,85 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__pkg__",
|
||||
"//tensorflow/core:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
hdrs = [
|
||||
"function_metadata.h",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
hdrs = [
|
||||
"saved_model_api.h",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
"tf_saved_model_impl.cc",
|
||||
],
|
||||
hdrs = ["tf_saved_model_impl.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pywrap_required_hdrs",
|
||||
textual_hdrs = [
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs_only_runtime",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
"tf_saved_model_impl.cc",
|
||||
"tf_saved_model_impl.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
||||
const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note that ConcreteFunctions's lifetimes are effectively bound
|
||||
// to the SavedModel they are loaded from, since they retain pointers
|
||||
// to the TensorHandles owned by the SavedModel, and the FunctionDef
|
||||
// of the SavedModel.
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class ConcreteFunction {
|
||||
public:
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
||||
|
||||
class Fused2Ops<string kernel> : NativeCodeCall<
|
||||
"fuseOps(&$_builder, {$0, $1}, \"" # kernel # "\")">;
|
||||
class Fused3Ops<string kernel> : NativeCodeCall<
|
||||
"fuseOps(&$_builder, {$0, $1, $2}, \"" # kernel # "\")">;
|
||||
namespace tensorflow {
|
||||
|
||||
def : Pat<(HLO_AddOp:$add (HLO_MulOp:$mul $_, $_, $_), $_, $_),
|
||||
(Fused2Ops<"generic.mul_add"> $mul, $add)>;
|
||||
class FunctionMetadata {
|
||||
// TODO(bmzhao): Fill in with fields as necessary
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_
|
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
55
tensorflow/c/experimental/saved_model/core/saved_model_api.h
Normal file
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock
|
||||
// TFRT integration with TF Serving. Do not add more virtual implementations of
|
||||
// this class. Eventually we want to remove this virtual base class indirection
|
||||
// and have only a single implementation.
|
||||
class SavedModelAPI {
|
||||
public:
|
||||
// Retrieve a function from the TF2 SavedModel, using the "path" to a function
|
||||
// in a TF2 savedmodel.
|
||||
// Note: `function` is a double pointer, so that implementations are
|
||||
// able to return a pointer to an internal member.
|
||||
virtual Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
// Retrieve a function from a SavedModel, using the key of the
|
||||
// SignatureDef map:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
virtual Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) = 0;
|
||||
|
||||
virtual std::vector<ConcreteFunction*> ListFunctions() = 0;
|
||||
|
||||
virtual ~SavedModelAPI() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TFSavedModelAPIImpl::GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::GetSignatureDefFunction(
|
||||
const std::string& signature_def_key, ConcreteFunction** function) {
|
||||
// TODO(bmzhao): Add support for retrieving a signaturedef function.
|
||||
return errors::Unimplemented(
|
||||
"Retrieving functions is unimplemented currently");
|
||||
}
|
||||
|
||||
std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
std::vector<ConcreteFunction*> result;
|
||||
result.reserve(functions_.size());
|
||||
for (ConcreteFunction& function : functions_) {
|
||||
result.push_back(&function);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
Status GetSignatureDefFunction(const std::string& signature_def_key,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
|
181
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
181
tensorflow/c/experimental/saved_model/internal/BUILD
Normal file
@ -0,0 +1,181 @@
|
||||
# Experimental Implementation of SavedModel C APIs for TensorFlow. See RFC
|
||||
# https://github.com/tensorflow/community/pull/207
|
||||
# External clients should not worry about this directory; all contents are implementation details.
|
||||
# Code in this directory is intended to form the glue between the C API and the internal C++
|
||||
# implementation by
|
||||
# 1. mapping C API calls onto correponding methods of C++ objects
|
||||
# 2. mapping opaque C types onto C++ classes
|
||||
|
||||
# Note(bmzhao): The *.cc files in this directory form the direct implementation of the
|
||||
# C API functions exposed in tf/c/experimental/saved_model/public/.
|
||||
|
||||
# Note(bmzhao): All *type.h files in this directory are the internal definitions of
|
||||
# the opaque C types. These headers should only be visible to internal tensorflow
|
||||
# implementors.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
# TODO(bmzhao): Remove this as we refactor C API to granular targets,
|
||||
# so that we can depend on c/eager/c_api_unified_experimental.h.
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function_type",
|
||||
":function_metadata",
|
||||
":function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list",
|
||||
srcs = [
|
||||
"concrete_function_list.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function_list.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_list_type",
|
||||
hdrs = [
|
||||
"concrete_function_list_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function_type",
|
||||
hdrs = [
|
||||
"concrete_function_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata",
|
||||
srcs = [
|
||||
"function_metadata.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:function_metadata.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "function_metadata_type",
|
||||
hdrs = [
|
||||
"function_metadata_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api",
|
||||
srcs = [
|
||||
"saved_model_api.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c/experimental/saved_model/public:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":concrete_function_list_type",
|
||||
":concrete_function_type",
|
||||
":saved_model_api_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_api_type",
|
||||
hdrs = [
|
||||
"saved_model_api_type.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(const_cast<tensorflow::FunctionMetadata*>(
|
||||
&tensorflow::unwrap(func)->GetFunctionMetadata()));
|
||||
}
|
||||
|
||||
TF_OutputList* TF_ConcreteFunctionGetCaptures(TF_ConcreteFunction* func) {
|
||||
// TODO(bmzhao): Refactor TF_OutputList struct definition into a separate
|
||||
// internal header, and implement this function.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
size_t TF_ConcreteFunctionListNumOutputs(TF_ConcreteFunctionList* list) {
|
||||
return list->list.size();
|
||||
}
|
||||
|
||||
TF_ConcreteFunction* TF_ConcreteFunctionListGet(TF_ConcreteFunctionList* list,
|
||||
int i) {
|
||||
return tensorflow::wrap(list->list[i]);
|
||||
}
|
||||
|
||||
void TF_DeleteConcreteFunctionList(TF_ConcreteFunctionList* list) {
|
||||
delete list;
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_ConcreteFunctionList {
|
||||
std::vector<tensorflow::ConcreteFunction*> list;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate
|
||||
// struct, since the lifetime of the struct and the raw pointer it wraps would
|
||||
// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*.
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_
|
@ -0,0 +1,20 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
|
||||
// TODO(bmzhao): Add getter functions here as necessary.
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_
|
@ -0,0 +1,97 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
const char* const* tags, int tags_len,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
|
||||
std::unordered_set<std::string> tagset;
|
||||
for (int i = 0; i < tags_len; ++i) {
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_SavedModel{std::move(result)};
|
||||
}
|
||||
|
||||
void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; }
|
||||
|
||||
TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model,
|
||||
const char* function_path,
|
||||
TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetFunction(function_path, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status) {
|
||||
tensorflow::ConcreteFunction* result = nullptr;
|
||||
tensorflow::Status get_function_status =
|
||||
model->saved_model->GetSignatureDefFunction(signature_def_key, &result);
|
||||
status->status.Update(get_function_status);
|
||||
if (!get_function_status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
|
||||
TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) {
|
||||
return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()};
|
||||
}
|
||||
|
||||
} // end extern "C"
|
@ -0,0 +1,109 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "cc/saved_model/testdata";
|
||||
const char* kServeTag[] = {"serve"};
|
||||
|
||||
std::string SavedModelPath(tensorflow::StringPiece saved_model_dir) {
|
||||
return tensorflow::io::JoinPath(tensorflow::testing::TensorFlowSrcRoot(),
|
||||
kTestData, saved_model_dir);
|
||||
}
|
||||
|
||||
// This value parameterized test allows us to test both TFRT
|
||||
// and non TFRT runtimes.
|
||||
// https://github.com/google/googletest/blob/dcc92d0ab6c4ce022162a23566d44f673251eee4/googletest/docs/advanced.md#value-parameterized-tests
|
||||
class CSavedModelAPITest : public ::testing::TestWithParam<bool> {};
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModelWithTags) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModelWithTags(model_dir.c_str(), ctx, kServeTag, 1, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
TEST_P(CSavedModelAPITest, LoadsSavedModel) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
bool use_tfrt = GetParam();
|
||||
if (use_tfrt) {
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced.
|
||||
}
|
||||
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
std::string model_dir = SavedModelPath("VarsAndArithmeticObjectGraph");
|
||||
|
||||
TF_SavedModel* saved_model =
|
||||
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
|
||||
|
||||
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
|
||||
|
||||
TF_DeleteSavedModel(saved_model);
|
||||
TF_DeleteStatus(status);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest,
|
||||
::testing::Bool());
|
||||
|
||||
} // namespace
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to change
|
||||
// and should not be depended on.
|
||||
|
||||
struct TF_SavedModel {
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> saved_model;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_
|
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
63
tensorflow/c/experimental/saved_model/public/BUILD
Normal file
@ -0,0 +1,63 @@
|
||||
# Experimental SavedModel C APIs for TensorFlow.
|
||||
# See RFC https://github.com/tensorflow/community/pull/207
|
||||
# All headers are on the public surface of Tensorflow's C API.
|
||||
# Once moved out of experimental, these will be stable.
|
||||
# The idea behind a separate public/ directory is to make apparent
|
||||
# which headers are part of TF's public interface (and which headers)
|
||||
# are implementation details. This structure allows us to also perform future
|
||||
# programmatic checks that all "public" headers only include other "public"
|
||||
# headers.
|
||||
|
||||
package(
|
||||
# This is intentionally public
|
||||
default_visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
# TODO(bmzhao): Remove these exports_files and rules, swap with cc_public_library instead.
|
||||
# cc_public_library would allows us to separate the header dep graph from header+srcs dep graph.
|
||||
exports_files(
|
||||
[
|
||||
"concrete_function.h",
|
||||
"concrete_function_list.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
|
||||
)
|
||||
|
||||
# The purpose of this header is to provide insulation against
|
||||
# future changes where we rename/move a public header, without
|
||||
# forcing all clients to change their "#includes".
|
||||
cc_library(
|
||||
name = "c_saved_model_api",
|
||||
hdrs = ["c_saved_model_api.h"],
|
||||
deps = [
|
||||
":concrete_function",
|
||||
":concrete_function_list",
|
||||
":function_metadata",
|
||||
":saved_model_api",
|
||||
],
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "concrete_function_list",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:concrete_function_list",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "function_metadata",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:function_metadata",
|
||||
)
|
||||
|
||||
alias(
|
||||
name = "saved_model_api",
|
||||
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 Google LLC. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,19 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
|
||||
|
||||
#include "tensorflow/lite/experimental/ruy/ruy/platform.h"
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#if RUY_PLATFORM(NEON)
|
||||
#include "tensorflow/lite/experimental/ruy/ruy/kernel_arm.h"
|
||||
#elif RUY_PLATFORM(X86)
|
||||
#include "tensorflow/lite/experimental/ruy/ruy/kernel_x86.h"
|
||||
#else
|
||||
#include "tensorflow/lite/experimental/ruy/ruy/kernel_common.h"
|
||||
#endif
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type that corresponds to a Function loaded from a SavedModel.
|
||||
// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the
|
||||
// C++ Unified Eager/Graph API's AbstractFunction
|
||||
typedef struct TF_ConcreteFunction TF_ConcreteFunction;
|
||||
|
||||
// Returns FunctionMetadata associated with `func`. Metadata's lifetime is
|
||||
// bound to `func`, which is bound to the TF_SavedModel it was loaded from.
|
||||
TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a list of TensorHandles implicitly captured by this function.
|
||||
TF_CAPI_EXPORT extern TF_OutputList* TF_ConcreteFunctionGetCaptures(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_ConcreteFunction* func);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
|
||||
typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList;
|
||||
|
||||
// Returns the size of `list`.
|
||||
TF_CAPI_EXPORT size_t
|
||||
TF_ConcreteFunctionListSize(TF_ConcreteFunctionList* list);
|
||||
|
||||
// Returns the `i`th TF_ConcreteFunction in the list.
|
||||
TF_CAPI_EXPORT TF_ConcreteFunction* TF_ConcreteFunctionListGet(
|
||||
TF_ConcreteFunctionList* list, int i);
|
||||
|
||||
// Deletes `list`.
|
||||
TF_CAPI_EXPORT void TF_DeleteConcreteFunctionList(
|
||||
TF_ConcreteFunctionList* list);
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type used to store any metadata associated with a function.
|
||||
typedef struct TF_FunctionMetadata TF_FunctionMetadata;
|
||||
|
||||
// TODO(bmzhao): Add getters for fields as we determine what metadata
|
||||
// we want to expose.
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_
|
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
108
tensorflow/c/experimental/saved_model/public/saved_model_api.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
||||
|
||||
#include "tensorflow/c/c_api_macros.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// An opaque type representing a Tensorflow "SavedModel"
|
||||
// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer
|
||||
// to achieve ABI stability.
|
||||
typedef struct TF_SavedModel TF_SavedModel;
|
||||
|
||||
// Load a SavedModel from `dirname`. We expect the SavedModel to contain a
|
||||
// single Metagraph (as for those exported from TF2's `tf.saved_model.save`).
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname,
|
||||
TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
||||
// Load a SavedModel from `dirname`.
|
||||
//
|
||||
// Params:
|
||||
// dirname - A directory filepath that the SavedModel is at.
|
||||
// ctx - A TFE_Context containing optional load/TF runtime options.
|
||||
// `ctx` must outlive the returned TF_SavedModel pointer.
|
||||
// tags - char* array of SavedModel tags. We will load the metagraph matching
|
||||
// the tags.
|
||||
// tags_len - number of elements in the `tags` array.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a newly created
|
||||
// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel.
|
||||
TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags(
|
||||
const char* dirname, TFE_Context* ctx, const char* const* tags,
|
||||
int tags_len, TF_Status* status);
|
||||
|
||||
// Deletes a TF_SavedModel, and frees any resources owned by it.
|
||||
TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model);
|
||||
|
||||
// Retrieve a function from the TF2 SavedModel via function path.
|
||||
//
|
||||
// Params:
|
||||
// model - The TF2 SavedModel to load a function from.
|
||||
// function_path - A string containing the path from the root saved python
|
||||
// object to a tf.function method.
|
||||
// TODO(bmzhao): Add a detailed example of this with a
|
||||
// python tf.module before moving this out of experimental.
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. The lifetime of this instance is
|
||||
// "conceptually" bound to `model`. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(
|
||||
TF_SavedModel* model, const char* function_path, TF_Status* status);
|
||||
|
||||
// Retrieve a function from the TF SavedModel via a SignatureDef key.
|
||||
//
|
||||
// Params:
|
||||
// model - The SavedModel to load a function from.
|
||||
// signature_def_key - The string key of the SignatureDef map of a SavedModel:
|
||||
// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89
|
||||
// status - Set to OK on success and an appropriate error on failure.
|
||||
// Returns:
|
||||
// If status is not OK, returns nullptr. Otherwise, returns a
|
||||
// TF_ConcreteFunction instance. Once `model` is deleted, all
|
||||
// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction(
|
||||
TF_SavedModel* model, const char* signature_def_key, TF_Status* status);
|
||||
|
||||
// Returns a list of all ConcreteFunctions stored in this SavedModel.
|
||||
// The lifetime of the returned list is bound to `model`.
|
||||
TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions(
|
||||
TF_SavedModel* model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_
|
@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_TENSOR_INTERFACE_H_
|
||||
#define TENSORFLOW_C_TENSOR_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor.
|
||||
//
|
||||
@ -28,10 +29,11 @@ limitations under the License.
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorInterface {
|
||||
public:
|
||||
virtual ~AbstractTensorInterface() {}
|
||||
// Release any underlying resources, including the interface object.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType Type() const = 0;
|
||||
virtual DataType Type() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims() const = 0;
|
||||
// Returns size of specified dimension
|
||||
@ -47,37 +49,11 @@ class AbstractTensorInterface {
|
||||
virtual bool IsAligned() const = 0;
|
||||
// Returns if their is sole ownership of this Tensor and thus it can be moved.
|
||||
virtual bool CanMove() const = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorInterface : public AbstractTensorInterface {
|
||||
public:
|
||||
TensorInterface() {}
|
||||
explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {}
|
||||
~TensorInterface() override {}
|
||||
|
||||
TF_DataType Type() const override;
|
||||
int NumDims() const override;
|
||||
int64_t Dim(int dim_index) const override;
|
||||
int64_t NumElements() const override;
|
||||
size_t ByteSize() const override;
|
||||
void* Data() const override;
|
||||
bool IsAligned() const override;
|
||||
bool CanMove() const override;
|
||||
|
||||
Status ToTensor(tensorflow::Tensor* dst) const;
|
||||
Status BitcastFrom(const TensorInterface& from, TF_DataType type,
|
||||
const int64_t* new_dims, int num_new_dims);
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
tensorflow::Tensor Tensor() { return tensor_; }
|
||||
|
||||
private:
|
||||
tensorflow::Tensor tensor_;
|
||||
protected:
|
||||
virtual ~AbstractTensorInterface() {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/error.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
using ::tensorflow::IOError;
|
||||
using ::tensorflow::Status;
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_TF_STATUS_HELPER_H_
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user