Merge pull request from tensorflow/master

sync to upstream
This commit is contained in:
bigcat-himax 2020-06-30 15:19:02 +08:00 committed by GitHub
commit a28bac9244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3249 changed files with 134311 additions and 100366 deletions
.bazelrc.bazelversion
.github
README.mdRELEASE.mdWORKSPACEconfigure.py
tensorflow
BUILD
c
cc
compiler

114
.bazelrc
View File

@ -30,6 +30,7 @@
# short_logs: Only log errors during build, skip warnings.
# monolithic: Build all TF C++ code into a single shared object.
# dynamic_kernels: Try to link all kernels dynamically (experimental).
# libc++: Link against libc++ instead of stdlibc++
#
#
# TF version options;
@ -38,6 +39,7 @@
#
# Feature and Third party library support options:
# xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
@ -56,13 +58,12 @@
#
#
# Remote build execution options (only configured to work with TF team projects for now.)
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
#
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc.
# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
#
# rbe_linux_py2: Linux Python 2 RBE config.
# rbe_linux_py3: Linux Python 3 RBE config
@ -79,6 +80,14 @@
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
build:libc++ --action_env=CC
build:libc++ --action_env=CXX
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
build:libc++ --action_env=PATH
build:libc++ --define force_libcpp=enabled
build:libc++ --linkopt -fuse-ld=lld
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
# target CPU to build transient dependencies correctly. See
@ -171,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD
# Config to build TPU backend
build:tpu --define=with_tpu_support=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
@ -200,6 +212,8 @@ build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nonccl --define=no_nccl_support=true
build:stackdriver_support --define=stackdriver_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
@ -385,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 10.1.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -429,8 +458,23 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
# ROCm
build:rbe_linux_rocm_base --config=rbe_linux
build:rbe_linux_rocm_base --repo_env=TF_NEED_ROCM=1
build:rbe_linux_rocm_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain"
build:rbe_linux_rocm_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain-linux-x86_64"
build:rbe_linux_rocm_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --action_env=TF_ROCM_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm"
build:rbe_linux_rocm_base --define=using_rocm_hipcc=true
build:rbe_linux_rocm_py2.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python2.7"
build:rbe_linux_rocm_py3.5 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.5"
build:rbe_linux_rocm_py3.6 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.6"
build:rbe_linux_rocm_py3.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.7"
build:rbe_linux_rocm_py3.8 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.8"
# Linux CPU
build:rbe_linux_py2 --config=rbe_linux
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
build:rbe_linux_py2 --python_path="/usr/bin/python2"
@ -441,8 +485,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"

View File

@ -1 +1 @@
3.0.0
3.1.0

View File

@ -27,6 +27,19 @@ assignees:
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# filesystem path
filesystem_path:
- tensorflow/c/experimental/filesystem
# security path
security_path:
- tensorflow/security
# words checklist
segfault_memory:
- segfault
- memory leaks
# assignees
filesystem_security_assignee:
- mihaimaruseac
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:

2
.github/stale.yml vendored
View File

@ -23,7 +23,7 @@
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable

View File

@ -61,7 +61,6 @@ commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
```shell
@ -96,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
The TensorFlow project strives to abide by generally accepted best practices in
open-source software development:
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md)
@ -114,6 +114,12 @@ Build Type | Status
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
### Community Supported Builds

View File

@ -1,3 +1,63 @@
# Release 2.4.0
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
## Breaking Changes
* <DOCUMENT BREAKING CHANGES HERE>
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
* The byte layout for string tensors across the C-API has been updated to match
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
`TF_StringEncodedSize` are no longer relevant and have been removed; see
core/platform/ctstring.h for string access/modification in C.
## Known Caveats
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
## Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* `tf.data`:
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
* <ADD RELEASE NOTES HERE>
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* TPU Enhancements:
* <ADD RELEASE NOTES HERE>
* XLA Support:
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
# Release 2.3.0
## Breaking Changes
@ -11,6 +71,15 @@
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
## Bug Fixes and Other Changes
* `tf.keras`:
* Deprecated the `tf.keras.experimental.PeepholeLSTMCell` layer, which was
moved to `tensorflow_addons` as
`tensorflow_addons.rnn.PeepholeLSTMCell`. This experimental API is
expected to be removed from TF in the next public release (2.4).
* Mutable tables now restore checkpointed values when loaded from SavedModel.
# Release 2.1.1
## Bug Fixes and Other Changes

View File

@ -36,18 +36,6 @@ load(
bazel_toolchains_repositories()
load(
"@io_bazel_rules_docker//repositories:repositories.bzl",
container_repositories = "repositories",
)
container_repositories()
load("//third_party/toolchains/preconfig/generate:workspace.bzl",
"remote_config_workspace")
remote_config_workspace()
# Use `swift_rules_dependencies` to fetch the toolchains. With the
# `git_repository` rules above, the following call will skip redefining them.
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MIN_BAZEL_VERSION = '3.1.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
@ -480,7 +480,7 @@ def check_bazel_version(min_version, max_version):
"""
if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
sys.exit(1)
stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'],

View File

@ -298,6 +298,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# Experimental features
config_setting(
name = "stackdriver_support",
define_values = {"stackdriver_support": "true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
@ -460,6 +467,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# This flag enables experimental TPU support
config_setting(
name = "with_tpu_support",
values = {"define": "with_tpu_support=true"},
visibility = ["//visibility:public"],
)
# Specifies via a config setting if this is a mobile build or not, makes
# it easier to combine settings later.
selects.config_setting_group(
@ -531,6 +545,7 @@ package_group(
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
@ -540,6 +555,11 @@ package_group(
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "structured_tensor_whitelist")
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -29,6 +29,8 @@ filegroup(
"tf_file_statistics.h",
"tf_status.h",
"tf_tensor.h",
"tf_tstring.h",
"//tensorflow/core/platform:ctstring",
],
visibility = ["//tensorflow:__subpackages__"],
)
@ -48,6 +50,7 @@ filegroup(
"*test*",
],
) + [
"//tensorflow/core/platform:ctstring",
"//tensorflow/cc:srcs_no_runtime",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
@ -78,6 +81,7 @@ tf_cuda_library(
"c_api_internal.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
],
visibility = [
"//tensorflow:internal",
@ -173,6 +177,7 @@ tf_cuda_library(
copts = tf_copts(),
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/python:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_tstring.h"
// --------------------------------------------------------------------------
// C API for TensorFlow.

View File

@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) {

View File

@ -54,7 +54,7 @@ Status ProcessInputs(
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),

View File

@ -2286,14 +2286,15 @@ TEST_F(CApiAttributesTest, Tensor) {
TEST_F(CApiAttributesTest, StringTensor) {
// Create the string-Tensor "attribute" value.
char encoded[] = {
0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
1, // varint encoded string length
'A',
};
const char test_string[] =
"borkborkborkborkborkborkborkbork"; // >24bytes to force heap alloc
TF_TString tstr[1];
TF_TString_Init(&tstr[0]);
TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
auto deallocator = [](void* data, size_t len, void* arg) {};
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
sizeof(encoded), deallocator, nullptr),
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
sizeof(tstr), deallocator, nullptr),
TF_DeleteTensor);
// Create a TF_Operation with the attribute t_in
@ -2312,9 +2313,17 @@ TEST_F(CApiAttributesTest, StringTensor) {
EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
EXPECT_EQ(0, TF_NumDims(t_out));
ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
TF_TensorByteSize(t_out)));
TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
EXPECT_EQ(absl::string_view(test_string),
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
TF_TString_GetSize(t_out_tstr)));
EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
TF_TString_GetSize(t_in_tstr)),
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
TF_TString_GetSize(t_out_tstr)));
TF_DeleteTensor(t_out);
TF_TString_Dealloc(&tstr[0]);
}
TEST_F(CApiAttributesTest, TensorList) {

View File

@ -38,9 +38,10 @@ tf_cuda_library(
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":context_interface",
":operation_interface",
":tensor_handle_interface",
":immediate_execution_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
":abstract_tensor_handle",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
":tfe_executor_internal",
@ -101,13 +102,17 @@ tf_cuda_library(
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"abstract_context.h",
"abstract_function.h",
"abstract_operation.h",
"abstract_tensor_handle.h",
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h",
"operation_interface.h",
"tensor_handle_interface.h",
"immediate_execution_context.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"tfe_cancellation_manager_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
@ -153,9 +158,13 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
@ -163,12 +172,22 @@ cc_library(
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
name = "abstract_tensor_handle",
hdrs = ["abstract_tensor_handle.h"],
visibility = [
"//tensorflow:internal",
],
deps = [],
)
cc_library(
name = "immediate_execution_tensor_handle",
hdrs = ["immediate_execution_tensor_handle.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -177,13 +196,13 @@ cc_library(
)
cc_library(
name = "operation_interface",
hdrs = ["operation_interface.h"],
name = "abstract_operation",
hdrs = ["abstract_operation.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
":abstract_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -193,16 +212,58 @@ cc_library(
)
cc_library(
name = "context_interface",
hdrs = ["context_interface.h"],
name = "immediate_execution_operation",
hdrs = ["immediate_execution_operation.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":operation_interface",
":tensor_handle_interface",
":abstract_operation",
":abstract_tensor_handle",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_context",
hdrs = ["abstract_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_function",
":abstract_operation",
],
)
cc_library(
name = "abstract_function",
hdrs = ["abstract_function.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
],
)
cc_library(
name = "immediate_execution_context",
hdrs = ["immediate_execution_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -218,7 +279,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":context_interface",
":immediate_execution_context",
"//tensorflow/c:conversion_macros",
],
)
@ -278,7 +339,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":operation_interface",
":immediate_execution_operation",
"//tensorflow/c:conversion_macros",
],
)
@ -301,7 +362,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
":tensor_handle_interface",
":immediate_execution_tensor_handle",
"//tensorflow/c:conversion_macros",
],
)
@ -481,6 +542,12 @@ tf_cuda_library(
":tfe_context_internal",
":tfe_op_internal",
":tfe_tensorhandle_internal",
":abstract_operation",
":abstract_context",
":abstract_tensor_handle",
":immediate_execution_tensor_handle",
":immediate_execution_context",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu",
@ -499,6 +566,7 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:variant",
"//tensorflow/c:conversion_macros",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -672,6 +740,10 @@ filegroup(
],
exclude = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
"*test*",
"*dlpack*",
],

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
public:
AbstractContextKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
virtual AbstractOperation* CreateOperation() = 0;
// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual Status RegisterFunction(AbstractFunction*) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
private:
const AbstractContextKind kind_;
};
namespace internal {
struct AbstractContextDeleter {
void operator()(AbstractContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraph, kMlir };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Returns the AbstractFunction as a FunctionDef.
virtual Status GetFunctionDef(FunctionDef**) = 0;
private:
const AbstractFunctionKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_

View File

@ -12,24 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
// This interface allows building and executing an operation in either
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
public:
AbstractOperationKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
@ -38,7 +45,6 @@ class AbstractOperationInterface {
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
virtual void Clear() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
@ -66,12 +72,10 @@ class AbstractOperationInterface {
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
virtual Status AddInputList(
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
virtual Status AddInput(AbstractTensorHandle* input) = 0;
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
@ -82,7 +86,7 @@ class AbstractOperationInterface {
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) = 0;
const AbstractOperation* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name,
@ -102,19 +106,25 @@ class AbstractOperationInterface {
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperationInterface*> values) = 0;
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
virtual ~AbstractOperationInterface() {}
private:
const AbstractOperationKind kind_;
};
namespace internal {
struct AbstractOperationDeleter {
void operator()(AbstractOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOperationPtr =
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
public:
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private:
const AbstractTensorHandleKind kind_;
};
namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/c/eager/abstract_tensor_handle.h"
// clang-format off
#include "tensorflow/core/platform/platform.h"
// clang-format on
@ -31,8 +33,8 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
@ -713,8 +715,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(
new tfrt::ContextInterface(op_handler_chains, device_attributes));
return tensorflow::wrap(new tfrt::ContextInterface(
op_handler_chains, device_attributes, opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
tensorflow::AbstractOperationInterface* new_op =
tensorflow::ImmediateExecutionOperation* new_op =
tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) {
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInputList(
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs)});
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
tensorflow::unwrap(value)),
static_cast<size_t>(num_values)});
if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name;
}
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(retvals)),
*num_retvals),
num_retvals);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
@ -1473,14 +1482,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
}
destination->CopyAttributes(*tensorflow::unwrap(attrs));
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -38,7 +38,7 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
tensorflow::AbstractOperationInterface* op =
tensorflow::ImmediateExecutionOperation* op =
tensorflow::unwrap(op_to_reset);
op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
context->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->GetContextId();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
// is initialized (either locally or remotely). The context_id can change during
// the process lifetime although this should cause the worker to be
// reinitialized (e.g. cleared caches) as well.
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
// -----------------------------------------------------------------------------
// Cancellation APIs.

View File

@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
TFE_DeleteCancellationManager(c_mgr);
}
TEST(CAPI, ExecutorContextDestructionOrder) {
TF_Status* status = TF_NewStatus();
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteContext(ctx);
TFE_DeleteExecutor(executor);
}
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TF_DeleteStatus(status);
}
TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();

View File

@ -22,15 +22,17 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
using tensorflow::internal::OutputList;
using tensorflow::internal::unwrap;
namespace tensorflow {
namespace internal {
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
namespace tracing {
typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
@ -48,8 +50,8 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
auto entry = GetFactories().find(default_factory);
if (entry != GetFactories().end()) return entry->second(fn_name, s);
string msg = absl::StrCat(
@ -70,7 +72,7 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
return nullptr;
}
} // end namespace internal
} // end namespace tracing
} // end namespace tensorflow
// =============================================================================
@ -83,43 +85,77 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
//
// =============================================================================
using tensorflow::AbstractFunction;
using tensorflow::AbstractTensorHandle;
using tensorflow::DataType;
using tensorflow::dyn_cast;
using tensorflow::OutputList;
using tensorflow::Status;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::tracing::CreateTracingExecutionContext;
using tensorflow::tracing::SetDefaultTracingEngine;
using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name) {
tensorflow::internal::SetDefaultTracingEngine(name);
SetDefaultTracingEngine(name);
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
return wrap(CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
AbstractFunction* func;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"Only TracingContext can be converted into a function."));
return nullptr;
}
Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
TF_DeleteExecutionContext(ctx);
return func;
return wrap(func);
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
return wrap(unwrap(func)->AddParameter(dtype, s));
TracingTensorHandle* t;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AddFunctionParameter must be called on a TracingContext."));
return nullptr;
}
Set_TF_Status_from_Status(
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
return wrap(t);
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return wrap(unwrap(c)->CreateOperation());
return wrap((unwrap(c)->CreateOperation()));
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
unwrap(o)->expected_num_outputs = num_outputs;
unwrap(o)->outputs.clear();
unwrap(o)->outputs.resize(num_outputs);
}
int TF_OutputListNumOutputs(TF_OutputList* o) {
return unwrap(o)->outputs.size();
@ -134,24 +170,46 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
unwrap(op)->SetOpType(op_type, s);
Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
/*raw_device_name=*/nullptr));
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
unwrap(op)->SetOpName(op_name, s);
TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
if (!tracing_op) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AbstractOpSetOpName must be called on a TracingOperation."));
return;
}
Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
unwrap(op)->SetAttrType(attr_name, value, s);
Status status =
unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
TF_SetStatus(s, static_cast<TF_Code>(status.code()),
status.error_message().c_str());
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
unwrap(o), s);
TF_Status* s) {
for (int i = 0; i < num_inputs; i++) {
Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
if (TF_GetCode(s) != TF_OK) {
return;
}
}
int num_outputs = unwrap(o)->expected_num_outputs;
Set_TF_Status_from_Status(
s, unwrap(op)->Execute(
absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
unwrap(o)->outputs.data()),
unwrap(o)->outputs.size()),
&num_outputs));
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
@ -161,5 +219,5 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
unwrap(ctx)->RegisterFunction(unwrap(func), s);
Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
}

View File

@ -110,7 +110,7 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
// Any active tape will observe the effects of this execution.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The provided `ctx` is consumed by this API call and deleted.
@ -137,7 +137,8 @@ TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*,
TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */

View File

@ -15,180 +15,68 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
namespace tensorflow {
namespace internal {
// Simple wrapper over a TFE_TensorHandle
struct EagerTensor : public AbstractTensor {
TFE_TensorHandle* t = nullptr;
EagerTensor() : AbstractTensor(kKind) {}
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
static constexpr AbstractTensorKind kKind = kEagerTensor;
};
// Simple wrapper over a TFE_Op
class EagerOp : public AbstractOp {
public:
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
op_ = TFE_NewOp(ctx_, op_type, s);
}
void SetOpName(const char* const op_name, TF_Status* s) override {
// Name is ignored in eager mode.
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (op_ == nullptr) {
TF_SetStatus(s, TF_FAILED_PRECONDITION,
"op_type must be specified before specifying attrs.");
return;
}
TFE_OpSetAttrType(op_, attr_name, value);
}
~EagerOp() override { TFE_DeleteOp(op_); }
static constexpr AbstractOpKind kKind = kEagerOp;
private:
friend class EagerContext; // For access to op_.
TFE_Op* op_ = nullptr;
TFE_Context* ctx_;
};
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
class EagerContext : public ExecutionContext {
public:
EagerContext() : ExecutionContext(kKind) {}
void Build(TFE_ContextOptions* options, TF_Status* status) {
eager_ctx_ = TFE_NewContext(options, status);
}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new EagerOp(eager_ctx_);
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* eager_op = dyncast<EagerOp>(op);
if (eager_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_EagerOp.");
return;
}
auto* tfe_op = eager_op->op_;
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
if (!eager_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
o->outputs.push_back(new EagerTensor(retvals[i]));
}
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't add function parameter on an eager context.");
return nullptr;
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Can't use finalize function on an eager context.");
return nullptr;
}
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {
return;
}
TFE_ContextAddFunction(eager_ctx_, func, s);
}
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
static constexpr ExecutionContextKind kKind = kEagerContext;
private:
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
TF_ExecutionContext* ctx);
TFE_Context* eager_ctx_;
};
} // namespace internal
} // namespace tensorflow
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Eager API.
// =============================================================================
using tensorflow::internal::dyncast;
using tensorflow::internal::unwrap;
using tensorflow::AbstractContext;
using tensorflow::AbstractTensorHandle;
using tensorflow::dyn_cast;
using tensorflow::ImmediateExecutionContext;
using tensorflow::ImmediateExecutionTensorHandle;
using tensorflow::string;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::strings::StrCat;
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
TF_Status* s) {
auto* ctx = new tensorflow::internal::EagerContext();
ctx->Build(options, s);
return wrap(ctx);
TFE_Context* c_ctx = TFE_NewContext(options, s);
if (TF_GetCode(s) != TF_OK) {
return nullptr;
}
return wrap(static_cast<AbstractContext*>(unwrap(c_ctx)));
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s) {
return wrap(new tensorflow::internal::EagerTensor(t));
return wrap(static_cast<AbstractTensorHandle*>(unwrap(t)));
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
if (!eager_tensor) {
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at));
if (!handle) {
string msg =
StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return eager_tensor->t;
return wrap(handle);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
if (!eager_ctx) return nullptr;
return eager_ctx->eager_ctx_;
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx,
TF_Status* s) {
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx));
if (!imm_ctx) {
string msg =
StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return wrap(imm_ctx);
}

View File

@ -18,77 +18,198 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::dyn_cast;
using tensorflow::string;
namespace tensorflow {
namespace internal {
namespace tracing {
namespace graph {
class GraphContext;
class GraphOperation;
class GraphTensor;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
struct GraphTensor : public AbstractTensor {
TF_Output output{};
GraphContext* ctx = nullptr;
GraphTensor() : AbstractTensor(kKind) {}
GraphTensor(TF_Output output, GraphContext* ctx)
: AbstractTensor(kKind), output(output), ctx(ctx) {}
static constexpr AbstractTensorKind kKind = kGraphTensor;
class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
TF_Output output_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph;
}
};
// GraphOp wraps and populate a TF_OperationDescription.
class GraphOp : public AbstractOp {
// GraphOperation wraps and populates a TF_OperationDescription.
class GraphOperation : public TracingOperation {
public:
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
void SetOpType(const char* const op_type, TF_Status* s) override {
explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
void Release() override { delete this; }
Status Reset(const char* op, const char* raw_device_name) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpType called on already built op.").c_str());
return;
return errors::FailedPrecondition("Reset called on already built op.");
}
if (op_name_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type, op_name_));
op_name_ = nullptr;
} else {
op_type_ = op_type;
if (raw_device_name) {
device_name_ = raw_device_name;
}
op_type_ = op;
return Status::OK();
}
void SetOpName(const char* const op_name, TF_Status* s) override {
Status SetOpName(const char* const op_name) override {
if (op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
strings::StrCat("SetOpName called on already built op.").c_str());
return;
return errors::FailedPrecondition(
"SetOpName called on already built op.");
}
if (op_type_ != nullptr) {
op_.reset(TF_NewOperation(g_, op_type_, op_name));
op_type_ = nullptr;
} else {
op_name_ = op_name;
if (op_type_.empty()) {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
return Status::OK();
}
void SetAttrType(const char* const attr_name, TF_DataType value,
TF_Status* s) override {
if (!op_) {
TF_SetStatus(
s, TF_FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
return;
}
TF_SetAttrType(op_.get(), attr_name, value);
}
~GraphOp() override {}
const string& Name() const override { return op_type_; }
const string& DeviceName() const override { return device_name_; }
static constexpr AbstractOpKind kKind = kGraphOp;
Status SetDeviceName(const char* name) override {
// TODO(srbs): Implement this.
device_name_ = name;
return Status::OK();
}
Status AddInput(AbstractTensorHandle* input) override {
GraphTensor* t = dyn_cast<GraphTensor>(input);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
TF_AddInput(op_.get(), t->output_);
return Status::OK();
}
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
return tensorflow::errors::Unimplemented(
"AddInputList has not been implemented yet.");
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {
auto* tf_opdesc = op_.release();
if (tf_opdesc == nullptr) {
return errors::InvalidArgument("AbstractOp is incomplete.");
}
TF_Status* s = TF_NewStatus();
auto* operation = TF_FinishOperation(tf_opdesc, s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*num_retvals = TF_OperationNumOutputs(operation);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new GraphTensor({operation, i});
}
return Status::OK();
}
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrString has not been implemented yet.");
}
Status SetAttrInt(const char* attr_name, int64_t value) override {
return tensorflow::errors::Unimplemented(
"SetAttrInt has not been implemented yet.");
}
Status SetAttrFloat(const char* attr_name, float value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloat has not been implemented yet.");
}
Status SetAttrBool(const char* attr_name, bool value) override {
return tensorflow::errors::Unimplemented(
"SetAttrBool has not been implemented yet.");
}
Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) {
return Status(
error::Code::FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
}
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override {
return tensorflow::errors::Unimplemented(
"SetAttrShape has not been implemented yet.");
}
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented yet.");
}
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrStringList has not been implemented yet.");
}
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloatList has not been implemented yet.");
}
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrIntList has not been implemented yet.");
}
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrTypeList has not been implemented yet.");
}
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrBoolList has not been implemented yet.");
}
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrShapeList has not been implemented yet.");
}
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph;
}
~GraphOperation() override {}
private:
friend class GraphContext; // For access to op_.
@ -96,123 +217,109 @@ class GraphOp : public AbstractOp {
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
const char* op_type_ = nullptr;
string op_type_;
const char* op_name_ = nullptr;
// TODO(srbs): Use this.
string device_name_;
};
// GraphFunction is a thin wrapper over a TF_Function.
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kKind) {}
GraphFunction() : AbstractFunction(kGraph) {}
explicit GraphFunction(TF_Function* func)
: AbstractFunction(kKind), func(func) {}
: AbstractFunction(kGraph), func(func) {}
~GraphFunction() override {
if (func) TF_DeleteFunction(func);
}
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
Status GetFunctionDef(FunctionDef** fdef) override {
*fdef = &func->fdef;
return Status::OK();
}
static constexpr AbstractFunctionKind kKind = kGraphFunc;
// For LLVM style RTTI.
static bool classof(const AbstractFunction* ptr) {
return ptr->getKind() == kGraph;
}
};
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public ExecutionContext {
class GraphContext : public TracingContext {
public:
explicit GraphContext(const char* name)
: ExecutionContext(kKind),
: TracingContext(kGraph),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
AbstractOp* CreateOperation() override {
// TODO(srbs): Should the lifetime of this op be tied to the context.
return new GraphOp(graph_.get());
void Release() override { delete this; }
TracingOperation* CreateOperation() override {
return new GraphOperation(graph_.get());
}
void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) override {
auto* graph_op = dyncast<GraphOp>(op);
if (graph_op == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Unable to cast AbstractOp to TF_GraphOp.");
return;
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
auto operation = CreateOperation();
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
TF_RETURN_IF_ERROR(
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(operation->Execute(
absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
if (num_outputs != 1) {
return errors::Internal("Expected 1 output but found ", num_outputs);
}
auto* tf_opdesc = graph_op->op_.release();
if (tf_opdesc == nullptr) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
return;
}
for (int i = 0; i < num_inputs; ++i) {
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
if (!graph_tensor) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (graph_tensor->ctx != this) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, graph_tensor->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
graph_op->op_ = nullptr;
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
o->outputs.push_back(new GraphTensor({operation, i}, this));
auto* t = dyn_cast<GraphTensor>(outputs[0]);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
inputs_.push_back(t->output_);
*output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
return Status::OK();
}
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
TF_OperationDescription* opdesc =
TF_NewOperation(graph_.get(), "Placeholder",
absl::StrCat("_input_", inputs_.size()).c_str());
TF_SetAttrType(opdesc, "dtype", dtype);
auto* operation = TF_FinishOperation(opdesc, s);
if (!s->status.ok()) return nullptr;
inputs_.push_back(TF_Output{operation, 0});
return new GraphTensor(inputs_.back(), this);
}
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
Status Finalize(OutputList* outputs, AbstractFunction** f) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.reserve(outputs->outputs.size());
for (AbstractTensor* abstract_output : outputs->outputs) {
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
for (auto* abstract_output : outputs->outputs) {
GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
if (!output) {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
return nullptr;
return errors::Unimplemented(
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
}
graph_outputs.push_back(output->output);
graph_outputs.push_back(output->output_);
}
auto s = TF_NewStatus();
func->func = TF_GraphToFunction(
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
if (TF_GetCode(s) != TF_OK) return nullptr;
return func.release();
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*f = func.release();
return Status::OK();
}
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
Status RegisterFunction(AbstractFunction* func) override {
return errors::Unimplemented(
"Registering graph functions has not been implemented yet.");
}
~GraphContext() override {}
static constexpr ExecutionContextKind kKind = kGraphContext;
Status RemoveFunction(const string& func) override {
return errors::Unimplemented(
"GraphContext::RemoveFunction has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph;
}
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
@ -220,7 +327,7 @@ class GraphContext : public ExecutionContext {
const char* name_;
};
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
@ -231,5 +338,6 @@ static bool register_tracing = [] {
return true;
}();
} // namespace internal
} // namespace graph
} // namespace tracing
} // namespace tensorflow

View File

@ -19,6 +19,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -26,7 +30,14 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensorHandle*> outputs;
int expected_num_outputs = -1;
};
namespace tracing {
// =============================================================================
// Implementation detail for the unified execution APIs for Eager and tracing
@ -37,165 +48,75 @@ namespace internal {
// `c_api_unified_experimental.h` header.
// =============================================================================
// We can't depend on C++ rtti, but we still want to be able to have a safe
// dynamic_cast to provide diagnostics to the user when the API is misused.
// Instead we model RTTI by listing all the possible subclasses for each
// abstract base. Each subclass initializes the base class with the right
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
// utility.
template <typename T, typename S>
T* dyncast(S source) {
if (source->getKind() != T::kKind) {
return nullptr;
}
return tensorflow::down_cast<T*>(source);
}
// Represents either an EagerTensor or a GraphTensor.
// Represents either a MlirTensor or a GraphTensor.
// This base class does not expose any public methods other than to distinguish
// which subclass it actually is. The user is responsible to use the right
// type of AbstractTensor in their context (do not pass an EagerTensor to a
// type of AbstractTensor in their context (do not pass an MlirTensor to a
// GraphContext and vice-versa).
class AbstractTensor {
class TracingTensorHandle : public AbstractTensorHandle {
protected:
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
explicit TracingTensorHandle(AbstractTensorHandleKind kind)
: AbstractTensorHandle(kind) {}
public:
// Returns which subclass is this instance of.
AbstractTensorKind getKind() const { return kind_; }
virtual ~AbstractTensor() = default;
private:
const AbstractTensorKind kind_;
};
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
// Holds the result of tracing a function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction.
// At the moment both Eager and Graph needs access to a "TF_Function" object.
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind kind_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
// An abstract operation describes an operation by its type, name, and
// attributes. It can be "executed" by the context with some input tensors.
// It is allowed to reusing the same abstract operation for multiple execution
// on a given context, with the same or different input tensors.
class AbstractOp {
class TracingOperation : public AbstractOperation {
protected:
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
explicit TracingOperation(AbstractOperationKind kind)
: AbstractOperation(kind) {}
public:
// Returns which subclass is this instance of.
AbstractOpKind getKind() const { return kind_; }
virtual ~AbstractOp() = default;
// Sets the type of the operation (for example `AddV2`).
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
// Sets the name of the operation: this is an optional identifier that is
// not intended to carry semantics and preserved/propagated without
// guarantees.
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
virtual Status SetOpName(const char* op_name) = 0;
// Add a `TypeAttribute` on the operation.
virtual void SetAttrType(const char* attr_name, TF_DataType value,
TF_Status* s) = 0;
private:
const AbstractOpKind kind_;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
// This holds the context for the execution: dispatching operations either to an
// eager implementation or to a graph implementation.
struct ExecutionContext {
// MLIR implementation or to a graph implementation.
class TracingContext : public AbstractContext {
protected:
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
public:
// Returns which subclass is this instance of.
ExecutionContextKind getKind() const { return k; }
virtual ~ExecutionContext() = default;
// Executes the operation on the provided inputs and populate the OutputList
// with the results. The input tensors must match the current context.
// The effect of "executing" an operation depends on the context: in an Eager
// context it will dispatch it to the runtime for execution, while in a
// tracing context it will add the operation to the current function.
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
// Creates an empty AbstractOperation suitable to use with this context.
virtual AbstractOp* CreateOperation() = 0;
// Add a function parameter and return the corresponding tensor.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
// This is only valid with an ExecutionContext obtained from a TracingContext,
// it'll always error out with an eager context.
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
// Registers a functions with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
private:
const ExecutionContextKind k;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
void SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
} // namespace tracing
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
// C++ implementation, and back.
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
return reinterpret_cast<CPP_CLASS* const&>(o); \
} \
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
return reinterpret_cast<const CPP_CLASS* const&>(o); \
} \
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
return reinterpret_cast<C_TYPEDEF* const&>(o); \
} \
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
} // namespace internal
DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -29,15 +30,19 @@ using tensorflow::string;
namespace tensorflow {
namespace {
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
TEST_P(UnifiedCAPI, TestBasicEager) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
@ -45,7 +50,8 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
@ -63,7 +69,7 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
@ -109,9 +115,11 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
@ -123,6 +131,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
@ -137,16 +146,14 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
// Build an abstract input tensor.
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* input_t =
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
@ -195,8 +202,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
@ -215,8 +224,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
@ -256,6 +267,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@ -273,7 +285,8 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
@ -286,7 +299,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_OutputListSetNumOutputs(func_outputs, 2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
eager_execution_ctx, s);
s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
@ -314,20 +327,21 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_DeleteAbstractFunction(func);
}
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, func);
TF_AbstractFunction* f = TF_FinalizeFunction(ctx, nullptr, status.get());
ASSERT_EQ(nullptr, f);
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
}
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpTypeAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
@ -348,7 +362,7 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpNameAfterFinishingOpBuildingRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
@ -369,116 +383,44 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
// Build an Eager context.
TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an Eager operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
TF_AbstractTensor* at =
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build a Graph context.
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute eager op using graph context.
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
TF_DeleteExecutionContext(graph_ctx);
}
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_OutputList* placeholder_outputs = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
graph_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
// Delete placeholder op.
TF_DeleteAbstractOp(placeholder_op);
// Build an abstract operation.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* add_outputs = TF_NewOutputList();
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Execute.
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
status.get());
auto placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
// Clean up operation and inputs.
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(add_op);
TF_DeleteOutputList(add_outputs);
TF_DeleteOutputList(placeholder_outputs);
TF_DeleteExecutionContext(graph_ctx);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecutionContextGetTFEContext(graph_ctx, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
TF_DeleteExecutionContext(graph_ctx);
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",
"mlir"),
::testing::Values(false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -12,17 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -34,16 +36,8 @@ namespace tensorflow {
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class AbstractContextInterface {
class ImmediateExecutionContext : public AbstractContext {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
@ -74,21 +68,20 @@ class AbstractContextInterface {
void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle(
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Copy the handle to another device.
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name,
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) = 0;
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
ImmediateExecutionOperation* CreateOperation() override = 0;
// Load a SavedModelAPI object from the given directory and tags
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
tensorflow::Status* status) = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
// code that is layered on top of the runtime.
virtual bool UsesTFRT() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
@ -108,14 +101,32 @@ class AbstractContextInterface {
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
virtual ~AbstractContextInterface() {}
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}
~ImmediateExecutionContext() override {}
};
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation {
public:
virtual void Clear() = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
explicit ImmediateExecutionOperation(AbstractOperationKind kind)
: AbstractOperation(kind) {}
~ImmediateExecutionOperation() override {}
};
namespace internal {
struct ImmediateExecutionOperationDeleter {
void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateOpPtr =
std::unique_ptr<ImmediateExecutionOperation,
internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_

View File

@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -30,16 +31,8 @@ namespace tensorflow {
// files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied.
class AbstractTensorHandleInterface {
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
public:
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns number of dimensions.
@ -57,12 +50,33 @@ class AbstractTensorHandleInterface {
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
virtual ImmediateExecutionTensorHandle* Copy() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
protected:
virtual ~AbstractTensorHandleInterface() {}
explicit ImmediateExecutionTensorHandle(AbstractTensorHandleKind kind)
: AbstractTensorHandle(kind) {}
~ImmediateExecutionTensorHandle() override {}
};
namespace internal {
struct ImmediateExecutionTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::ImmediateExecutionTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_

View File

@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
parallel_inputs.reserve(inputs.size());
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
for (const auto& input : inputs) {
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
std::unique_ptr<ParallelTensor> parallel_tensor(
parallel_device.CopyToParallelDevice(
context, absl::get<TFE_TensorHandle*>(input), status));
if (TF_GetCode(status) != TF_OK) return result;
parallel_inputs.push_back(parallel_tensor.get());
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
} else {
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
}
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
parallel_device.Execute(context, std::move(inputs), operation_name,
parallel_device.Execute(context, parallel_inputs, operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace parallel_device {
@ -28,21 +30,216 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
class StatusDeleter {
public:
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
};
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
return executors;
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
// If the context's default exector is set to async, re-using that in
// each thread would cause collectives to deadlock. For consistency we
// create a new sync executor for every thread.
//
// TODO(allenl): We should have an async API that works with the
// parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
TFE_ContextSetExecutorForThread(context, executor_.get());
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
@ -65,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
int32_t* device_id = new int32_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int32_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
delete reinterpret_cast<int32_t*>(data);
},
nullptr),
TF_DeleteTensor);
@ -86,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
@ -100,7 +297,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const std::vector<ParallelTensor*>& inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
@ -108,88 +305,34 @@ ParallelDevice::Execute(TFE_Context* context,
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor =
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
int first_op_output_count = 0;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
DeviceThread* device_thread = device_threads_[device_index].get();
std::vector<TFE_TensorHandle*> device_inputs;
device_inputs.reserve(device_inputs.size());
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
// Parallel tensors are divided between operations by device.
device_inputs.push_back(inputs[input_index]->tensor(device_index));
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
if (device_index == 0) {
first_op_output_count = real_num_outputs;
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (real_num_outputs != first_op_output_count) {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.

View File

@ -41,19 +41,8 @@ class TensorHandleDeleter {
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
@ -61,6 +50,8 @@ class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
@ -79,10 +70,9 @@ class ParallelDevice {
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
// its corresponding inputs from the input ParallelTensors. Wraps the
// resulting per-device and per-output TFE_TensorHandles into one
// ParallelTensor per output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
@ -90,7 +80,7 @@ class ParallelDevice {
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
@ -98,9 +88,19 @@ class ParallelDevice {
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// A sequence of thread wrappers, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
//
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the

View File

@ -407,11 +407,12 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
TFE_ContextOptionsSetAsync(opts.get(), async);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
@ -454,6 +455,12 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ExpectScalarEq<float>(result_components[1].get(), 3.);
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {

View File

@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
ExpectScalarEq<int32_t>(components[0].get(), 0);
ExpectScalarEq<int32_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
// Wraps a pointer to a context implementation.
//
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
} // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
// Wraps a pointer to an operation implementation.
//
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
} // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Wraps a pointer to a tensor handle implementation.
//
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
TFE_TensorHandle*);
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/types.h"
struct TF_StringStream {
@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() {
return list;
}
char* TF_GetTempFileName(const char* extension) {
return strdup(::tensorflow::io::GetTempFilename(extension).c_str());
}
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
return ::tensorflow::Env::Default()->NowNanos();
}

View File

@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
// The caller is responsible for freeing the list (see TF_StringStreamDone).
TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
// Creates a temporary file name with an extension.
// The caller is responsible for freeing the returned pointer.
TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension);
// Returns the number of nanoseconds since the Unix epoch.
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);

View File

@ -1,5 +1,5 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
@ -19,12 +19,45 @@ tf_cc_shared_object(
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
hdrs = ["gcs_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gcs_helper",
srcs = ["gcs_helper.cc"],
hdrs = ["gcs_helper.h"],
linkstatic = 1,
deps = [
"//tensorflow/c:env",
],
)
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [
"gcs_filesystem_test.cc",
],
tags = [
"manual",
"notap",
],
deps = [
":gcs_filesystem_impl",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
"@com_google_absl//absl/strings",
],
)

View File

@ -12,31 +12,237 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "absl/strings/numbers.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// How to upload new data when Flush() is called multiple times.
// By default the entire file is reuploaded.
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
// If GCS_APPEND_MODE=compose then instead the new data is uploaded to a
// temporary object and composed with the original object. This is disabled by
// default as the multiple API calls required add a risk of stranding temporary
// objects.
constexpr char kComposeAppend[] = "compose";
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't start with 'gs://'.");
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");
return;
}
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*object = fname.substr(bucket_end + 1);
if (object->empty() && !object_empty_ok) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
}
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
} GCSFile;
// TODO(vnvo2409): Implement later
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Adding cache.
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
auto stream = gcs_file->gcs_client->ReadObject(
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) &&
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
return -1;
}
int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
if (read != n) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
}
stream.read(buffer, read);
return read;
}
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
// `offset` tells us how many bytes of this file are already uploaded to
// server. If `offset == -1`, we always upload the entire temporary file.
int64_t offset;
} GCSFile;
// TODO(vnvo2409): Implement later
static void SyncImpl(const std::string& bucket, const std::string& object,
int64_t* offset, TempFile* outfile,
gcs::Client* gcs_client, TF_Status* status) {
outfile->flush();
// `*offset == 0` means this file does not exist on the server.
if (*offset == -1 || *offset == 0) {
// UploadFile will automatically switch to resumable upload based on Client
// configuration.
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
if (*offset == 0) {
if (!outfile->truncate()) {
TF_SetStatus(status, TF_INTERNAL,
"Could not truncate internal temporary file.");
return;
}
*offset = static_cast<int64_t>(metadata->size());
}
outfile->clear();
outfile->seekp(std::ios::end);
TF_SetStatus(status, TF_OK, "");
} else {
std::string temporary_object =
gcs::CreateRandomPrefixName("tf_writable_file_gcs");
auto metadata =
gcs_client->UploadFile(outfile->getName(), bucket, temporary_object);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
const std::vector<gcs::ComposeSourceObject> source_objects = {
{object, {}, {}}, {temporary_object, {}, {}}};
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
// We truncate the data that are already uploaded.
if (!outfile->truncate()) {
TF_SetStatus(status, TF_INTERNAL,
"Could not truncate internal temporary file.");
return;
}
*offset = static_cast<int64_t>(metadata->size());
TF_SetStatus(status, TF_OK, "");
}
}
void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (!gcs_file->outfile.is_open()) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"The internal temporary file is not writable.");
return;
}
gcs_file->sync_need = true;
gcs_file->outfile.write(buffer, n);
if (!gcs_file->outfile)
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
else
TF_SetStatus(status, TF_OK, "");
}
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
int64_t position = int64_t(gcs_file->outfile.tellp());
if (position == -1)
TF_SetStatus(status, TF_INTERNAL,
"tellp on the internal temporary file failed");
else
TF_SetStatus(status, TF_OK, "");
return position == -1
? -1
: position + (gcs_file->offset == -1 ? 0 : gcs_file->offset);
}
void Flush(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->sync_need) {
if (!gcs_file->outfile) {
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
return;
}
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
&gcs_file->outfile, gcs_file->gcs_client, status);
if (TF_GetCode(status) != TF_OK) return;
gcs_file->sync_need = false;
} else {
TF_SetStatus(status, TF_OK, "");
}
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
Flush(file, status);
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->sync_need) {
Flush(file, status);
}
gcs_file->outfile.close();
}
} // namespace tf_writable_file
@ -51,8 +257,109 @@ namespace tf_read_only_memory_region {
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
typedef struct GCSFile {
gcs::Client gcs_client; // owned
bool compose;
} GCSFile;
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
const char* append_mode = std::getenv(kAppendMode);
bool compose =
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
filesystem->plugin_filesystem =
new GCSFile({std::move(client.value()), compose});
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
delete gcs_file;
}
// TODO(vnvo2409): Implement later
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
file->plugin_file = new tf_random_access_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true,
(gcs_file->compose ? 0 : -1)});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
char* temp_file_name_c_str = TF_GetTempFileName("");
std::string temp_file_name(temp_file_name_c_str); // To prevent memory-leak
free(temp_file_name_c_str);
if (!gcs_file->compose) {
auto gcs_status =
gcs_file->gcs_client.DownloadToFile(bucket, object, temp_file_name);
TF_SetStatusFromGCSStatus(gcs_status, status);
auto status_code = TF_GetCode(status);
if (status_code != TF_OK && status_code != TF_NOT_FOUND) return;
// If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need,
-1});
} else {
// If compose is true, we do not download anything.
// Instead we only check if this file exists on server or not.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_OK) {
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::trunc), false,
static_cast<int64_t>(metadata->size())});
} else if (TF_GetCode(status) == TF_NOT_FOUND) {
file->plugin_file = new tf_writable_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::trunc), true,
0});
} else {
return;
}
}
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_gcs_filesystem
@ -60,6 +367,25 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
ops->random_access_file_ops->read = tf_random_access_file::Read;
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
ops->filesystem_ops->new_random_access_file =
tf_gcs_filesystem::NewRandomAccessFile;
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
@ -69,4 +395,4 @@ void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "gs");
}
}

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status);
namespace tf_random_access_file {
void Cleanup(TF_RandomAccessFile* file);
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status);
} // namespace tf_random_access_file
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <random>
#include "absl/strings/string_view.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
// We will work with content_view instead of content.
static const absl::string_view content_view = content;
namespace gcs = google::cloud::storage;
namespace tensorflow {
namespace {
class GCSFilesystemTest : public ::testing::Test {
public:
void SetUp() override {
root_dir_ = io::JoinPath(
tmp_dir_,
::testing::UnitTest::GetInstance()->current_test_info()->name());
status_ = TF_NewStatus();
filesystem_ = new TF_Filesystem;
tf_gcs_filesystem::Init(filesystem_, status_);
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
<< TF_Message(status_);
}
void TearDown() override {
TF_DeleteStatus(status_);
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_;
}
static bool InitializeTmpDir() {
// This env should be something like `gs://bucket/path`
const char* test_dir = getenv("GCS_TEST_TMPDIR");
if (test_dir != nullptr) {
std::string bucket, object;
TF_Status* status = TF_NewStatus();
ParseGCSPath(test_dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteStatus(status);
return false;
}
TF_DeleteStatus(status);
// We add a random value into `test_dir` to ensures that two consecutive
// runs are unlikely to clash.
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> distribution;
std::string rng_val = std::to_string(distribution(gen));
tmp_dir_ = io::JoinPath(string(test_dir), rng_val);
return true;
} else {
return false;
}
}
std::string GetURIForPath(absl::string_view path) {
const std::string translated_name =
tensorflow::io::JoinPath(root_dir_, path);
return translated_name;
}
protected:
TF_Filesystem* filesystem_;
TF_Status* status_;
private:
std::string root_dir_;
static std::string tmp_dir_;
};
std::string GCSFilesystemTest::tmp_dir_;
::testing::AssertionResult WriteToServer(const std::string& path, size_t length,
gcs::Client* gcs_client,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) {
return ::testing::AssertionFailure() << TF_Message(status);
}
auto writer = gcs_client->WriteObject(bucket, object);
writer.write(content, length);
writer.Close();
if (writer.metadata()) {
return ::testing::AssertionSuccess();
} else {
return ::testing::AssertionFailure()
<< writer.metadata().status().message();
}
}
::testing::AssertionResult CompareSubString(int64_t offset, size_t n,
absl::string_view result,
size_t read) {
// Result isn't a null-terminated string so we have to wrap it inside a
// `string_view`
if (n == read && content_view.substr(offset, n) ==
absl::string_view(result).substr(0, read)) {
return ::testing::AssertionSuccess();
} else {
return ::testing::AssertionFailure()
<< "Result: " << absl::string_view(result).substr(0, read)
<< " Read:" << read;
}
}
TEST_F(GCSFilesystemTest, ParseGCSPath) {
std::string bucket, object;
ParseGCSPath("gs://bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ASSERT_EQ(object, "path/to/object");
ParseGCSPath("gs://bucket/", true, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ParseGCSPath("bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
// bucket name must end with "/"
ParseGCSPath("gs://bucket", true, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
ParseGCSPath("gs://bucket/", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
}
TEST_F(GCSFilesystemTest, RandomAccessFile) {
std::string filepath = GetURIForPath("a_file");
TF_RandomAccessFile* file = new TF_RandomAccessFile;
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
status_);
ASSERT_TF_OK(status_);
char* result = new char[content_view.length()];
int64_t read = tf_random_access_file::Read(file, 0, 1, result, status_);
ASSERT_EQ(read, -1) << "Read: " << read;
ASSERT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
TF_SetStatus(status_, TF_OK, "");
auto gcs_client = static_cast<gcs::Client*>(filesystem_->plugin_filesystem);
ASSERT_TRUE(
WriteToServer(filepath, content_view.length(), gcs_client, status_));
read = tf_random_access_file::Read(file, 0, content_view.length(), result,
status_);
ASSERT_TF_OK(status_);
ASSERT_TRUE(CompareSubString(0, content_view.length(), result, read));
read = tf_random_access_file::Read(file, 0, 4, result, status_);
ASSERT_TF_OK(status_);
ASSERT_TRUE(CompareSubString(0, 4, result, read));
read = tf_random_access_file::Read(file, content_view.length() - 2, 4, result,
status_);
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
ASSERT_TRUE(CompareSubString(content_view.length() - 2, 2, result, read));
delete result;
tf_random_access_file::Cleanup(file);
delete file;
}
} // namespace
} // namespace tensorflow
GTEST_API_ int main(int argc, char** argv) {
tensorflow::testing::InstallStacktraceHandler();
if (!tensorflow::GCSFilesystemTest::InitializeTmpDir()) {
std::cerr << "Could not read GCS_TEST_TMPDIR env";
return -1;
}
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include <stdio.h>
#include <fstream>
#include <string>
#include <utility>
TempFile::TempFile(const std::string& temp_file_name, std::ios::openmode mode)
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
TempFile::TempFile(TempFile&& rhs)
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
TempFile::~TempFile() {
std::fstream::close();
std::remove(name_.c_str());
}
const std::string TempFile::getName() const { return name_; }
bool TempFile::truncate() {
std::fstream::close();
std::fstream::open(name_, std::ios::binary | std::ios::out);
return std::fstream::is_open();
}

View File

@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <cstdlib>
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
// This is a demo fuzzer to test that the entire framework functions correctly.
// Once we start moving the existing fuzzers to this framework we will delete
// this.
// TODO(mihaimaruseac): Delete this when no longer needed
void DemoFuzzer(const uint8_t* data, size_t size) {
// Trigger a small bug that should be found by the fuzzer quite quickly
if (size > 10 && size % 3 == 2)
if (data[0] > data[1])
if (data[5] % data[2] == data[3]) abort();
}
#include <fstream>
#include <string>
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
DemoFuzzer(data, size);
return 0;
}
class TempFile : public std::fstream {
public:
// We should specify openmode each time we call TempFile.
TempFile(const std::string& temp_file_name, std::ios::openmode mode);
TempFile(TempFile&& rhs);
~TempFile() override;
const std::string getName() const;
bool truncate();
private:
const std::string name_;
};
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_

View File

@ -3,6 +3,10 @@
# Targets in this directory are pure C++ "Classes" underlying the C API types
# under tf/c/experimental/saved_model/public/. They are subject to change and
# have visibility limited to Tensorflow's implementation only.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
@ -23,8 +27,8 @@ cc_library(
],
deps = [
":function_metadata",
"//tensorflow/c/eager:operation_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
],
)
@ -47,6 +51,46 @@ cc_library(
],
)
cc_library(
name = "saved_model_utils",
srcs = [
"saved_model_utils.cc",
],
hdrs = [
"saved_model_utils.h",
],
deps = [
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/core:protos_all_cc",
],
)
cc_library(
name = "test_utils",
testonly = True,
srcs = [
"test_utils.cc",
],
hdrs = [
"test_utils.h",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tf_saved_model_impl",
srcs = [
@ -57,6 +101,7 @@ cc_library(
":concrete_function",
":saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
@ -83,3 +128,47 @@ filegroup(
],
visibility = ["//tensorflow/core:__pkg__"],
)
tf_cc_test(
name = "constant_loading_test",
srcs = [
"constant_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "saved_variable_loading_test",
srcs = [
"saved_variable_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
ConcreteFunction::GetCaptures() const {
return captures_;
}

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
@ -38,15 +38,15 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetCallOp() = 0;
virtual ImmediateExecutionOperation* GetCallOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
const;
const FunctionMetadata& GetFunctionMetadata() const;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
FunctionDef* function_;
};

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class ConstantTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>, bool>> {
public:
ConstantTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
// preserves values.
TEST_P(ConstantTest, CreateConstantSuccessful) {
// Get test parameters
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
bool tensorproto_use_tensor_content = std::get<2>(test_params);
// Construct a Tensor with the given dtype + shape
Tensor expected(dtype, shape);
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
expected.data(), 42);
// Serialize it to a Tensorproto
TensorProto proto;
if (tensorproto_use_tensor_content) {
expected.AsProtoTensorContent(&proto);
} else {
expected.AsProtoField(&proto);
}
// Revival should succeed w/o errors
std::unique_ptr<Constant> revived;
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
// The revived tensorhandle should have the exact same dtype, shape, +
// approx equivalent data to the original.
ImmediateExecutionTensorHandle* handle = revived->handle();
Status status;
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
for (int i = 0; i < expected.dims(); ++i) {
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
}
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
revived_tensor->Data(), expected.data());
}
// Test against combinations of tensors that are
// 1. Varying dtypes
// 2. Varying shapes
// 3. TensorProto serialized using tensor_content vs repeated type
INSTANTIATE_TEST_SUITE_P(
ConstantIntegerDtypesTest, ConstantTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
INSTANTIATE_TEST_SUITE_P(
ConstantFloatingDtypesTest, ConstantTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes()),
::testing::Values(false, true)));
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
# This package contains written convenience helpers for Eager Operations
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
"//tensorflow/c/experimental/saved_model/internal:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "variable_ops",
srcs = [
"variable_ops.cc",
],
hdrs = [
"variable_ops.h",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "variable_ops_test",
srcs = [
"variable_ops_test.cc",
],
deps = [
":variable_ops",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -0,0 +1,118 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
static const char kNoSharingResourceID[] =
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
// shape.dims() will be -1.
gtl::InlinedVector<int64, 4> dim_sizes = shape.dim_sizes();
TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape(
"shape", reinterpret_cast<const int64_t*>(dim_sizes.data()),
shape.dims()));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString("container", "", 0));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
AbstractTensorHandle* var_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_var_handle(var_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_var_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_var_handle.release()));
return Status();
}
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value) {
ImmediateOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
int num_retvals = 0;
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
return Status();
}
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateTensorHandlePtr* output) {
ImmediateOpPtr read_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
AbstractTensorHandle* value = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_value(value);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(owned_value.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
output->reset(
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
return Status();
}
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle) {
ImmediateOpPtr destroy_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
int num_retvals = 0;
TF_RETURN_IF_ERROR(destroy_op->Execute({}, &num_retvals));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,61 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace internal {
// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE
// TensorHandle associated with the variable. This is equivalent to creating an
// unitialized TF2 tf.Variable.
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
// underlying variable for `variable_handle`. Note that it is illegal to assign
// a variable to a Tensor with a different dtype than what the variable was
// created with.
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value);
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
// value of `variable_handle` and copies the value to `output`. `dtype` must be
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H

View File

@ -0,0 +1,106 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
class VariableOpsTest : public ::testing::Test {
public:
VariableOpsTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0"))),
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Sanity check for variable creation
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// The created TensorHandle should be a DT_Resource
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
}
// Sanity check for variable destruction
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// Destroy the variable
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
}
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
ImmediateTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
TF_EXPECT_OK(status);
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,60 @@
# This package contains classes corresponding to Revived SavedObjectGraph types
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__pkg__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "constant",
srcs = [
"constant.cc",
],
hdrs = [
"constant.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "variable",
srcs = [
"variable.cc",
],
hdrs = [
"variable.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tensorhandle_convertible",
hdrs = [
"tensorhandle_convertible.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
Constant::Constant(ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)) {}
Status Constant::Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output) {
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
if (handle == nullptr) {
return errors::Internal("Failed to convert tensor to tensorhandle");
}
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.pb.h"
namespace tensorflow {
// This class corresponds to python's tf.constant, which is effectively a
// TensorHandle explicitly initialized to some value.
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
// and offer a subclass of TensorHandleConvertible. Note that similar to
// the python's eager mode logic, we bypass calling the "Const" op:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
class Constant : public TensorHandleConvertible {
public:
static Status Create(ImmediateExecutionContext* ctx,
AbstractTensorInterface* tensor,
std::unique_ptr<Constant>* output);
// RevivedConstant is movable, but not copyable.
Constant(Constant&& other) = default;
Constant& operator=(Constant&& other) = default;
~Constant() override = default;
private:
explicit Constant(ImmediateTensorHandlePtr handle);
Constant(const Constant&) = delete;
Constant& operator=(const Constant&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
namespace tensorflow {
// A common interface for objects that can be converted to a TensorHandle.
// Examples of objects that implement this include Variables, Constants, Assets,
// etc. This is used to convert captured objects into a ConcreteFunction's
// captured TensorHandles:
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
class TensorHandleConvertible {
public:
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
: handle_(std::move(handle)) {}
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
// TensorHandleConvertible is movable, but not copyable.
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
virtual ~TensorHandleConvertible() = default;
protected:
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
ImmediateTensorHandlePtr handle_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
TensorShape shape, absl::optional<std::string> name,
ImmediateTensorHandlePtr handle)
: TensorHandleConvertible(std::move(handle)),
name_(name.has_value() ? *name : "Variable"),
dtype_(dtype),
shape_(shape),
ctx_(ctx) {}
Variable::~Variable() {
// If the handle is null (perhaps because variable was std::moved from), then
// we don't have to do anything.
if (handle_ == nullptr) {
return;
}
Status status = internal::DestroyResource(ctx_, handle_.get());
if (!status.ok()) {
LOG(ERROR) << "Error destroying variable: " << name_
<< "due to: " << status;
}
}
DataType Variable::dtype() { return dtype_; }
TensorShape Variable::shape() { return shape_; }
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
}
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
}
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output) {
ImmediateTensorHandlePtr handle;
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
ctx, dtype, shape, &handle));
output->reset(
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,76 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
#include <memory>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
class Variable : public TensorHandleConvertible {
public:
// Creates an uninitialized resource variable. Note that a caller must
// call "assign" to associate a value with the variable.
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
absl::optional<std::string> name,
std::unique_ptr<Variable>* output);
// The dtype of the underlying variable.
DataType dtype();
// The shape of the underlying variable.
TensorShape shape();
// Updates the variable's contents with `handle`.
Status Assign(ImmediateExecutionTensorHandle* handle);
// Reads the value of the variable, and stores it in `out`
Status ReadValue(ImmediateTensorHandlePtr* out);
// Variable is movable, but not copyable.
Variable(Variable&& other) = default;
Variable& operator=(Variable&& other) = default;
~Variable() override;
private:
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
Variable(const Variable& variable) = delete;
Variable& operator=(const Variable&) = delete;
std::string name_;
DataType dtype_;
TensorShape shape_;
// ctx_ must outlive Variable.
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output) {
tensorflow::Tensor tensor;
bool parse_result = tensor.FromProto(proto);
if (!parse_result) {
return errors::Internal("Failed to parse tensor from tensorproto");
}
TensorInterface tensor_interface(std::move(tensor));
return Constant::Create(ctx, &tensor_interface, output);
}
// This follows the python variable restoration logic:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output) {
const std::string& name = variable.name();
tensorflow::TensorShape shape(variable.shape());
tensorflow::DataType dtype = variable.dtype();
TF_RETURN_IF_ERROR(
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
// Some internal utility functions for the SavedModelAPI, factored out into a
// separately unit-testable header.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace internal {
// Load a TensorProto into a tensorflow::Constant. This is similar to the
// constant loading logic in python:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto,
std::unique_ptr<Constant>* output);
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
// logic in:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
// Note that the caller **must assign a value** to the loaded variable.
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable,
std::unique_ptr<Variable>* output);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_

View File

@ -0,0 +1,122 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace {
class SavedVariableLoadingTest : public ::testing::TestWithParam<
std::tuple<DataType, std::vector<int64>>> {
public:
SavedVariableLoadingTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
// 1. does not cause an error
// 2. preserves dtype and shape.
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
TensorShape shape(std::get<1>(test_params));
SavedVariable saved_variable;
saved_variable.set_dtype(dtype);
shape.AsProto(saved_variable.mutable_shape());
std::unique_ptr<Variable> var;
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
EXPECT_EQ(var->dtype(), dtype);
EXPECT_EQ(var->shape(), shape);
}
// Assigning and reading values should yield
// consistent results.
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
auto& test_params = GetParam();
DataType dtype = std::get<0>(test_params);
std::vector<int64> shape_vector = std::get<1>(test_params);
TensorShape shape(shape_vector);
// Create the variable.
Status status;
std::unique_ptr<Variable> var;
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
absl::nullopt, &var));
// Create a TensorHandle
ImmediateTensorHandlePtr expected_handle =
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Assign the tensorhandle to the variable.
TF_EXPECT_OK(var->Assign(expected_handle.get()));
// Read back the value from the variable
ImmediateTensorHandlePtr output_handle;
TF_EXPECT_OK(var->ReadValue(&output_handle));
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
TF_EXPECT_OK(status) << status.error_message();
// Check that output_tensor == expected_tensor
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
testing::CheckBufferDataIsEqual(
output_tensor->Type(), output_tensor->NumElements(),
output_tensor->Data(), expected_tensor->Data());
}
// Test against combinations of SavedVariables of
// 1. Varying dtypes
// 2. Varying shapes
INSTANTIATE_TEST_SUITE_P(
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
::testing::Combine(
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
::testing::ValuesIn(testing::InterestingShapes())));
INSTANTIATE_TEST_SUITE_P(
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
::testing::ValuesIn(testing::InterestingShapes())));
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
return std::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
}
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
return EagerContextPtr(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr));
}
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
std::vector<DataType> result;
result.reserve(set.size());
for (DataType dt : set) {
result.push_back(dt);
}
return result;
}
std::vector<std::vector<int64>> InterestingShapes() {
std::vector<std::vector<int64>> interesting_shapes;
interesting_shapes.push_back({}); // Scalar
interesting_shapes.push_back({10}); // 1D Vector
interesting_shapes.push_back({3, 3}); // 2D Matrix
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
return interesting_shapes;
}
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value) {
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
CHECK_NE(tensor.get(), nullptr)
<< "Tensor creation failed for tensor of dtype: "
<< DataTypeString(dtype);
CHECK_EQ(tensor->Type(), dtype);
for (int i = 0; i < shape.size(); ++i) {
CHECK_EQ(tensor->Dim(i), shape[i]);
}
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
value);
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
CHECK_NE(handle.get(), nullptr);
return handle;
}
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_buffer = static_cast<type*>(buffer); \
for (size_t i = 0; i < num_elements; ++i) { \
typed_buffer[i] = value; \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
break;
}
}
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b) {
switch (dtype) {
#define CASE(type) \
case DataTypeToEnum<type>::value: { \
type* typed_a = static_cast<type*>(a); \
type* typed_b = static_cast<type*>(b); \
for (int64 i = 0; i < num_elements; ++i) { \
if (DataTypeIsFloating(dtype)) { \
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
} else { \
EXPECT_EQ(typed_a[i], typed_b[i]); \
} \
} \
break; \
}
TF_CALL_INTEGRAL_TYPES(CASE);
TF_CALL_double(CASE);
TF_CALL_float(CASE);
#undef CASE
default:
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
}
}
} // namespace testing
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace testing {
// Creates a DeviceMgr suitable for local tests.
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
// Creates an EagerContext suitable for local tests. Does not take ownership
// of `device_mgr`.
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
// This is useful for tests using GTest's ::testing::ValuesIn, since
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
// Returns a vector of shapes intended to be "interesting" test cases.
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
std::vector<std::vector<int64>> InterestingShapes();
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
// `dtype` must be an integer dtype, float, or double.
// If a TensorHandle cannot be created successfully, this function will
// CHECK fail. This should only be used for testing purposes.
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
DataType dtype,
absl::Span<const int64> shape,
int8 value);
// Fills a numeric tensor's buffer with `value`.
// dtype must be any integer dtype, float or double.
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
int8 value);
// Checks the underlying data is equal for the buffers for two numeric tensors.
// Note: The caller must ensure to check that the dtypes and sizes of the
// underlying buffers are the same before calling this.
// dtype must be any integer dtype, float, or double.
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b);
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) {
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI {
public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out);
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default;
private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_;
};

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)
@ -176,7 +178,7 @@ cc_library(
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
@ -188,7 +190,7 @@ cc_library(
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
@ -34,10 +38,21 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i]));
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) {
return nullptr;
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>,
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
TF_TensorHandleList)
} // namespace tensorflow

View File

@ -54,6 +54,20 @@ class AbstractTensorInterface {
virtual ~AbstractTensorInterface() {}
};
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -228,57 +228,6 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
} // namespace tensorflow
// --------------------------------------------------------------------------
void StringEncode(const char* src, size_t src_len, char* dst) {
dst = tensorflow::core::EncodeVarint64(dst, src_len);
memcpy(dst, src, src_len);
}
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
size_t dst_len, TF_Status* status) {
const size_t sz = TF_StringEncodedSize(src_len);
if (sz < src_len) {
Set_TF_Status_from_Status(
status, InvalidArgument("src string is too large to encode"));
return 0;
}
if (dst_len < sz) {
Set_TF_Status_from_Status(
status,
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
src_len, "-byte string"));
return 0;
}
StringEncode(src, src_len, dst);
return sz;
}
static Status TF_StringDecode_Impl(const char* src, size_t src_len,
const char** dst, size_t* dst_len) {
tensorflow::uint64 len64 = 0;
const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
if (p == nullptr) {
return InvalidArgument("invalid string encoding or truncated src buffer");
}
if (len64 > std::numeric_limits<size_t>::max()) {
return InvalidArgument("encoded string is ", len64,
"-bytes, which is too large for this architecture");
}
*dst = p;
*dst_len = static_cast<size_t>(len64);
return Status::OK();
}
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
size_t* dst_len, TF_Status* status) {
Set_TF_Status_from_Status(status,
TF_StringDecode_Impl(src, src_len, dst, dst_len));
if (TF_GetCode(status) != TF_OK) return 0;
return static_cast<size_t>(*dst - src) + *dst_len;
}
size_t TF_StringEncodedSize(size_t len) {
return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
}
static void DeleteArray(void* data, size_t size, void* arg) {
DCHECK_EQ(data, arg);
@ -334,58 +283,12 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
return t;
}
if (src.dtype() != tensorflow::DT_STRING) {
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
// Compute bytes needed for encoding.
size_t size = 0;
const auto& srcarray = src.flat<tstring>();
for (int i = 0; i < srcarray.size(); ++i) {
const string& s = srcarray(i);
// uint64 starting_offset, TF_StringEncode-d string.
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
}
// Encode all strings.
char* base = new char[size];
char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
char* dst = data_start; // Where next string is encoded.
size_t dst_len = size - static_cast<size_t>(data_start - base);
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
for (int i = 0; i < srcarray.size(); ++i) {
*offsets = (dst - data_start);
offsets++;
const string& s = srcarray(i);
const size_t consumed = TF_StringEncodedSize(s.size());
StringEncode(s.data(), s.size(), dst);
dst += consumed;
dst_len -= consumed;
}
if (dst != base + size) {
*status = InvalidArgument(
"invalid string tensor encoding (decoded ", (dst - base),
" bytes, but the tensor is encoded in ", size, " bytes");
delete[] base;
Tensor tensor;
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
auto dims = src.shape().dim_sizes();
std::vector<tensorflow::int64> dimvec(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
dimvec[i] = dims[i];
}
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
return TF_NewTensor(TF_STRING,
reinterpret_cast<const int64_t*>(dimvec.data()),
dimvec.size(), base, size, DeleteArray, base);
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
@ -409,39 +312,7 @@ Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
}
return Status::OK();
}
if (tensor_.dtype() != DT_STRING) {
*dst = tensor_;
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
// string objects.
const tensorflow::int64 num_elements = tensor_.NumElements();
const char* input = reinterpret_cast<const char*>(Data());
const size_t src_size = ByteSize();
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
num_elements) {
return InvalidArgument(
"Malformed TF_STRING tensor; too short to hold number of elements");
}
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
const char* limit = input + src_size;
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
auto dstarray = dst->flat<tstring>();
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
tensorflow::uint64 offset =
reinterpret_cast<const tensorflow::uint64*>(input)[i];
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
return InvalidArgument("Malformed TF_STRING tensor; element ", i,
" out of range");
}
size_t len;
const char* p;
const char* srcp = data_start + offset;
Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
if (!status.ok()) return status;
dstarray(i).assign(p, len);
}
*dst = tensor_;
return Status::OK();
}

View File

@ -148,34 +148,6 @@ TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from,
int num_new_dims,
TF_Status* status);
// --------------------------------------------------------------------------
// Encode the string `src` (`src_len` bytes long) into `dst` in the format
// required by TF_STRING tensors. Does not write to memory more than `dst_len`
// bytes beyond `*dst`. `dst_len` should be at least
// TF_StringEncodedSize(src_len).
//
// On success returns the size in bytes of the encoded string.
// Returns an error into `status` otherwise.
TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len,
char* dst, size_t dst_len,
TF_Status* status);
// Decode a string encoded using TF_StringEncode.
//
// On success, sets `*dst` to the start of the decoded string and `*dst_len` to
// its length. Returns the number of bytes starting at `src` consumed while
// decoding. `*dst` points to memory within the encoded buffer. On failure,
// `*dst` and `*dst_len` are undefined and an error is set in `status`.
//
// Does not read memory more than `src_len` bytes beyond `src`.
TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len,
const char** dst, size_t* dst_len,
TF_Status* status);
// Return the size in bytes required to encode a string `len` bytes long into a
// TF_STRING tensor.
TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
// Returns bool iff this tensor is aligned.
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);

20
tensorflow/c/tf_tstring.h Normal file
View File

@ -0,0 +1,20 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_TF_TSTRING_H_
#define TENSORFLOW_C_TF_TSTRING_H_
#include "tensorflow/core/platform/ctstring.h"
#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_

View File

@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
RunTest(x, x_init_value, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, MaxPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
SetRandomValuesForMaxPooling<float>(&x_init_value);
RunTest(x, x_init_value, y, y_shape);
}
#endif
TEST_F(NNGradTest, AvgPoolGradHelper) {
TensorShape x_shape({1, 2, 2, 1});
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
RunTest(x, x_shape, y, y_shape);
}
// TODO(rocm):
// Re-enable this test once 3D pooling is supported on ROCm platform
#ifndef TENSORFLOW_USE_ROCM
TEST_F(NNGradTest, AvgPool3DGradHelper) {
TensorShape x_shape({1, 3, 3, 3, 1});
TensorShape y_shape({1, 1, 1, 1, 1});
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
RunTest(x, x_shape, y, y_shape);
}
#endif
TEST_F(NNGradTest, LRN) {
TensorShape x_shape({1, 1, 2, 1});

View File

@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
":loader_util",
":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
],
)
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test(
name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"],

View File

@ -10,6 +10,9 @@ tf_cc_test(
srcs = [
"saved_model_api_test.cc",
],
data = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
deps = [
"//tensorflow/cc/experimental/base/public:runtime",
"//tensorflow/cc/experimental/base/public:runtime_builder",

View File

@ -90,7 +90,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK();
}
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -67,13 +67,14 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:arm_target", # fixdeps: keep
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
"@llvm-project//llvm:target_base",
"@llvm-project//llvm:x86_target", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -94,8 +95,8 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_target", # fixdeps: keep
"@llvm-project//llvm:Support", # fixdeps: keep
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
],
)
@ -109,12 +110,12 @@ cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_target", # fixdeps: keep
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
"@llvm-project//llvm:target_base",
"@llvm-project//llvm:x86_target", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
] + if_llvm_aarch64_available([
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -286,9 +287,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:ir",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target_base",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
],
)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/base/call_once.h"
#include "llvm-c/Target.h"
#include "llvm/Support/ManagedStatic.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/aot/quantize.h"

View File

@ -1,5 +1,5 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and

View File

@ -4,6 +4,7 @@ traces: {
value: {
file_line_cols: {
line: 1
col: 1
}
}
}
@ -12,9 +13,11 @@ traces: {
value: {
file_line_cols: {
line: 3
col: 1
}
file_line_cols: {
line: 4
col: 1
}
}
}
@ -23,6 +26,7 @@ traces: {
value: {
file_line_cols: {
line: 2
col: 1
}
}
}

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list;
absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names;
};
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -108,8 +108,7 @@ class XlaExecutableClosure {
explicit XlaExecutableClosure(
xla::LocalClient* client, xla::LocalExecutable* executable,
const XlaCompiler::CompilationResult* compilation_result,
std::map<int, OptionalTensor> resource_var_snapshots,
int num_constant_args)
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
: client_(client),
executable_(executable),
compilation_result_(compilation_result),
@ -124,7 +123,7 @@ class XlaExecutableClosure {
const XlaCompiler::CompilationResult* compilation_result() const {
return compilation_result_;
}
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
const ResourceVarsSnapshot& resource_var_snapshots() const {
return resource_var_snapshots_;
}
int num_constant_args() const { return num_constant_args_; }
@ -133,7 +132,7 @@ class XlaExecutableClosure {
xla::LocalClient* client_;
xla::LocalExecutable* executable_;
const XlaCompiler::CompilationResult* compilation_result_;
std::map<int, OptionalTensor> resource_var_snapshots_;
ResourceVarsSnapshot resource_var_snapshots_;
int num_constant_args_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
std::map<int, OptionalTensor>* variables,
const XlaCompiler::CompilationResult** kernel,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
// in the ResourceMgr.
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
constant_args, variable_infos, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
kernel, executable);
compilation_result, executable);
}
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
&executable);
variable_infos, constants_, /*lazy=*/false, &client,
&compilation_result, &executable);
OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
}
se::Stream* stream =
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
client, allocator,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, kernel, variables,
launch_context.PopulateInputs(ctx, compilation_result, variables,
/*missing_ctx_input_prefix=*/0);
// Execute the computation.
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, kernel, run_result.ConsumeValueOrDie(),
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
VLOG(1) << "Done";
}
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
ResourceVarsSnapshot variables;
bool cannot_compile_cluster;
{
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
cannot_compile_cluster) {
executable = nullptr;
} else {
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}

View File

@ -1837,7 +1837,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
// clang-format on
return result;
}

View File

@ -28,32 +28,23 @@ limitations under the License.
namespace tensorflow {
namespace {
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
std::map<int, OptionalTensor> variables;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
// Returns argument indices corresponding to the resource variable inputs of
// kernel context `ctx`.
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
std::vector<int> out;
for (int64 i = 0; i < ctx->num_inputs(); i++) {
if (ctx->input(i).dtype() == DT_RESOURCE) {
core::RefCountPtr<Var> variable;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
}
out.push_back(i);
}
}
return variables;
return out;
}
} // namespace
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable) {
std::map<int, OptionalTensor> variables = GetVariables(ctx);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
// Builds an XLA allocator for the device.
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables,
launch_context.PopulateInputs(ctx, result, variable_args,
/*missing_ctx_input_prefix=*/0);
se::Stream* stream =
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
return Status::OK();
}
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
xla::LocalExecutable** executable) {
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_args, ctx, &args));
{
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
ctx, variables_indices, variable_infos, variable_args));
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
ResourceVarsSnapshot variable_args;
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable);
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
};
} // namespace tensorflow

View File

@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
"resource inputs to XLA.";
} // anonymous namespace
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
: index_(index), name_(name), var_(var) {}
VariableInfo::VariableInfo(VariableInfo&& other)
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
other.index_ = -1;
@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() {
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
static Status GetVariableInfosFromCtxInputs(
OpKernelContext* ctx, absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
// *Release* the variable because we're going to unref it later in
// ~VariableInfo.
Var* variable = variables[i].release();
result->emplace_back(variable_indices[i], variable);
int input_idx = variable_indices[i];
std::string var_name = HandleFromInput(ctx, input_idx).name();
result->emplace_back(input_idx, var_name, variable);
}
return Status::OK();
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result) {
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result) {
for (int i = 0; i < variable_indices.size(); i++) {
if (variable_infos[i].var()) {
OptionalTensor& tensor = (*result)[variable_indices[i]];
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
tensor.present = true;
tensor.value = *variable_infos[i].var()->tensor();
} else {
(*result)[variable_indices[i]] = OptionalTensor();
}
Var* var = variable_infos[i].var();
(*result)[variable_indices[i]] =
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
}
return Status::OK();
}
@ -195,53 +188,45 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
// Pass remaining parameters.
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
xla::TransferManager* transfer_manager =
client_->backend().transfer_manager();
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value())
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
if (use_multiple_streams_) {
CHECK(stream) << "Must have a stream available when using XLA tensors!";
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
<< "Must have a stream available when using XLA tensors!";
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor);
xla_tensor->WaitForDefinitionEventOnStream(stream);
xla_tensor->WaitForDefinitionEventOnStream(
ctx->op_device_context()->stream());
}
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (on_device_shape.IsTuple()) {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
} else {
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
on_device_shape))
<< "On-device shape "
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
} else {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
}
}
}
@ -269,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
if (MustAliasOutput(input_output_alias, output_num)) {
int xla_param = input_output_alias.GetAliasedParameter({output_num})
.value()
@ -281,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
// entry time.
if (input_tensor->dtype() == DT_RESOURCE) {
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.present);
return &v.value;
CHECK(v.has_value());
return &v.value();
}
return input_tensor;
}
@ -305,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const std::map<int, OptionalTensor>& resource_var_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
Allocator* output_allocator) {
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
input_mapping, resource_var_snapshots)) {
@ -370,13 +355,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
return Status::OK();
}
// Sets output `output_num` for `ctx` provided it is known at a compile time.
static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(output_num, const_tensor);
output_tensor = ctx->mutable_output(output_num);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
}
return variable_infos;
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
const ResourceVarsSnapshot& resource_var_snapshots) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
@ -384,7 +450,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple
// buffer with a nullptr root index table. This allows the code below to treat
@ -410,89 +476,70 @@ Status XlaComputationLaunchContext::PopulateOutputs(
stream->ThenRecordEvent(definition_event.get());
}
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(ctx->num_outputs());
if (output.on_host_shape().is_dynamic()) {
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);
}
} else {
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
}
}
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
if (kernel->outputs[i].is_constant) {
// Output is a constant.
const Tensor& const_tensor = kernel->outputs[i].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
TF_RETURN_IF_ERROR(
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(i, const_tensor);
output_tensor = ctx->mutable_output(i);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
if (compilation_result->outputs[i].is_constant) {
TF_RETURN_IF_ERROR(
SetOutputForConstant(ctx, stream, compilation_result, i));
} else if (type == DT_RESOURCE) {
int input_index =
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
const TensorShape& shape = kernel->outputs[i].shape;
const DataType& type = kernel->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
int input_index =
kernel->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
}
if (VLOG_IS_ON(3)) {
@ -502,34 +549,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(kernel->resource_updates.size());
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
TF_ASSIGN_OR_RETURN(
std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
@ -543,7 +570,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots, write.type,
compilation_result->input_mapping, resource_var_snapshots, write.type,
write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified;
@ -555,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
<< "Need to hold the lock on resource variables "
"before calling BuildXlaCompilerArguments";
variable_info_lookup.emplace(info.index(), &info);
}
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
XlaCompiler::Argument& arg = (*args)[input_num];
if (constant_args.count(input_num) > 0) {
// Handles compile-time constants.
const Tensor& input = constant_args.at(input_num);
@ -569,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = input.dtype();
arg.shape = input.shape();
arg.constant_value = input;
} else if (variable_args.count(input_num) == 0) {
} else if (variable_info_lookup.count(input_num) == 0) {
// Handles the non-constant arguments.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
@ -585,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
// Handles resource variables.
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
const OptionalTensor& variable = variable_args.at(input_num);
arg.name = variable.name;
const VariableInfo& variable = *variable_info_lookup[input_num];
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.present) {
const Tensor& value = variable.value;
arg.type = value.dtype();
arg.shape = value.shape();
if (variable.var()) {
const Tensor* value = variable.var()->tensor();
arg.type = value->dtype();
arg.shape = value->shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since

View File

@ -34,36 +34,17 @@ limitations under the License.
namespace tensorflow {
// Struct that represents a possibly-absent Tensor.
struct OptionalTensor {
string name; // A descriptive name
bool present = false; // Is the tensor present?
Tensor value; // If present, what is the Tensor's value?
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
//
// Returns a map of TensorFlow argument index to resource variable. If a
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::map<int, OptionalTensor>* result);
// Snapshot of resource variables for a TF kernel invocation, mapping from
// parameter number to values at execution time. If the resource variable is not
// initialized, the value will not be present.
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
// Information about the state of a variable passed as input to the _XlaCompile
// and _XlaRun operators. Unlocks the resource variable and decrements its
// refcount on destruction.
class VariableInfo {
public:
explicit VariableInfo(int index, Var* var);
explicit VariableInfo(int index, absl::string_view name, Var* var);
VariableInfo(VariableInfo&& other);
VariableInfo& operator=(VariableInfo&& other);
@ -79,6 +60,9 @@ class VariableInfo {
// "empty", i.e. it does not track a resource variable.
Var* var() const { return var_; }
// Returns the variable name.
absl::string_view name() const { return name_; }
// Returns true if the resource variable lock was successfully acquired by
// this thread.
bool lock_held() const { return lock_held_; }
@ -88,6 +72,7 @@ class VariableInfo {
private:
int index_;
std::string name_;
Var* var_;
// We can't use a optional<mutex_lock> here because it confuses the compiler's
@ -96,6 +81,20 @@ class VariableInfo {
bool lock_held_ = false;
};
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
// important that the shapes used for compilation match the true shapes of the
// buffers.
//
// We snapshot the entire set of resource variables as one atomic operation.
// This models Read->* dependencies between resource variable operations. See
// jit/resource_operation_safety_analysis for details.
Status SnapshotResourceVariables(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
absl::Span<VariableInfo const> variable_infos,
ResourceVarsSnapshot* result);
// Acquires the mutexes for all the variables in `variables` using a
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
//
@ -104,6 +103,13 @@ class VariableInfo {
Status LockVariables(absl::Span<VariableInfo> variables)
TF_EXCLUSIVE_LOCK_FUNCTION();
// Returns a vector of VariableInfo instances for the resource variable inputs
// to the kernel with context `ctx`. The input indices for the resource
// variable inputs are in `variable_indices`.
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result);
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
// ShapedBuffers suitable for passing to an XLA computation.
class XlaComputationLaunchContext {
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
// op.
// Precondition: variables in `variable_args` are locked.
static Status BuildXlaCompilerArguments(
const std::map<int, Tensor>& constant_args,
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
@ -136,8 +143,8 @@ class XlaComputationLaunchContext {
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables,
int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
@ -148,13 +155,14 @@ class XlaComputationLaunchContext {
// See jit/resource_operation_safety_analysis for details.
//
//
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
// missing and adjusts input indices accordingly.
// Assumes that the first `missing_ctx_input_prefix` inputs to the
// compilation_result are missing and adjusts input indices accordingly.
Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots);
const ResourceVarsSnapshot& resource_var_snapshots);
// Return the argument list. Only valid after PopulateInputs() has been
// called.

View File

@ -27,10 +27,6 @@ namespace tensorflow {
return xla_tensor;
}
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor);

View File

@ -39,8 +39,6 @@ class XlaTensor {
// fails.
static XlaTensor* FromTensor(const Tensor* tensor);
static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is
@ -57,7 +55,7 @@ class XlaTensor {
// manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
const xla::ShapedBuffer& shaped_buffer() const {
@ -70,8 +68,7 @@ class XlaTensor {
}
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
shaped_buffer_ = std::move(shaped_buffer);
}
// Some tensors on the device may have known values on the host. We use these
@ -79,14 +76,12 @@ class XlaTensor {
// host value already.
// Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
bool has_host_tensor() const { return host_tensor_.has_value(); }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()
const Tensor& host_tensor() const { return *host_tensor_; }
// Sets the contained host tensor.
void set_host_tensor(const Tensor& tensor) {
host_tensor_.reset(new Tensor(tensor));
}
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
// Adds synchronization events to 'stream' that wait for this tensor to be
// defined on 'stream'. Does nothing if the tensor is already defined on that
@ -113,9 +108,9 @@ class XlaTensor {
private:
// The optional contained ShapedBuffer.
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
absl::optional<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -42,7 +42,7 @@ cc_library(
":init_mlir",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -102,7 +102,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
@ -155,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",

View File

@ -0,0 +1,265 @@
# MLIR CodeGen for XLA
<!--*
# Document freshness: For more information, see go/fresh-source.
freshness: { owner: 'timshen' reviewed: '2020-06-16' }
*-->
XLA operates on `HloInstruction` and performs many optimizations on this
representation, sharing a lot of these between targeted devices. As some point a
linear schedule is computed and the memory buffer is assigned to each value
statically. The device specific codegen operates by traversing this sequence and
calling "emitters" to generate a representation suitable for the device (for
example a single LLVM function per XLA computation on CPU, or a sequence of
"thunks" encapsulating GPU operations and possibly generated PTX when targeting
GPU).
As a staging step, we're currently in the process of intercepting the process
right after XLA completes the buffer-assignment phase and emit instead an MLIR
module in the `lhlo` dialect. From there we perform the codegen using MLIR
components (Linalg, affine, and GPU dialect mainly) depending on the device.
Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as
the codegen input.
## Tasks
| | Host | Device
| ------------- | ------------------------ | ------------------------
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
* **Task 1** changes both host and device input format from HloInstruction* to
LHLO.
* **Task 2** changes output format of host from thunks to "some landing pad
for host" (see below).
* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's
optional to this project, and see the section "Migrating Device LLVM IR" for
details.
This project prioritizes having end-to-end runnable models with LHLO-emitters
enabled as much as possible. This implies that the following order list of
objectives by priority:
* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters
unmodified.
* Eliminate the references to HloInstruction\* in LHLO, case by case:
* Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or
* Mechanically translate the existing emitter to take MLIR representation
(migrate to Standard with GPU Dialect).
## Migrating Thunks (Task 2)
xla::gpu::Thunk is a data structure that:
* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()).
* Carries various data in its subclasses.
* Interacts with BufferAllocation::Slice and StreamExecutor.
* Launches kernels
* Calls into all runtime libraries.
The cost of that includes:
* Representing op-specific configuration data (e.g. convolution configs).
* Migrating op shape and operand shapes.
* Representing a tree of thunks (while, condition, etc).
The migration work is independent from LHLO / emitter migration. Under limited
resources, it's prioritized behind LHLO / emitter migration.
We have several choices on how to lower the host-side part from LHLO:
* TFRT
* (Pro) great CUDA and HIP wrappers for use.
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
TFRT ops are interpreted by C++ code.
* (Con) host side is under development and not tested.
* (Con) the JAX integration isnt clear from a runtime point of view
* Jitted CPU code
* (Pro) great lower-ability. Create a few loops and conditions and it's
done.
* (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device
allocation.
* (Con) CUDA / HIP runtime support is minimal (toolkit path, version,
dynamic loading, etc).
* Existing (interpreting) XLA runtime
Tentative conclusion: Use jitted CPU code during the transition, and optionally
adopt TFRT in the end.
## Migrating Device LLVM IR (Task 3)
An elemental emitter generates target op by filling it element by element. Each
output element depends on a set of elements from the operands. All elements are
described by combining the buffer with dynamic indices. It's sufficient to
describe almost all "math" ops, but for performance reasons only a large subset
of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter.
ElementalIrEmitter is unique in that:
* A large portion of the code is shared between XLA/GPU and CPU.
* It represents a large portion of ops seen in models, including all
element-wise ops.
* Most fusions solely depend on ElementalIrEmitter.
* It's structurally simple, as it describes a data dependency DAG between op
elements and operand elements.
* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy).
* Dynamic shape support is easy for at least element-wise ops.
Now, for all ops, elementally-emitted or not, there are several flavors of the
end state of each XLA op:
1. Device code stays as LLVM IR.
1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect:
* (Cost) Will be throw-away work if we want to ultimately migrate to
Standard.
* (Benefit) It is easy and mechanical. Can be done in a short period.
* (Benefit) It doesn't benefit more compared to a).
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
* (Cost) Lifting existing emitters to Standard introduces some challenges.
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
amdgpu completeness is another one.
* (Cost) XLA/GPU heavily relies on LLVM metadata:
* `range` for block/thread indices.
* `align`, `dereferenceable`, `invariant.load`, `alias.scope`,
`noalias` for load/stores.
* `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`,
`llvm.loop.vectorize.enable` for sequential loops.
* (Benefit) Can be long-term. More portable.
1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters
* (Cost) This is case by case. Compared to previous options, a new
implementation that matches XLA's performance needs to go through the
benchmark <-> optimize workflow, which can be a significant cost for
some ops.
* (Benefit) unified stack; community support; portability; more
optimization potentials.
## Prioritization
While all three tasks mentioned above are parallelizable, under limited
resources they have to be serialized. The prioritization focuses on visible
results for completion of each task.
The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task
3 (MLIR emitters).
By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator)
and execute them. The compilation format will not be serializable MLIR.
By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables
offline compilation.
By the end of Task 3, all XLA emitters are MLIR-based in its implementation.
## Detailed Design
### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO
This step makes all existing XLA/GPU emitters interact with MLIR ops. This step
is pure refactoring and NFC.
This step is mostly mechanical, but it's worth noticing the following
discrepancies between an unnested HloComputation and LHLO:
* Each HloInstruction has direct access to its operands (a data-flow DAG). On
contrary, each LHLO op only has access to its operand buffers (a bipartite
between ops and buffers). LHLO ops have to go through use-def chains to
access their operand ops.
* Unnested legacy emitters empirically almost never access their operands. The
only exception is kReduce.
* Unnested legacy emitters access BufferAssignment only for getting slices,
not for accessing aux data structures like dataflow\_analysis() or
alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice
information.
The conclusion is that LHLO should fit right-in without major hassle.
### Step 2: (Optional) Profiling Support
**This step is only needed if we start to discard some of the XLA Thunk logic
(see the next step).**
Before actually turning on any MLIR-based emitters, we need profiling for
MLIR-based emitters.
Currently XLA performs its own profiling by calling into StreamExecutor's timer.
The timer under the hood inserts two events before and after a kernel launch,
and measures the sync time between these two events.
There are roughly three approaches to support profiling in MLIR:
* Run a profiler end-to-end
* Add a profile op for each op in LHLO, using an injected profiler.
The "end-to-end" approach is transparent to MLIR, but suffers the same problem
that makes XLA not use it in the first place: library calls collected by a
profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN
launches multiple kernels for each HLO, and it's hard to tell which kernels
correspond to which HLO.
The "injected profiler" approach requires:
* LHLO to take a profiler as a parameter.
* inserting profile.start / profile.end before and after each op.
* a pass from that lowers profile.{start,end} to a C++ implementation.
The exact profiling can't be easily done for MLIR-generated ops, since:
* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor.
* MLIR doesn't easily call into C functions with complicated parameters.
### Step 3: (Task 2) Migrating Thunks
This step migrates all host ops and library calls. This step will eliminate most
of the thunks and produce serializable MLIR instead.
There are roughly three kinds of thunks:
* KernelThunk, which launches a kernel.
* Control flow thunks, which has host control flow logic (conditional, while,
for, sequence) and launch body kernels.
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
The **bottom line** is to:
* Create a Thunk dialect that provides (de)serialize logic for all existing
C++-based Thunks.
* Change emitters to emit a graph of Thunk dialect.
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
step requires profiling and stream support.
### Step 4: (Task 3) Migrated ElementalIrEmitter
Once profiling is ready, we can complete and tune all ElementalIrEmitter-based
emitters in MLIR. Then we turn them on by default, assuming that all of these
MLIR-based emitters use a single stream.
Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well,
since they share a large portion of the code.
With all benchmarking and performance hunting done (TODO: define performance
parity), we turn on the new MLIR-based elemental emitter, and delete the legacy
ElementalIrEmitter.
This step also provides easy fusion transitions (nested ops) for the later
migration.
### Step 5: Multi-Stream Support or Drop
We can't delete
[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140)
until we support it in MLIR, or we drop the feature. It's a relatively large
amount of work in MLIR and a small amount of gain for XLA. We should investigate
current users of multi-stream XLA/GPU users, and try to delete this feature if
reasonable.
### Step 6: (Task 3) Migrated Device Ops
This step migrates all unnested ops, then we can delete all unnested emitters.
This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already
worked on for plenty, so the actual amount of work that needs to be done remains
to be seen.

Some files were not shown because too many files have changed in this diff Show More