commit
a28bac9244
.bazelrc.bazelversion
.github
README.mdRELEASE.mdWORKSPACEconfigure.pytensorflow
BUILD
c
BUILDc_api.hc_api_experimental.ccc_api_function.ccc_api_test.cctensor_interface.htf_tensor.cctf_tensor.htf_tstring.h
eager
BUILDabstract_context.habstract_function.habstract_operation.habstract_tensor_handle.hc_api.ccc_api_experimental.ccc_api_experimental.hc_api_experimental_test.ccc_api_unified_experimental.ccc_api_unified_experimental.hc_api_unified_experimental_eager.ccc_api_unified_experimental_graph.ccc_api_unified_experimental_internal.hc_api_unified_experimental_test.ccimmediate_execution_context.himmediate_execution_operation.himmediate_execution_tensor_handle.h
env.ccenv.hparallel_device
parallel_device.ccparallel_device_lib.ccparallel_device_lib.hparallel_device_test.ccparallel_device_testlib.cc
tfe_context_internal.htfe_op_internal.htfe_tensorhandle_internal.hexperimental
filesystem/plugins/gcs
saved_model
cc
gradients
saved_model
compiler
114
.bazelrc
114
.bazelrc
@ -30,6 +30,7 @@
|
||||
# short_logs: Only log errors during build, skip warnings.
|
||||
# monolithic: Build all TF C++ code into a single shared object.
|
||||
# dynamic_kernels: Try to link all kernels dynamically (experimental).
|
||||
# libc++: Link against libc++ instead of stdlibc++
|
||||
#
|
||||
#
|
||||
# TF version options;
|
||||
@ -38,6 +39,7 @@
|
||||
#
|
||||
# Feature and Third party library support options:
|
||||
# xla: Build TF with XLA
|
||||
# tpu: Build TF with TPU support
|
||||
# using_cuda: CUDA is available to build system.
|
||||
# cuda: Build with full cuda support.
|
||||
# rocm: Build with AMD GPU support (rocm).
|
||||
@ -56,13 +58,12 @@
|
||||
#
|
||||
#
|
||||
# Remote build execution options (only configured to work with TF team projects for now.)
|
||||
# rbe: General RBE options shared by all flavors.
|
||||
# rbe_linux: General RBE options used on all linux builds.
|
||||
# rbe_win: General RBE options used on all windows builds.
|
||||
# rbe: General RBE options shared by all flavors.
|
||||
# rbe_linux: General RBE options used on all linux builds.
|
||||
# rbe_win: General RBE options used on all windows builds.
|
||||
#
|
||||
# rbe_cpu_linux: RBE options to build with only CPU support.
|
||||
# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc.
|
||||
# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc
|
||||
# rbe_cpu_linux: RBE options to build with only CPU support.
|
||||
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
|
||||
#
|
||||
# rbe_linux_py2: Linux Python 2 RBE config.
|
||||
# rbe_linux_py3: Linux Python 3 RBE config
|
||||
@ -79,6 +80,14 @@
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
build:libc++ --action_env=CC
|
||||
build:libc++ --action_env=CXX
|
||||
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
|
||||
build:libc++ --action_env=PATH
|
||||
build:libc++ --define force_libcpp=enabled
|
||||
build:libc++ --linkopt -fuse-ld=lld
|
||||
|
||||
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
|
||||
# target CPU to build transient dependencies correctly. See
|
||||
@ -171,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
|
||||
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
|
||||
build:dbg --copt -DDEBUG_BUILD
|
||||
|
||||
# Config to build TPU backend
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:tensorrt --action_env TF_NEED_TENSORRT=1
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
@ -200,6 +212,8 @@ build:nogcp --define=no_gcp_support=true
|
||||
build:nohdfs --define=no_hdfs_support=true
|
||||
build:nonccl --define=no_nccl_support=true
|
||||
|
||||
build:stackdriver_support --define=stackdriver_support=true
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
@ -385,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
|
||||
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
|
||||
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
|
||||
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
|
||||
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
|
||||
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
|
||||
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
|
||||
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
|
||||
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
|
||||
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
|
||||
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
|
||||
|
||||
# Map default to CUDA 10.1.
|
||||
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
|
||||
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5
|
||||
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6
|
||||
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7
|
||||
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8
|
||||
|
||||
# Deprecated configs that people might still use.
|
||||
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
|
||||
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
|
||||
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
|
||||
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
|
||||
@ -429,8 +458,23 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF
|
||||
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
|
||||
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
|
||||
|
||||
common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
|
||||
# ROCm
|
||||
build:rbe_linux_rocm_base --config=rbe_linux
|
||||
build:rbe_linux_rocm_base --repo_env=TF_NEED_ROCM=1
|
||||
build:rbe_linux_rocm_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain"
|
||||
build:rbe_linux_rocm_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain-linux-x86_64"
|
||||
build:rbe_linux_rocm_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
|
||||
build:rbe_linux_rocm_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
|
||||
build:rbe_linux_rocm_base --platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
|
||||
build:rbe_linux_rocm_base --action_env=TF_ROCM_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm"
|
||||
build:rbe_linux_rocm_base --define=using_rocm_hipcc=true
|
||||
build:rbe_linux_rocm_py2.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python2.7"
|
||||
build:rbe_linux_rocm_py3.5 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.5"
|
||||
build:rbe_linux_rocm_py3.6 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.6"
|
||||
build:rbe_linux_rocm_py3.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.7"
|
||||
build:rbe_linux_rocm_py3.8 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.8"
|
||||
|
||||
# Linux CPU
|
||||
build:rbe_linux_py2 --config=rbe_linux
|
||||
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
|
||||
build:rbe_linux_py2 --python_path="/usr/bin/python2"
|
||||
@ -441,8 +485,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
|
||||
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
|
||||
|
||||
build:rbe_win --config=rbe
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
|
||||
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
|
||||
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
|
||||
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
|
||||
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
|
||||
|
@ -1 +1 @@
|
||||
3.0.0
|
||||
3.1.0
|
||||
|
13
.github/bot_config.yml
vendored
13
.github/bot_config.yml
vendored
@ -27,6 +27,19 @@ assignees:
|
||||
# A list of assignees for compiler folder
|
||||
compiler_assignees:
|
||||
- joker-eph
|
||||
# filesystem path
|
||||
filesystem_path:
|
||||
- tensorflow/c/experimental/filesystem
|
||||
# security path
|
||||
security_path:
|
||||
- tensorflow/security
|
||||
# words checklist
|
||||
segfault_memory:
|
||||
- segfault
|
||||
- memory leaks
|
||||
# assignees
|
||||
filesystem_security_assignee:
|
||||
- mihaimaruseac
|
||||
# Cuda Comment
|
||||
cuda_comment: >
|
||||
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
|
||||
|
2
.github/stale.yml
vendored
2
.github/stale.yml
vendored
@ -23,7 +23,7 @@
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
|
@ -61,7 +61,6 @@ commands.
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
```shell
|
||||
@ -96,6 +95,7 @@ for general questions and discussion, and please direct specific questions to
|
||||
The TensorFlow project strives to abide by generally accepted best practices in
|
||||
open-source software development:
|
||||
|
||||
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
|
||||
[](https://bestpractices.coreinfrastructure.org/projects/1486)
|
||||
[](CODE_OF_CONDUCT.md)
|
||||
|
||||
@ -114,6 +114,12 @@ Build Type | Status
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
|
69
RELEASE.md
69
RELEASE.md
@ -1,3 +1,63 @@
|
||||
# Release 2.4.0
|
||||
|
||||
<INSERT SMALL BLURB ABOUT RELEASE FOCUS AREA AND POTENTIAL TOOLCHAIN CHANGES>
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
* <DOCUMENT BREAKING CHANGES HERE>
|
||||
* <THIS SECTION SHOULD CONTAIN API, ABI AND BEHAVIORAL BREAKING CHANGES>
|
||||
* The byte layout for string tensors across the C-API has been updated to match
|
||||
TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s.
|
||||
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
|
||||
`TF_StringEncodedSize` are no longer relevant and have been removed; see
|
||||
core/platform/ctstring.h for string access/modification in C.
|
||||
|
||||
## Known Caveats
|
||||
|
||||
* <CAVEATS REGARDING THE RELEASE (BUT NOT BREAKING CHANGES). E.G. ADDING A NEW DEPENDENCY, BUMPING A DEPENDENCY NUMBER, LACK OF SUPPORT ON SOME PLATFORM, ETC>
|
||||
|
||||
## Major Features and Improvements
|
||||
|
||||
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
|
||||
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
* `tf.distribute`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.keras`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.function`/AutoGraph:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.lite`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.random`:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Math and Linear Algebra:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* TPU Enhancements:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* XLA Support:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
stjohnso98, <NAME>, <HERE>, <USING>, <GITHUB>, <HANDLE>
|
||||
|
||||
# Release 2.3.0
|
||||
|
||||
## Breaking Changes
|
||||
@ -11,6 +71,15 @@
|
||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* `tf.keras`:
|
||||
* Deprecated the `tf.keras.experimental.PeepholeLSTMCell` layer, which was
|
||||
moved to `tensorflow_addons` as
|
||||
`tensorflow_addons.rnn.PeepholeLSTMCell`. This experimental API is
|
||||
expected to be removed from TF in the next public release (2.4).
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
12
WORKSPACE
12
WORKSPACE
@ -36,18 +36,6 @@ load(
|
||||
|
||||
bazel_toolchains_repositories()
|
||||
|
||||
load(
|
||||
"@io_bazel_rules_docker//repositories:repositories.bzl",
|
||||
container_repositories = "repositories",
|
||||
)
|
||||
|
||||
container_repositories()
|
||||
|
||||
load("//third_party/toolchains/preconfig/generate:workspace.bzl",
|
||||
"remote_config_workspace")
|
||||
|
||||
remote_config_workspace()
|
||||
|
||||
# Use `swift_rules_dependencies` to fetch the toolchains. With the
|
||||
# `git_repository` rules above, the following call will skip redefining them.
|
||||
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
|
||||
|
@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
|
||||
_TF_WORKSPACE_ROOT = ''
|
||||
_TF_BAZELRC = ''
|
||||
_TF_CURRENT_BAZEL_VERSION = None
|
||||
_TF_MIN_BAZEL_VERSION = '2.0.0'
|
||||
_TF_MIN_BAZEL_VERSION = '3.1.0'
|
||||
_TF_MAX_BAZEL_VERSION = '3.99.0'
|
||||
|
||||
NCCL_LIB_PATHS = [
|
||||
@ -480,7 +480,7 @@ def check_bazel_version(min_version, max_version):
|
||||
"""
|
||||
if which('bazel') is None:
|
||||
print('Cannot find bazel. Please install bazel.')
|
||||
sys.exit(0)
|
||||
sys.exit(1)
|
||||
|
||||
stderr = open(os.devnull, 'wb')
|
||||
curr_version = run_shell(['bazel', '--version'],
|
||||
|
@ -298,6 +298,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Experimental features
|
||||
config_setting(
|
||||
name = "stackdriver_support",
|
||||
define_values = {"stackdriver_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Crosses between platforms and file system libraries not supported on those
|
||||
# platforms due to limitations in nested select() statements.
|
||||
config_setting(
|
||||
@ -460,6 +467,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# This flag enables experimental TPU support
|
||||
config_setting(
|
||||
name = "with_tpu_support",
|
||||
values = {"define": "with_tpu_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Specifies via a config setting if this is a mobile build or not, makes
|
||||
# it easier to combine settings later.
|
||||
selects.config_setting_group(
|
||||
@ -531,6 +545,7 @@ package_group(
|
||||
|
||||
# Packages that use composite tensors or dispatch.
|
||||
# TODO(b/154762408) Remove this package group once it's no longer needed.
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "composite_tensor_whitelist")
|
||||
|
||||
# Packages that use private types symbols, until they are exported.
|
||||
@ -540,6 +555,11 @@ package_group(
|
||||
packages = ["//learning/deepmind/tensorflow/replicator/..."],
|
||||
)
|
||||
|
||||
# Packages that use StructuredTensors.
|
||||
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
|
||||
# If this is modified, then copy.bara.sky must also be modified.
|
||||
package_group(name = "structured_tensor_whitelist")
|
||||
|
||||
filegroup(
|
||||
name = "intel_binary_blob",
|
||||
data = if_mkl_ml(
|
||||
|
@ -29,6 +29,8 @@ filegroup(
|
||||
"tf_file_statistics.h",
|
||||
"tf_status.h",
|
||||
"tf_tensor.h",
|
||||
"tf_tstring.h",
|
||||
"//tensorflow/core/platform:ctstring",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
@ -48,6 +50,7 @@ filegroup(
|
||||
"*test*",
|
||||
],
|
||||
) + [
|
||||
"//tensorflow/core/platform:ctstring",
|
||||
"//tensorflow/cc:srcs_no_runtime",
|
||||
"//tensorflow/core/distributed_runtime:server_lib.h",
|
||||
],
|
||||
@ -78,6 +81,7 @@ tf_cuda_library(
|
||||
"c_api_internal.h",
|
||||
"tf_datatype.h",
|
||||
"tf_tensor.h",
|
||||
"tf_tstring.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
@ -173,6 +177,7 @@ tf_cuda_library(
|
||||
copts = tf_copts(),
|
||||
visibility = [
|
||||
"//tensorflow/c:__subpackages__",
|
||||
"//tensorflow/python:__subpackages__",
|
||||
"//third_party/llvm/llvm-project:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/c/tf_tstring.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// C API for TensorFlow.
|
||||
|
@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
||||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
|
@ -54,7 +54,7 @@ Status ProcessInputs(
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
Node* node = &inputs[i].oper->node;
|
||||
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
|
||||
int idx = inputs[i].index;
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
Node* node = &outputs[i].oper->node;
|
||||
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
|
||||
int idx = outputs[i].index;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(node, idx),
|
||||
|
@ -2286,14 +2286,15 @@ TEST_F(CApiAttributesTest, Tensor) {
|
||||
|
||||
TEST_F(CApiAttributesTest, StringTensor) {
|
||||
// Create the string-Tensor "attribute" value.
|
||||
char encoded[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
|
||||
1, // varint encoded string length
|
||||
'A',
|
||||
};
|
||||
const char test_string[] =
|
||||
"borkborkborkborkborkborkborkbork"; // >24bytes to force heap alloc
|
||||
TF_TString tstr[1];
|
||||
TF_TString_Init(&tstr[0]);
|
||||
TF_TString_Copy(&tstr[0], test_string, sizeof(test_string) - 1);
|
||||
|
||||
auto deallocator = [](void* data, size_t len, void* arg) {};
|
||||
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &encoded[0],
|
||||
sizeof(encoded), deallocator, nullptr),
|
||||
unique_tensor_ptr t_in(TF_NewTensor(TF_STRING, nullptr, 0, &tstr[0],
|
||||
sizeof(tstr), deallocator, nullptr),
|
||||
TF_DeleteTensor);
|
||||
|
||||
// Create a TF_Operation with the attribute t_in
|
||||
@ -2312,9 +2313,17 @@ TEST_F(CApiAttributesTest, StringTensor) {
|
||||
EXPECT_EQ(TF_STRING, TF_TensorType(t_out));
|
||||
EXPECT_EQ(0, TF_NumDims(t_out));
|
||||
ASSERT_EQ(TF_TensorByteSize(t_in.get()), TF_TensorByteSize(t_out));
|
||||
EXPECT_EQ(0, memcmp(TF_TensorData(t_in.get()), TF_TensorData(t_out),
|
||||
TF_TensorByteSize(t_out)));
|
||||
TF_TString* t_in_tstr = static_cast<TF_TString*>(TF_TensorData(t_in.get()));
|
||||
TF_TString* t_out_tstr = static_cast<TF_TString*>(TF_TensorData(t_out));
|
||||
EXPECT_EQ(absl::string_view(test_string),
|
||||
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||
TF_TString_GetSize(t_out_tstr)));
|
||||
EXPECT_EQ(absl::string_view(TF_TString_GetDataPointer(t_in_tstr),
|
||||
TF_TString_GetSize(t_in_tstr)),
|
||||
absl::string_view(TF_TString_GetDataPointer(t_out_tstr),
|
||||
TF_TString_GetSize(t_out_tstr)));
|
||||
TF_DeleteTensor(t_out);
|
||||
TF_TString_Dealloc(&tstr[0]);
|
||||
}
|
||||
|
||||
TEST_F(CApiAttributesTest, TensorList) {
|
||||
|
@ -38,9 +38,10 @@ tf_cuda_library(
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
@ -101,13 +102,17 @@ tf_cuda_library(
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"abstract_context.h",
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
@ -153,9 +158,13 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
@ -163,12 +172,22 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_tensor_handle",
|
||||
hdrs = ["immediate_execution_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -177,13 +196,13 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operation_interface",
|
||||
hdrs = ["operation_interface.h"],
|
||||
name = "abstract_operation",
|
||||
hdrs = ["abstract_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -193,16 +212,58 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "context_interface",
|
||||
hdrs = ["context_interface.h"],
|
||||
name = "immediate_execution_operation",
|
||||
hdrs = ["immediate_execution_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_context",
|
||||
hdrs = ["abstract_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_function",
|
||||
":abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_function",
|
||||
hdrs = ["abstract_function.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -218,7 +279,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
":immediate_execution_context",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -278,7 +339,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":immediate_execution_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -301,7 +362,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
@ -481,6 +542,12 @@ tf_cuda_library(
|
||||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
":abstract_operation",
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_context",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -499,6 +566,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -672,6 +740,10 @@ filegroup(
|
||||
],
|
||||
exclude = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
82
tensorflow/c/eager/abstract_context.h
Normal file
82
tensorflow/c/eager/abstract_context.h
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
// This serves as a factory for creating `AbstractOperation`s and for
|
||||
// registering traced functions.
|
||||
// Operations creation within a context can only be executed in that context
|
||||
// (for now at least).
|
||||
// Implementations of the context may contain some state e.g. an execution
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
public:
|
||||
AbstractContextKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Creates an operation builder and ties it to this context.
|
||||
// The returned object can be used for setting operation's attributes,
|
||||
// adding inputs and finally executing (immediately or lazily as in tracing)
|
||||
// it in this context.
|
||||
virtual AbstractOperation* CreateOperation() = 0;
|
||||
|
||||
// Registers a function with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual Status RegisterFunction(AbstractFunction*) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
private:
|
||||
const AbstractContextKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractContextDeleter {
|
||||
void operator()(AbstractContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
45
tensorflow/c/eager/abstract_function.h
Normal file
45
tensorflow/c/eager/abstract_function.h
Normal file
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A traced function: this hides the complexity of converting the serialized
|
||||
// representation between various supported formats e.g. FunctionDef and Mlir
|
||||
// function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraph, kMlir };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Returns the AbstractFunction as a FunctionDef.
|
||||
virtual Status GetFunctionDef(FunctionDef**) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
@ -12,24 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
// This interface allows building and executing an operation in either
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
public:
|
||||
AbstractOperationKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
@ -38,7 +45,6 @@ class AbstractOperationInterface {
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
@ -66,12 +72,10 @@ class AbstractOperationInterface {
|
||||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
virtual Status AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
virtual Status AddInput(AbstractTensorHandle* input) = 0;
|
||||
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
@ -82,7 +86,7 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) = 0;
|
||||
const AbstractOperation* value) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name,
|
||||
@ -102,19 +106,25 @@ class AbstractOperationInterface {
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) = 0;
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
private:
|
||||
const AbstractOperationKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractOperationDeleter {
|
||||
void operator()(AbstractOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOperationPtr =
|
||||
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
60
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
60
tensorflow/c/eager/abstract_tensor_handle.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
public:
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
@ -31,8 +33,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
@ -713,8 +715,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
return tensorflow::wrap(new tfrt::ContextInterface(
|
||||
op_handler_chains, device_attributes, opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::ImmediateExecutionOperation* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
||||
tensorflow::unwrap(value)),
|
||||
static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(retvals)),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
@ -1473,14 +1482,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
destination->CopyAttributes(*tensorflow::unwrap(attrs));
|
||||
}
|
||||
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
|
@ -38,7 +38,7 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::ImmediateExecutionOperation* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
|
||||
context->SetShouldStoreGraphs(false);
|
||||
}
|
||||
|
||||
uint64_t TFE_GetContextId(TFE_Context* ctx) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
return context->GetContextId();
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
|
@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
|
||||
bool use_tfrt);
|
||||
|
||||
// Returns the context_id from the EagerContext which is used by the
|
||||
// EagerService to maintain consistency between client and worker. The
|
||||
// context_id is initialized with a dummy value and is later set when the worker
|
||||
// is initialized (either locally or remotely). The context_id can change during
|
||||
// the process lifetime although this should cause the worker to be
|
||||
// reinitialized (e.g. cleared caches) as well.
|
||||
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Cancellation APIs.
|
||||
|
||||
|
@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
|
||||
TFE_DeleteCancellationManager(c_mgr);
|
||||
}
|
||||
|
||||
TEST(CAPI, ExecutorContextDestructionOrder) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
|
@ -22,15 +22,17 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::internal::OutputList;
|
||||
using tensorflow::internal::unwrap;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||
namespace tracing {
|
||||
typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
|
||||
|
||||
static FactoriesMap& GetFactories() {
|
||||
static FactoriesMap* factories = new FactoriesMap;
|
||||
@ -48,8 +50,8 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||
|
||||
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||
|
||||
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
TF_Status* s) {
|
||||
auto entry = GetFactories().find(default_factory);
|
||||
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||
string msg = absl::StrCat(
|
||||
@ -70,7 +72,7 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace tracing
|
||||
} // end namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
@ -83,43 +85,77 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::AbstractFunction;
|
||||
using tensorflow::AbstractTensorHandle;
|
||||
using tensorflow::DataType;
|
||||
using tensorflow::dyn_cast;
|
||||
using tensorflow::OutputList;
|
||||
using tensorflow::Status;
|
||||
using tensorflow::unwrap;
|
||||
using tensorflow::wrap;
|
||||
using tensorflow::tracing::CreateTracingExecutionContext;
|
||||
using tensorflow::tracing::SetDefaultTracingEngine;
|
||||
using tensorflow::tracing::TracingContext;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
using tensorflow::tracing::TracingTensorHandle;
|
||||
|
||||
void TF_SetTracingImplementation(const char* name) {
|
||||
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||
SetDefaultTracingEngine(name);
|
||||
}
|
||||
|
||||
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||
// given tracing context.
|
||||
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||
return wrap(CreateTracingExecutionContext(fn_name, s));
|
||||
}
|
||||
|
||||
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
TF_OutputList* outputs, TF_Status* s) {
|
||||
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||
AbstractFunction* func;
|
||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
|
||||
if (!tracing_ctx) {
|
||||
Set_TF_Status_from_Status(
|
||||
s, tensorflow::errors::InvalidArgument(
|
||||
"Only TracingContext can be converted into a function."));
|
||||
return nullptr;
|
||||
}
|
||||
Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
return func;
|
||||
return wrap(func);
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||
TracingTensorHandle* t;
|
||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||
if (!tracing_ctx) {
|
||||
Set_TF_Status_from_Status(
|
||||
s, tensorflow::errors::InvalidArgument(
|
||||
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||
return nullptr;
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
||||
return wrap(t);
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
return wrap(unwrap(c)->CreateOperation());
|
||||
return wrap((unwrap(c)->CreateOperation()));
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); }
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); }
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
unwrap(o)->expected_num_outputs = num_outputs;
|
||||
unwrap(o)->outputs.clear();
|
||||
unwrap(o)->outputs.resize(num_outputs);
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) {
|
||||
return unwrap(o)->outputs.size();
|
||||
@ -134,24 +170,46 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
unwrap(op)->SetOpType(op_type, s);
|
||||
Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
|
||||
/*raw_device_name=*/nullptr));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
unwrap(op)->SetOpName(op_name, s);
|
||||
TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
|
||||
if (!tracing_op) {
|
||||
Set_TF_Status_from_Status(
|
||||
s, tensorflow::errors::InvalidArgument(
|
||||
"TF_AbstractOpSetOpName must be called on a TracingOperation."));
|
||||
return;
|
||||
}
|
||||
Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
|
||||
TF_DataType value, TF_Status* s) {
|
||||
unwrap(op)->SetAttrType(attr_name, value, s);
|
||||
Status status =
|
||||
unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
|
||||
TF_SetStatus(s, static_cast<TF_Code>(status.code()),
|
||||
status.error_message().c_str());
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs),
|
||||
unwrap(o), s);
|
||||
TF_Status* s) {
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
int num_outputs = unwrap(o)->expected_num_outputs;
|
||||
Set_TF_Status_from_Status(
|
||||
s, unwrap(op)->Execute(
|
||||
absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
|
||||
unwrap(o)->outputs.data()),
|
||||
unwrap(o)->outputs.size()),
|
||||
&num_outputs));
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
@ -161,5 +219,5 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||
Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
|
||||
// Any active tape will observe the effects of this execution.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
TF_Status* s);
|
||||
|
||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||
@ -137,7 +137,8 @@ TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*);
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*,
|
||||
TF_Status* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
|
@ -15,180 +15,68 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Simple wrapper over a TFE_TensorHandle
|
||||
struct EagerTensor : public AbstractTensor {
|
||||
TFE_TensorHandle* t = nullptr;
|
||||
EagerTensor() : AbstractTensor(kKind) {}
|
||||
explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {}
|
||||
~EagerTensor() override { TFE_DeleteTensorHandle(t); }
|
||||
static constexpr AbstractTensorKind kKind = kEagerTensor;
|
||||
};
|
||||
|
||||
// Simple wrapper over a TFE_Op
|
||||
class EagerOp : public AbstractOp {
|
||||
public:
|
||||
explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
op_ = TFE_NewOp(ctx_, op_type, s);
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
// Name is ignored in eager mode.
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (op_ == nullptr) {
|
||||
TF_SetStatus(s, TF_FAILED_PRECONDITION,
|
||||
"op_type must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TFE_OpSetAttrType(op_, attr_name, value);
|
||||
}
|
||||
|
||||
~EagerOp() override { TFE_DeleteOp(op_); }
|
||||
static constexpr AbstractOpKind kKind = kEagerOp;
|
||||
|
||||
private:
|
||||
friend class EagerContext; // For access to op_.
|
||||
TFE_Op* op_ = nullptr;
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs.
|
||||
class EagerContext : public ExecutionContext {
|
||||
public:
|
||||
EagerContext() : ExecutionContext(kKind) {}
|
||||
|
||||
void Build(TFE_ContextOptions* options, TF_Status* status) {
|
||||
eager_ctx_ = TFE_NewContext(options, status);
|
||||
}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new EagerOp(eager_ctx_);
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* eager_op = dyncast<EagerOp>(op);
|
||||
if (eager_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_EagerOp.");
|
||||
return;
|
||||
}
|
||||
auto* tfe_op = eager_op->op_;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]);
|
||||
if (!eager_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, eager_tensor->t, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
o->outputs.push_back(new EagerTensor(retvals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't add function parameter on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Can't use finalize function on an eager context.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
return;
|
||||
}
|
||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||
}
|
||||
|
||||
~EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kEagerContext;
|
||||
|
||||
private:
|
||||
friend TFE_Context* ::TF_ExecutionContextGetTFEContext(
|
||||
TF_ExecutionContext* ctx);
|
||||
TFE_Context* eager_ctx_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
// =============================================================================
|
||||
// Public C API entry points
|
||||
// These are only the entry points specific to the Eager API.
|
||||
// =============================================================================
|
||||
|
||||
using tensorflow::internal::dyncast;
|
||||
using tensorflow::internal::unwrap;
|
||||
using tensorflow::AbstractContext;
|
||||
using tensorflow::AbstractTensorHandle;
|
||||
using tensorflow::dyn_cast;
|
||||
using tensorflow::ImmediateExecutionContext;
|
||||
using tensorflow::ImmediateExecutionTensorHandle;
|
||||
using tensorflow::string;
|
||||
using tensorflow::unwrap;
|
||||
using tensorflow::wrap;
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
|
||||
TF_Status* s) {
|
||||
auto* ctx = new tensorflow::internal::EagerContext();
|
||||
ctx->Build(options, s);
|
||||
return wrap(ctx);
|
||||
TFE_Context* c_ctx = TFE_NewContext(options, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
return wrap(static_cast<AbstractContext*>(unwrap(c_ctx)));
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
return wrap(new tensorflow::internal::EagerTensor(t));
|
||||
return wrap(static_cast<AbstractTensorHandle*>(unwrap(t)));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at));
|
||||
if (!eager_tensor) {
|
||||
string msg = tensorflow::strings::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at));
|
||||
if (!handle) {
|
||||
string msg =
|
||||
StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return eager_tensor->t;
|
||||
return wrap(handle);
|
||||
}
|
||||
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) {
|
||||
auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx));
|
||||
if (!eager_ctx) return nullptr;
|
||||
return eager_ctx->eager_ctx_;
|
||||
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx,
|
||||
TF_Status* s) {
|
||||
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx));
|
||||
if (!imm_ctx) {
|
||||
string msg =
|
||||
StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return wrap(imm_ctx);
|
||||
}
|
||||
|
@ -18,77 +18,198 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using tensorflow::dyn_cast;
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
namespace tracing {
|
||||
namespace graph {
|
||||
|
||||
class GraphContext;
|
||||
class GraphOperation;
|
||||
class GraphTensor;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
struct GraphTensor : public AbstractTensor {
|
||||
TF_Output output{};
|
||||
GraphContext* ctx = nullptr;
|
||||
GraphTensor() : AbstractTensor(kKind) {}
|
||||
GraphTensor(TF_Output output, GraphContext* ctx)
|
||||
: AbstractTensor(kKind), output(output), ctx(ctx) {}
|
||||
static constexpr AbstractTensorKind kKind = kGraphTensor;
|
||||
class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
void Release() override { delete this; }
|
||||
TF_Output output_;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
};
|
||||
|
||||
// GraphOp wraps and populate a TF_OperationDescription.
|
||||
class GraphOp : public AbstractOp {
|
||||
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||
class GraphOperation : public TracingOperation {
|
||||
public:
|
||||
explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {}
|
||||
void SetOpType(const char* const op_type, TF_Status* s) override {
|
||||
explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
|
||||
void Release() override { delete this; }
|
||||
Status Reset(const char* op, const char* raw_device_name) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpType called on already built op.").c_str());
|
||||
return;
|
||||
return errors::FailedPrecondition("Reset called on already built op.");
|
||||
}
|
||||
if (op_name_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type, op_name_));
|
||||
op_name_ = nullptr;
|
||||
} else {
|
||||
op_type_ = op_type;
|
||||
if (raw_device_name) {
|
||||
device_name_ = raw_device_name;
|
||||
}
|
||||
op_type_ = op;
|
||||
return Status::OK();
|
||||
}
|
||||
void SetOpName(const char* const op_name, TF_Status* s) override {
|
||||
Status SetOpName(const char* const op_name) override {
|
||||
if (op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
strings::StrCat("SetOpName called on already built op.").c_str());
|
||||
return;
|
||||
return errors::FailedPrecondition(
|
||||
"SetOpName called on already built op.");
|
||||
}
|
||||
if (op_type_ != nullptr) {
|
||||
op_.reset(TF_NewOperation(g_, op_type_, op_name));
|
||||
op_type_ = nullptr;
|
||||
} else {
|
||||
op_name_ = op_name;
|
||||
if (op_type_.empty()) {
|
||||
return errors::FailedPrecondition(
|
||||
"GraphOperation::Reset must be called before calling SetOpName.");
|
||||
}
|
||||
op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name));
|
||||
return Status::OK();
|
||||
}
|
||||
void SetAttrType(const char* const attr_name, TF_DataType value,
|
||||
TF_Status* s) override {
|
||||
if (!op_) {
|
||||
TF_SetStatus(
|
||||
s, TF_FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
return;
|
||||
}
|
||||
TF_SetAttrType(op_.get(), attr_name, value);
|
||||
}
|
||||
~GraphOp() override {}
|
||||
const string& Name() const override { return op_type_; }
|
||||
const string& DeviceName() const override { return device_name_; }
|
||||
|
||||
static constexpr AbstractOpKind kKind = kGraphOp;
|
||||
Status SetDeviceName(const char* name) override {
|
||||
// TODO(srbs): Implement this.
|
||||
device_name_ = name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddInput(AbstractTensorHandle* input) override {
|
||||
GraphTensor* t = dyn_cast<GraphTensor>(input);
|
||||
if (!t) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unable to cast input to GraphTensor");
|
||||
}
|
||||
TF_AddInput(op_.get(), t->output_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"AddInputList has not been implemented yet.");
|
||||
}
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override {
|
||||
auto* tf_opdesc = op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
return errors::InvalidArgument("AbstractOp is incomplete.");
|
||||
}
|
||||
TF_Status* s = TF_NewStatus();
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||
TF_DeleteStatus(s);
|
||||
*num_retvals = TF_OperationNumOutputs(operation);
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new GraphTensor({operation, i});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrString has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrInt has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFloat(const char* attr_name, float value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloat has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrBool(const char* attr_name, bool value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBool has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrType(const char* const attr_name, DataType value) override {
|
||||
if (!op_) {
|
||||
return Status(
|
||||
error::Code::FAILED_PRECONDITION,
|
||||
"op_type and op_name must be specified before specifying attrs.");
|
||||
}
|
||||
op_->node_builder.Attr(attr_name, value);
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShape has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrStringList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFloatList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrIntList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTypeList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrBoolList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrShapeList has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been implemented yet.");
|
||||
}
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
~GraphOperation() override {}
|
||||
|
||||
private:
|
||||
friend class GraphContext; // For access to op_.
|
||||
@ -96,123 +217,109 @@ class GraphOp : public AbstractOp {
|
||||
std::unique_ptr<TF_OperationDescription> op_;
|
||||
// Hold `op_type` and `op_name` till both are available since we need both
|
||||
// to build a graph operation.
|
||||
const char* op_type_ = nullptr;
|
||||
string op_type_;
|
||||
const char* op_name_ = nullptr;
|
||||
// TODO(srbs): Use this.
|
||||
string device_name_;
|
||||
};
|
||||
|
||||
// GraphFunction is a thin wrapper over a TF_Function.
|
||||
struct GraphFunction : public AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
GraphFunction() : AbstractFunction(kKind) {}
|
||||
GraphFunction() : AbstractFunction(kGraph) {}
|
||||
explicit GraphFunction(TF_Function* func)
|
||||
: AbstractFunction(kKind), func(func) {}
|
||||
: AbstractFunction(kGraph), func(func) {}
|
||||
~GraphFunction() override {
|
||||
if (func) TF_DeleteFunction(func);
|
||||
}
|
||||
|
||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||
Status GetFunctionDef(FunctionDef** fdef) override {
|
||||
*fdef = &func->fdef;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractFunction* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
};
|
||||
|
||||
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||
// "execution" of operation, i.e. adding them to the function.
|
||||
class GraphContext : public ExecutionContext {
|
||||
class GraphContext : public TracingContext {
|
||||
public:
|
||||
explicit GraphContext(const char* name)
|
||||
: ExecutionContext(kKind),
|
||||
: TracingContext(kGraph),
|
||||
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||
name_(name) {}
|
||||
|
||||
AbstractOp* CreateOperation() override {
|
||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||
return new GraphOp(graph_.get());
|
||||
void Release() override { delete this; }
|
||||
|
||||
TracingOperation* CreateOperation() override {
|
||||
return new GraphOperation(graph_.get());
|
||||
}
|
||||
|
||||
void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) override {
|
||||
auto* graph_op = dyncast<GraphOp>(op);
|
||||
if (graph_op == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Unable to cast AbstractOp to TF_GraphOp.");
|
||||
return;
|
||||
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
||||
auto operation = CreateOperation();
|
||||
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(operation->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
|
||||
|
||||
if (num_outputs != 1) {
|
||||
return errors::Internal("Expected 1 output but found ", num_outputs);
|
||||
}
|
||||
auto* tf_opdesc = graph_op->op_.release();
|
||||
if (tf_opdesc == nullptr) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete.");
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* graph_tensor = dyncast<GraphTensor>(inputs[i]);
|
||||
if (!graph_tensor) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (graph_tensor->ctx != this) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, graph_tensor->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
// TF_FinishOperation deletes `tf_opdesc` so clear its reference.
|
||||
graph_op->op_ = nullptr;
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
o->outputs.push_back(new GraphTensor({operation, i}, this));
|
||||
auto* t = dyn_cast<GraphTensor>(outputs[0]);
|
||||
if (!t) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Unable to cast input to GraphTensor");
|
||||
}
|
||||
inputs_.push_back(t->output_);
|
||||
*output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||
TF_OperationDescription* opdesc =
|
||||
TF_NewOperation(graph_.get(), "Placeholder",
|
||||
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||
auto* operation = TF_FinishOperation(opdesc, s);
|
||||
if (!s->status.ok()) return nullptr;
|
||||
|
||||
inputs_.push_back(TF_Output{operation, 0});
|
||||
return new GraphTensor(inputs_.back(), this);
|
||||
}
|
||||
|
||||
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||
Status Finalize(OutputList* outputs, AbstractFunction** f) override {
|
||||
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||
std::vector<TF_Output> graph_outputs;
|
||||
graph_outputs.reserve(outputs->outputs.size());
|
||||
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||
for (auto* abstract_output : outputs->outputs) {
|
||||
GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
|
||||
if (!output) {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
return nullptr;
|
||||
return errors::Unimplemented(
|
||||
"Returning a non-graph tensor from a function has not "
|
||||
"been implemented yet.");
|
||||
}
|
||||
graph_outputs.push_back(output->output);
|
||||
graph_outputs.push_back(output->output_);
|
||||
}
|
||||
|
||||
auto s = TF_NewStatus();
|
||||
func->func = TF_GraphToFunction(
|
||||
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
return func.release();
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
|
||||
TF_DeleteStatus(s);
|
||||
*f = func.release();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
Status RegisterFunction(AbstractFunction* func) override {
|
||||
return errors::Unimplemented(
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
|
||||
~GraphContext() override {}
|
||||
|
||||
static constexpr ExecutionContextKind kKind = kGraphContext;
|
||||
Status RemoveFunction(const string& func) override {
|
||||
return errors::Unimplemented(
|
||||
"GraphContext::RemoveFunction has not been implemented yet.");
|
||||
}
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||
@ -220,7 +327,7 @@ class GraphContext : public ExecutionContext {
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||
return new GraphContext(name);
|
||||
}
|
||||
|
||||
@ -231,5 +338,6 @@ static bool register_tracing = [] {
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace graph
|
||||
} // namespace tracing
|
||||
} // namespace tensorflow
|
||||
|
@ -19,6 +19,10 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -26,7 +30,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Represents the results of the execution of an operation.
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
namespace tracing {
|
||||
|
||||
// =============================================================================
|
||||
// Implementation detail for the unified execution APIs for Eager and tracing
|
||||
@ -37,165 +48,75 @@ namespace internal {
|
||||
// `c_api_unified_experimental.h` header.
|
||||
// =============================================================================
|
||||
|
||||
// We can't depend on C++ rtti, but we still want to be able to have a safe
|
||||
// dynamic_cast to provide diagnostics to the user when the API is misused.
|
||||
// Instead we model RTTI by listing all the possible subclasses for each
|
||||
// abstract base. Each subclass initializes the base class with the right
|
||||
// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this
|
||||
// utility.
|
||||
template <typename T, typename S>
|
||||
T* dyncast(S source) {
|
||||
if (source->getKind() != T::kKind) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensorflow::down_cast<T*>(source);
|
||||
}
|
||||
|
||||
// Represents either an EagerTensor or a GraphTensor.
|
||||
// Represents either a MlirTensor or a GraphTensor.
|
||||
// This base class does not expose any public methods other than to distinguish
|
||||
// which subclass it actually is. The user is responsible to use the right
|
||||
// type of AbstractTensor in their context (do not pass an EagerTensor to a
|
||||
// type of AbstractTensor in their context (do not pass an MlirTensor to a
|
||||
// GraphContext and vice-versa).
|
||||
class AbstractTensor {
|
||||
class TracingTensorHandle : public AbstractTensorHandle {
|
||||
protected:
|
||||
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
|
||||
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
|
||||
explicit TracingTensorHandle(AbstractTensorHandleKind kind)
|
||||
: AbstractTensorHandle(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractTensorKind getKind() const { return kind_; }
|
||||
virtual ~AbstractTensor() = default;
|
||||
|
||||
private:
|
||||
const AbstractTensorKind kind_;
|
||||
};
|
||||
|
||||
// Represents the results of the execution of an operation.
|
||||
struct OutputList {
|
||||
std::vector<AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
// Holds the result of tracing a function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Temporary API till we figure the right abstraction for AbstractFunction.
|
||||
// At the moment both Eager and Graph needs access to a "TF_Function" object.
|
||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||
}
|
||||
};
|
||||
|
||||
// An abstract operation describes an operation by its type, name, and
|
||||
// attributes. It can be "executed" by the context with some input tensors.
|
||||
// It is allowed to reusing the same abstract operation for multiple execution
|
||||
// on a given context, with the same or different input tensors.
|
||||
class AbstractOp {
|
||||
class TracingOperation : public AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
|
||||
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
|
||||
explicit TracingOperation(AbstractOperationKind kind)
|
||||
: AbstractOperation(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractOpKind getKind() const { return kind_; }
|
||||
virtual ~AbstractOp() = default;
|
||||
|
||||
// Sets the type of the operation (for example `AddV2`).
|
||||
virtual void SetOpType(const char* op_type, TF_Status* s) = 0;
|
||||
|
||||
// Sets the name of the operation: this is an optional identifier that is
|
||||
// not intended to carry semantics and preserved/propagated without
|
||||
// guarantees.
|
||||
virtual void SetOpName(const char* op_name, TF_Status* s) = 0;
|
||||
virtual Status SetOpName(const char* op_name) = 0;
|
||||
|
||||
// Add a `TypeAttribute` on the operation.
|
||||
virtual void SetAttrType(const char* attr_name, TF_DataType value,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractOpKind kind_;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||
}
|
||||
};
|
||||
|
||||
// This holds the context for the execution: dispatching operations either to an
|
||||
// eager implementation or to a graph implementation.
|
||||
struct ExecutionContext {
|
||||
// MLIR implementation or to a graph implementation.
|
||||
class TracingContext : public AbstractContext {
|
||||
protected:
|
||||
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
|
||||
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
|
||||
explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
ExecutionContextKind getKind() const { return k; }
|
||||
virtual ~ExecutionContext() = default;
|
||||
|
||||
// Executes the operation on the provided inputs and populate the OutputList
|
||||
// with the results. The input tensors must match the current context.
|
||||
// The effect of "executing" an operation depends on the context: in an Eager
|
||||
// context it will dispatch it to the runtime for execution, while in a
|
||||
// tracing context it will add the operation to the current function.
|
||||
virtual void ExecuteOperation(AbstractOp* op, int num_inputs,
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
|
||||
// Creates an empty AbstractOperation suitable to use with this context.
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||
// it'll always error out with an eager context.
|
||||
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||
virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
|
||||
|
||||
// Registers a functions with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const ExecutionContextKind k;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
|
||||
}
|
||||
};
|
||||
|
||||
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||
void SetDefaultTracingEngine(const char* name);
|
||||
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||
FactoryFunction factory);
|
||||
} // namespace tracing
|
||||
|
||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||
// C++ implementation, and back.
|
||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||
static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \
|
||||
return reinterpret_cast<const CPP_CLASS* const&>(o); \
|
||||
} \
|
||||
static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<C_TYPEDEF* const&>(o); \
|
||||
} \
|
||||
static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \
|
||||
return reinterpret_cast<const C_TYPEDEF* const&>(o); \
|
||||
}
|
||||
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
||||
} // namespace internal
|
||||
DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
|
||||
DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
|
||||
DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
|
||||
DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
|
||||
DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -29,15 +30,19 @@ using tensorflow::string;
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||
class UnifiedCAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
|
||||
protected:
|
||||
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
@ -45,7 +50,8 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
@ -63,7 +69,7 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
TF_ExecuteOperation(op, 2, inputs, o, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
@ -109,9 +115,11 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
@ -123,6 +131,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
@ -137,16 +146,14 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* input_t =
|
||||
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs));
|
||||
@ -195,8 +202,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg0, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
@ -215,8 +224,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_AbstractTensor* inputs[2] = {arg1, arg1};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
// Trace the operation now (create a node in the graph).
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s);
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
// Extract the resulting tensor.
|
||||
@ -256,6 +267,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
@ -273,7 +285,8 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
std::vector<TF_AbstractTensor*> func_args;
|
||||
{
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx);
|
||||
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
@ -286,7 +299,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_OutputListSetNumOutputs(func_outputs, 2, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
|
||||
eager_execution_ctx, s);
|
||||
s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
|
||||
@ -314,20 +327,21 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_DeleteAbstractFunction(func);
|
||||
}
|
||||
|
||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam()));
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, func);
|
||||
TF_AbstractFunction* f = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||
ASSERT_EQ(nullptr, f);
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
@ -348,7 +362,7 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TEST_P(UnifiedCAPI, TF_AbstractOpSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
@ -369,116 +383,44 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||
// Build an Eager context.
|
||||
TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an Eager operation.
|
||||
auto* op = TF_NewAbstractOp(ctx);
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx);
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f);
|
||||
TF_AbstractTensor* at =
|
||||
TF_CreateAbstractTensorFromEagerTensor(t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build a Graph context.
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute eager op using graph context.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
||||
graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
||||
|
||||
// Delete placeholder op.
|
||||
TF_DeleteAbstractOp(placeholder_op);
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||
TF_AbstractOpSetOpType(add_op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(add_op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
|
||||
TF_OutputList* add_outputs = TF_NewOutputList();
|
||||
|
||||
// Build eager context.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TF_ExecutionContext* eager_execution_ctx =
|
||||
TF_NewEagerExecutionContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx,
|
||||
status.get());
|
||||
auto placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractTensor(placeholder_t);
|
||||
TF_DeleteAbstractOp(add_op);
|
||||
TF_DeleteOutputList(add_outputs);
|
||||
TF_DeleteOutputList(placeholder_outputs);
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||
}
|
||||
|
||||
TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_ExecutionContextGetTFEContext(graph_ctx, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
TF_DeleteExecutionContext(graph_ctx);
|
||||
}
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Values("graphdef", "mlir"));
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
|
||||
::testing::Combine(::testing::Values("graphdef",
|
||||
"mlir"),
|
||||
::testing::Values(false)));
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -12,17 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -34,16 +36,8 @@ namespace tensorflow {
|
||||
//
|
||||
// A context is responsible for creating key objects such as Tensors,
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
class ImmediateExecutionContext : public AbstractContext {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
@ -74,21 +68,20 @@ class AbstractContextInterface {
|
||||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
ImmediateExecutionOperation* CreateOperation() override = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
// code that is layered on top of the runtime.
|
||||
virtual bool UsesTFRT() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
@ -108,14 +101,32 @@ class AbstractContextInterface {
|
||||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
explicit ImmediateExecutionContext(AbstractContextKind kind)
|
||||
: AbstractContext(kind) {}
|
||||
~ImmediateExecutionContext() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateContextPtr =
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
74
tensorflow/c/eager/immediate_execution_operation.h
Normal file
74
tensorflow/c/eager/immediate_execution_operation.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class ImmediateExecutionOperation : public AbstractOperation {
|
||||
public:
|
||||
virtual void Clear() = 0;
|
||||
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
protected:
|
||||
explicit ImmediateExecutionOperation(AbstractOperationKind kind)
|
||||
: AbstractOperation(kind) {}
|
||||
~ImmediateExecutionOperation() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateOpPtr =
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -30,16 +31,8 @@ namespace tensorflow {
|
||||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
@ -57,12 +50,33 @@ class AbstractTensorHandleInterface {
|
||||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
explicit ImmediateExecutionTensorHandle(AbstractTensorHandleKind kind)
|
||||
: AbstractTensorHandle(kind) {}
|
||||
~ImmediateExecutionTensorHandle() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionTensorHandleDeleter {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateTensorHandlePtr =
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::ImmediateExecutionTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
@ -40,6 +40,9 @@ using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
using MaybeParallelTensorOwned =
|
||||
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
|
||||
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
|
||||
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
|
||||
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
|
||||
@ -141,9 +144,32 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
|
||||
result.emplace(std::move(result_content));
|
||||
return result;
|
||||
}
|
||||
std::vector<ParallelTensor*> parallel_inputs;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
|
||||
parallel_inputs.reserve(inputs.size());
|
||||
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
|
||||
for (const auto& input : inputs) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
std::unique_ptr<ParallelTensor> parallel_tensor(
|
||||
parallel_device.CopyToParallelDevice(
|
||||
context, absl::get<TFE_TensorHandle*>(input), status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
parallel_inputs.push_back(parallel_tensor.get());
|
||||
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
|
||||
} else {
|
||||
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
|
||||
}
|
||||
}
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
maybe_parallel_results(
|
||||
parallel_device.Execute(context, std::move(inputs), operation_name,
|
||||
parallel_device.Execute(context, parallel_inputs, operation_name,
|
||||
attributes, expected_max_outputs, status));
|
||||
if (!maybe_parallel_results.has_value()) return result;
|
||||
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -28,21 +30,216 @@ class OpDeleter {
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
class StatusDeleter {
|
||||
public:
|
||||
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
|
||||
};
|
||||
|
||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Allows a single op at a time to be launched without blocking.
|
||||
//
|
||||
// DeviceThread itself is thread-safe, in that StartExecute will block if there
|
||||
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
|
||||
// multiple DeviceThreads should always be accessed in the same order to avoid
|
||||
// deadlocks.
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
// each thread would cause collectives to deadlock. For consistency we
|
||||
// create a new sync executor for every thread.
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
std::bind(&DeviceThread::Run, this))) {}
|
||||
~DeviceThread();
|
||||
|
||||
// Requests that the worker thread execute the specified operation. Blocks
|
||||
// until the previously pending operation (a StartExecute without a Join) has
|
||||
// finished, if any.
|
||||
void StartExecute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs);
|
||||
// Block until the previous `StartExecute` operation has executed. Forwards
|
||||
// the status from `TFE_Execute` and returns outputs if the status is OK.
|
||||
std::vector<TensorHandlePtr> Join(TF_Status* status);
|
||||
|
||||
private:
|
||||
void Run();
|
||||
|
||||
void Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
|
||||
|
||||
enum class ExecutionState {
|
||||
kReadyToExecute,
|
||||
kHasResult,
|
||||
kIdle,
|
||||
kShuttingDown,
|
||||
};
|
||||
|
||||
tensorflow::mutex execution_mutex_;
|
||||
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
|
||||
ExecutionState::kIdle;
|
||||
// Tells the worker thread that there is new work.
|
||||
tensorflow::condition_variable start_execute_;
|
||||
// The worker thread notifies that work has finished.
|
||||
tensorflow::condition_variable finished_execute_;
|
||||
// Notifies a StartExecute that the previous Join has finished.
|
||||
tensorflow::condition_variable finished_join_;
|
||||
|
||||
// Temporary state between `StartExecute` and `Join`.
|
||||
// Inputs
|
||||
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
|
||||
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
|
||||
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// Outputs
|
||||
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
|
||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
DeviceThread::~DeviceThread() {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
execution_state_ = ExecutionState::kShuttingDown;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
void DeviceThread::Run() {
|
||||
while (true) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ == ExecutionState::kIdle ||
|
||||
execution_state_ == ExecutionState::kHasResult) {
|
||||
start_execute_.wait(l);
|
||||
}
|
||||
if (execution_state_ == ExecutionState::kShuttingDown) {
|
||||
return;
|
||||
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
|
||||
// op_outputs_ may have been std::moved
|
||||
op_outputs_ = std::vector<TensorHandlePtr>();
|
||||
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
|
||||
expected_max_outputs_, &op_outputs_, status_.get());
|
||||
execution_state_ = ExecutionState::kHasResult;
|
||||
}
|
||||
}
|
||||
finished_execute_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceThread::StartExecute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kIdle) {
|
||||
// If there's already a pending execution, wait until Join finishes before
|
||||
// starting on the next operation.
|
||||
finished_join_.wait(l);
|
||||
}
|
||||
context_ = context;
|
||||
operation_name_ = operation_name;
|
||||
op_inputs_ = inputs;
|
||||
attributes_ = attributes;
|
||||
expected_max_outputs_ = expected_max_outputs;
|
||||
execution_state_ = ExecutionState::kReadyToExecute;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> result;
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kHasResult) {
|
||||
finished_execute_.wait(l);
|
||||
}
|
||||
if (TF_GetCode(status_.get()) != TF_OK) {
|
||||
TF_SetStatus(status, TF_GetCode(status_.get()),
|
||||
TF_Message(status_.get()));
|
||||
}
|
||||
execution_state_ = ExecutionState::kIdle;
|
||||
result = std::move(op_outputs_);
|
||||
}
|
||||
finished_join_.notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs,
|
||||
TF_Status* status) const {
|
||||
if (op_ == nullptr) {
|
||||
TFE_ContextSetExecutorForThread(context, executor_.get());
|
||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
TFE_OpAddAttrs(op_.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
TFE_OpAddInput(op_.get(), inputs[input_index], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
unwrapped_results.resize(real_num_outputs);
|
||||
outputs->reserve(real_num_outputs);
|
||||
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
|
||||
outputs->emplace_back(unwrapped_result);
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
: underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
// Necessary for a unique_ptr to a forward-declared type.
|
||||
ParallelDevice::~ParallelDevice() = default;
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
@ -65,14 +262,14 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
components.reserve(underlying_devices_.size());
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
int64_t* device_id = new int64_t;
|
||||
int32_t* device_id = new int32_t;
|
||||
*device_id = device_index;
|
||||
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
|
||||
TF_NewTensor(
|
||||
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int64_t),
|
||||
TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id,
|
||||
sizeof(int32_t),
|
||||
[](void* data, size_t, void* arg) {
|
||||
delete reinterpret_cast<int64_t*>(data);
|
||||
delete reinterpret_cast<int32_t*>(data);
|
||||
},
|
||||
nullptr),
|
||||
TF_DeleteTensor);
|
||||
@ -86,7 +283,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
|
||||
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32);
|
||||
TFE_TensorHandle* device_handle;
|
||||
int num_outputs = 1;
|
||||
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
|
||||
@ -100,7 +297,7 @@ std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
|
||||
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
|
||||
ParallelDevice::Execute(TFE_Context* context,
|
||||
std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
TF_Status* status) const {
|
||||
@ -108,88 +305,34 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor =
|
||||
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
int first_op_output_count = 0;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
std::vector<TFE_TensorHandle*> device_inputs;
|
||||
device_inputs.reserve(device_inputs.size());
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
|
||||
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
|
||||
// to each parallel operation.
|
||||
//
|
||||
// TODO(allenl): There may be smarter ways to do this copy in some
|
||||
// cases, i.e. with a collective broadcast. We'll need to be careful
|
||||
// about things that are taken as inputs on the host or on their
|
||||
// existing device (for multi-device functions).
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<TFE_TensorHandle*>(inputs[input_index]),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
} else {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(),
|
||||
absl::get<ParallelTensor*>(inputs[input_index])
|
||||
->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// Parallel tensors are divided between operations by device.
|
||||
device_inputs.push_back(inputs[input_index]->tensor(device_index));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
per_device_output_tensors.push_back(device_thread->Join(status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
first_op_output_count = per_device_output_tensors.rbegin()->size();
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
|
@ -41,19 +41,8 @@ class TensorHandleDeleter {
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
|
||||
using MaybeParallelTensorUnowned =
|
||||
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
|
||||
class DeviceThread;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
@ -61,6 +50,8 @@ class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
@ -79,10 +70,9 @@ class ParallelDevice {
|
||||
|
||||
// Takes a description of a single operation being executed on the
|
||||
// ParallelDevice, and in turn runs one operation per component device with
|
||||
// its corresponding inputs from the input ParallelTensors (or
|
||||
// implicitly-mirrored tensors on other devices). Wraps the resulting
|
||||
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
|
||||
// output of the original operation.
|
||||
// its corresponding inputs from the input ParallelTensors. Wraps the
|
||||
// resulting per-device and per-output TFE_TensorHandles into one
|
||||
// ParallelTensor per output of the original operation.
|
||||
//
|
||||
// Attributes are forwarded to executed operations unmodified.
|
||||
//
|
||||
@ -90,7 +80,7 @@ class ParallelDevice {
|
||||
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
|
||||
// if sanity checks on dtypes/metadata fail.
|
||||
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
|
||||
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
|
||||
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
|
||||
const char* operation_name, const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs, TF_Status* status) const;
|
||||
|
||||
@ -98,9 +88,19 @@ class ParallelDevice {
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// A sequence of thread wrappers, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
//
|
||||
// Conceptually this is a thread pool with one thread per device. It requires
|
||||
// less synchronization than a thread pool would for this task, since Execute
|
||||
// acquires each thread in order (and so only one Execute will schedule
|
||||
// blocking collective operations at a time), and avoids some dynamic
|
||||
// allocation/scheduling.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
|
||||
// than a single list of threads so aliased nested parallel devices don't
|
||||
// re-use a thread.
|
||||
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
|
||||
};
|
||||
|
||||
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||
|
@ -407,11 +407,12 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
void TestCollective(bool async) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
TFE_ContextOptionsSetAsync(opts.get(), async);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
@ -454,6 +455,12 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||
|
||||
// Note that ops on the parallel device currently don't execute
|
||||
// asynchronously. The test is just that we don't get deadlocks.
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
|
@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
ExpectScalarEq<int64_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int64_t>(components[1].get(), 1);
|
||||
ExpectScalarEq<int32_t>(components[0].get(), 0);
|
||||
ExpectScalarEq<int32_t>(components[1].get(), 1);
|
||||
std::string first_device =
|
||||
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
|
||||
ASSERT_EQ(underlying_devices[0], first_device);
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
struct TF_StringStream {
|
||||
@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() {
|
||||
return list;
|
||||
}
|
||||
|
||||
char* TF_GetTempFileName(const char* extension) {
|
||||
return strdup(::tensorflow::io::GetTempFilename(extension).c_str());
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) {
|
||||
return ::tensorflow::Env::Default()->NowNanos();
|
||||
}
|
||||
|
@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename,
|
||||
// The caller is responsible for freeing the list (see TF_StringStreamDone).
|
||||
TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void);
|
||||
|
||||
// Creates a temporary file name with an extension.
|
||||
// The caller is responsible for freeing the returned pointer.
|
||||
TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension);
|
||||
|
||||
// Returns the number of nanoseconds since the Unix epoch.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void);
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Experimental gcs filesystem plugin.
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
|
||||
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -19,12 +19,45 @@ tf_cc_shared_object(
|
||||
cc_library(
|
||||
name = "gcs_filesystem_impl",
|
||||
srcs = ["gcs_filesystem.cc"],
|
||||
hdrs = ["gcs_filesystem.h"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":gcs_helper",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gcs_helper",
|
||||
srcs = ["gcs_helper.cc"],
|
||||
hdrs = ["gcs_helper.h"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gcs_filesystem_test",
|
||||
srcs = [
|
||||
"gcs_filesystem_test.cc",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":gcs_filesystem_impl",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -12,31 +12,237 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
// How to upload new data when Flush() is called multiple times.
|
||||
// By default the entire file is reuploaded.
|
||||
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
|
||||
// If GCS_APPEND_MODE=compose then instead the new data is uploaded to a
|
||||
// temporary object and composed with the original object. This is disabled by
|
||||
// default as the multiple API calls required add a risk of stranding temporary
|
||||
// objects.
|
||||
constexpr char kComposeAppend[] = "compose";
|
||||
|
||||
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
|
||||
// same integer values. See
|
||||
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
|
||||
static inline void TF_SetStatusFromGCSStatus(
|
||||
const google::cloud::Status& gcs_status, TF_Status* status) {
|
||||
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
|
||||
gcs_status.message().c_str());
|
||||
}
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
std::string* bucket, std::string* object, TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "gs://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't start with 'gs://'.");
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bucket_end = fname.find("/", scheme_end + 1);
|
||||
if (bucket_end == std::string::npos) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain a bucket name.");
|
||||
return;
|
||||
}
|
||||
|
||||
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
|
||||
*object = fname.substr(bucket_end + 1);
|
||||
|
||||
if (object->empty() && !object_empty_ok) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain an object name.");
|
||||
}
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
typedef struct GCSFile {
|
||||
const std::string bucket;
|
||||
const std::string object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
} GCSFile;
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void Cleanup(TF_RandomAccessFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Adding cache.
|
||||
// `google-cloud-cpp` is working on a feature that we may want to use.
|
||||
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
auto stream = gcs_file->gcs_client->ReadObject(
|
||||
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
|
||||
TF_SetStatusFromGCSStatus(stream.status(), status);
|
||||
if ((TF_GetCode(status) != TF_OK) &&
|
||||
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
|
||||
return -1;
|
||||
}
|
||||
int64_t read;
|
||||
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
|
||||
&read)) {
|
||||
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
|
||||
return -1;
|
||||
}
|
||||
if (read != n) {
|
||||
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
stream.read(buffer, read);
|
||||
return read;
|
||||
}
|
||||
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
typedef struct GCSFile {
|
||||
const std::string bucket;
|
||||
const std::string object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
TempFile outfile;
|
||||
bool sync_need;
|
||||
// `offset` tells us how many bytes of this file are already uploaded to
|
||||
// server. If `offset == -1`, we always upload the entire temporary file.
|
||||
int64_t offset;
|
||||
} GCSFile;
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
static void SyncImpl(const std::string& bucket, const std::string& object,
|
||||
int64_t* offset, TempFile* outfile,
|
||||
gcs::Client* gcs_client, TF_Status* status) {
|
||||
outfile->flush();
|
||||
// `*offset == 0` means this file does not exist on the server.
|
||||
if (*offset == -1 || *offset == 0) {
|
||||
// UploadFile will automatically switch to resumable upload based on Client
|
||||
// configuration.
|
||||
auto metadata = gcs_client->UploadFile(outfile->getName(), bucket, object);
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
if (*offset == 0) {
|
||||
if (!outfile->truncate()) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not truncate internal temporary file.");
|
||||
return;
|
||||
}
|
||||
*offset = static_cast<int64_t>(metadata->size());
|
||||
}
|
||||
outfile->clear();
|
||||
outfile->seekp(std::ios::end);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
} else {
|
||||
std::string temporary_object =
|
||||
gcs::CreateRandomPrefixName("tf_writable_file_gcs");
|
||||
auto metadata =
|
||||
gcs_client->UploadFile(outfile->getName(), bucket, temporary_object);
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
const std::vector<gcs::ComposeSourceObject> source_objects = {
|
||||
{object, {}, {}}, {temporary_object, {}, {}}};
|
||||
metadata = gcs_client->ComposeObject(bucket, source_objects, object);
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
// We truncate the data that are already uploaded.
|
||||
if (!outfile->truncate()) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not truncate internal temporary file.");
|
||||
return;
|
||||
}
|
||||
*offset = static_cast<int64_t>(metadata->size());
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
}
|
||||
|
||||
void Cleanup(TF_WritableFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
|
||||
TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (!gcs_file->outfile.is_open()) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"The internal temporary file is not writable.");
|
||||
return;
|
||||
}
|
||||
gcs_file->sync_need = true;
|
||||
gcs_file->outfile.write(buffer, n);
|
||||
if (!gcs_file->outfile)
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
int64_t position = int64_t(gcs_file->outfile.tellp());
|
||||
if (position == -1)
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"tellp on the internal temporary file failed");
|
||||
else
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
return position == -1
|
||||
? -1
|
||||
: position + (gcs_file->offset == -1 ? 0 : gcs_file->offset);
|
||||
}
|
||||
|
||||
void Flush(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->sync_need) {
|
||||
if (!gcs_file->outfile) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Could not append to the internal temporary file.");
|
||||
return;
|
||||
}
|
||||
SyncImpl(gcs_file->bucket, gcs_file->object, &gcs_file->offset,
|
||||
&gcs_file->outfile, gcs_file->gcs_client, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
gcs_file->sync_need = false;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
}
|
||||
|
||||
void Sync(const TF_WritableFile* file, TF_Status* status) {
|
||||
Flush(file, status);
|
||||
}
|
||||
|
||||
void Close(const TF_WritableFile* file, TF_Status* status) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
if (gcs_file->sync_need) {
|
||||
Flush(file, status);
|
||||
}
|
||||
gcs_file->outfile.close();
|
||||
}
|
||||
|
||||
} // namespace tf_writable_file
|
||||
|
||||
@ -51,8 +257,109 @@ namespace tf_read_only_memory_region {
|
||||
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
typedef struct GCSFile {
|
||||
gcs::Client gcs_client; // owned
|
||||
bool compose;
|
||||
} GCSFile;
|
||||
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
TF_SetStatusFromGCSStatus(client.status(), status);
|
||||
return;
|
||||
}
|
||||
|
||||
const char* append_mode = std::getenv(kAppendMode);
|
||||
bool compose =
|
||||
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
|
||||
|
||||
filesystem->plugin_filesystem =
|
||||
new GCSFile({std::move(client.value()), compose});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
file->plugin_file = new tf_random_access_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
char* temp_file_name = TF_GetTempFileName("");
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::out), true,
|
||||
(gcs_file->compose ? 0 : -1)});
|
||||
// We are responsible for freeing the pointer returned by TF_GetTempFileName
|
||||
free(temp_file_name);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
char* temp_file_name_c_str = TF_GetTempFileName("");
|
||||
std::string temp_file_name(temp_file_name_c_str); // To prevent memory-leak
|
||||
free(temp_file_name_c_str);
|
||||
|
||||
if (!gcs_file->compose) {
|
||||
auto gcs_status =
|
||||
gcs_file->gcs_client.DownloadToFile(bucket, object, temp_file_name);
|
||||
TF_SetStatusFromGCSStatus(gcs_status, status);
|
||||
auto status_code = TF_GetCode(status);
|
||||
if (status_code != TF_OK && status_code != TF_NOT_FOUND) return;
|
||||
// If this file does not exist on server, we will need to sync it.
|
||||
bool sync_need = (status_code == TF_NOT_FOUND);
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need,
|
||||
-1});
|
||||
} else {
|
||||
// If compose is true, we do not download anything.
|
||||
// Instead we only check if this file exists on server or not.
|
||||
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::trunc), false,
|
||||
static_cast<int64_t>(metadata->size())});
|
||||
} else if (TF_GetCode(status) == TF_NOT_FOUND) {
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{std::move(bucket), std::move(object), &gcs_file->gcs_client,
|
||||
TempFile(temp_file_name, std::ios::binary | std::ios::trunc), true,
|
||||
0});
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
@ -60,6 +367,25 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->random_access_file_ops = static_cast<TF_RandomAccessFileOps*>(
|
||||
plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE));
|
||||
ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup;
|
||||
ops->random_access_file_ops->read = tf_random_access_file::Read;
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup;
|
||||
ops->filesystem_ops->new_random_access_file =
|
||||
tf_gcs_filesystem::NewRandomAccessFile;
|
||||
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
|
||||
ops->filesystem_ops->new_appendable_file =
|
||||
tf_gcs_filesystem::NewAppendableFile;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
@ -69,4 +395,4 @@ void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
info->ops = static_cast<TF_FilesystemPluginOps*>(
|
||||
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
|
||||
ProvideFilesystemSupportFor(&info->ops[0], "gs");
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
std::string* bucket, std::string* object, TF_Status* status);
|
||||
|
||||
namespace tf_random_access_file {
|
||||
void Cleanup(TF_RandomAccessFile* file);
|
||||
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
|
||||
char* buffer, TF_Status* status);
|
||||
} // namespace tf_random_access_file
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_RandomAccessFile* file, TF_Status* status);
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stacktrace_handler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x)
|
||||
|
||||
static const char* content = "abcdefghijklmnopqrstuvwxyz1234567890";
|
||||
// We will work with content_view instead of content.
|
||||
static const absl::string_view content_view = content;
|
||||
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class GCSFilesystemTest : public ::testing::Test {
|
||||
public:
|
||||
void SetUp() override {
|
||||
root_dir_ = io::JoinPath(
|
||||
tmp_dir_,
|
||||
::testing::UnitTest::GetInstance()->current_test_info()->name());
|
||||
status_ = TF_NewStatus();
|
||||
filesystem_ = new TF_Filesystem;
|
||||
tf_gcs_filesystem::Init(filesystem_, status_);
|
||||
ASSERT_TF_OK(status_) << "Could not initialize filesystem. "
|
||||
<< TF_Message(status_);
|
||||
}
|
||||
void TearDown() override {
|
||||
TF_DeleteStatus(status_);
|
||||
tf_gcs_filesystem::Cleanup(filesystem_);
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
static bool InitializeTmpDir() {
|
||||
// This env should be something like `gs://bucket/path`
|
||||
const char* test_dir = getenv("GCS_TEST_TMPDIR");
|
||||
if (test_dir != nullptr) {
|
||||
std::string bucket, object;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
ParseGCSPath(test_dir, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
TF_DeleteStatus(status);
|
||||
return false;
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// We add a random value into `test_dir` to ensures that two consecutive
|
||||
// runs are unlikely to clash.
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> distribution;
|
||||
std::string rng_val = std::to_string(distribution(gen));
|
||||
tmp_dir_ = io::JoinPath(string(test_dir), rng_val);
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetURIForPath(absl::string_view path) {
|
||||
const std::string translated_name =
|
||||
tensorflow::io::JoinPath(root_dir_, path);
|
||||
return translated_name;
|
||||
}
|
||||
|
||||
protected:
|
||||
TF_Filesystem* filesystem_;
|
||||
TF_Status* status_;
|
||||
|
||||
private:
|
||||
std::string root_dir_;
|
||||
static std::string tmp_dir_;
|
||||
};
|
||||
std::string GCSFilesystemTest::tmp_dir_;
|
||||
|
||||
::testing::AssertionResult WriteToServer(const std::string& path, size_t length,
|
||||
gcs::Client* gcs_client,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
return ::testing::AssertionFailure() << TF_Message(status);
|
||||
}
|
||||
|
||||
auto writer = gcs_client->WriteObject(bucket, object);
|
||||
writer.write(content, length);
|
||||
writer.Close();
|
||||
if (writer.metadata()) {
|
||||
return ::testing::AssertionSuccess();
|
||||
} else {
|
||||
return ::testing::AssertionFailure()
|
||||
<< writer.metadata().status().message();
|
||||
}
|
||||
}
|
||||
|
||||
::testing::AssertionResult CompareSubString(int64_t offset, size_t n,
|
||||
absl::string_view result,
|
||||
size_t read) {
|
||||
// Result isn't a null-terminated string so we have to wrap it inside a
|
||||
// `string_view`
|
||||
if (n == read && content_view.substr(offset, n) ==
|
||||
absl::string_view(result).substr(0, read)) {
|
||||
return ::testing::AssertionSuccess();
|
||||
} else {
|
||||
return ::testing::AssertionFailure()
|
||||
<< "Result: " << absl::string_view(result).substr(0, read)
|
||||
<< " Read:" << read;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, ParseGCSPath) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath("gs://bucket/path/to/object", false, &bucket, &object, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(bucket, "bucket");
|
||||
ASSERT_EQ(object, "path/to/object");
|
||||
|
||||
ParseGCSPath("gs://bucket/", true, &bucket, &object, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_EQ(bucket, "bucket");
|
||||
|
||||
ParseGCSPath("bucket/path/to/object", false, &bucket, &object, status_);
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
|
||||
|
||||
// bucket name must end with "/"
|
||||
ParseGCSPath("gs://bucket", true, &bucket, &object, status_);
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
|
||||
|
||||
ParseGCSPath("gs://bucket/", false, &bucket, &object, status_);
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
TEST_F(GCSFilesystemTest, RandomAccessFile) {
|
||||
std::string filepath = GetURIForPath("a_file");
|
||||
TF_RandomAccessFile* file = new TF_RandomAccessFile;
|
||||
tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
char* result = new char[content_view.length()];
|
||||
int64_t read = tf_random_access_file::Read(file, 0, 1, result, status_);
|
||||
ASSERT_EQ(read, -1) << "Read: " << read;
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_);
|
||||
TF_SetStatus(status_, TF_OK, "");
|
||||
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem_->plugin_filesystem);
|
||||
ASSERT_TRUE(
|
||||
WriteToServer(filepath, content_view.length(), gcs_client, status_));
|
||||
|
||||
read = tf_random_access_file::Read(file, 0, content_view.length(), result,
|
||||
status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_TRUE(CompareSubString(0, content_view.length(), result, read));
|
||||
|
||||
read = tf_random_access_file::Read(file, 0, 4, result, status_);
|
||||
ASSERT_TF_OK(status_);
|
||||
ASSERT_TRUE(CompareSubString(0, 4, result, read));
|
||||
|
||||
read = tf_random_access_file::Read(file, content_view.length() - 2, 4, result,
|
||||
status_);
|
||||
ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_);
|
||||
ASSERT_TRUE(CompareSubString(content_view.length() - 2, 2, result, read));
|
||||
|
||||
delete result;
|
||||
tf_random_access_file::Cleanup(file);
|
||||
delete file;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
GTEST_API_ int main(int argc, char** argv) {
|
||||
tensorflow::testing::InstallStacktraceHandler();
|
||||
if (!tensorflow::GCSFilesystemTest::InitializeTmpDir()) {
|
||||
std::cerr << "Could not read GCS_TEST_TMPDIR env";
|
||||
return -1;
|
||||
}
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
TempFile::TempFile(const std::string& temp_file_name, std::ios::openmode mode)
|
||||
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
|
||||
|
||||
TempFile::TempFile(TempFile&& rhs)
|
||||
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
|
||||
|
||||
TempFile::~TempFile() {
|
||||
std::fstream::close();
|
||||
std::remove(name_.c_str());
|
||||
}
|
||||
|
||||
const std::string TempFile::getName() const { return name_; }
|
||||
|
||||
bool TempFile::truncate() {
|
||||
std::fstream::close();
|
||||
std::fstream::open(name_, std::ios::binary | std::ios::out);
|
||||
return std::fstream::is_open();
|
||||
}
|
@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
|
||||
// This is a demo fuzzer to test that the entire framework functions correctly.
|
||||
// Once we start moving the existing fuzzers to this framework we will delete
|
||||
// this.
|
||||
// TODO(mihaimaruseac): Delete this when no longer needed
|
||||
void DemoFuzzer(const uint8_t* data, size_t size) {
|
||||
// Trigger a small bug that should be found by the fuzzer quite quickly
|
||||
if (size > 10 && size % 3 == 2)
|
||||
if (data[0] > data[1])
|
||||
if (data[5] % data[2] == data[3]) abort();
|
||||
}
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
|
||||
DemoFuzzer(data, size);
|
||||
return 0;
|
||||
}
|
||||
class TempFile : public std::fstream {
|
||||
public:
|
||||
// We should specify openmode each time we call TempFile.
|
||||
TempFile(const std::string& temp_file_name, std::ios::openmode mode);
|
||||
TempFile(TempFile&& rhs);
|
||||
~TempFile() override;
|
||||
const std::string getName() const;
|
||||
bool truncate();
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
@ -3,6 +3,10 @@
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -23,8 +27,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
@ -47,6 +51,46 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_utils",
|
||||
srcs = [
|
||||
"saved_model_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"saved_model_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"test_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"test_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -57,6 +101,7 @@ cc_library(
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -83,3 +128,47 @@ filegroup(
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "constant_loading_test",
|
||||
srcs = [
|
||||
"constant_loading_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
":test_utils",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_variable_loading_test",
|
||||
srcs = [
|
||||
"saved_variable_loading_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
":test_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
@ -15,12 +15,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
@ -38,15 +38,15 @@ class ConcreteFunction {
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
virtual ImmediateExecutionOperation* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
testing::FillNumericTensorBuffer(expected.dtype(), expected.NumElements(),
|
||||
expected.data(), 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
testing::CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(testing::InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(testing::InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
60
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
60
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
@ -0,0 +1,60 @@
|
||||
# This package contains written convenience helpers for Eager Operations
|
||||
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
"variable_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"variable_ops.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":variable_ops",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
118
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
118
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
@ -0,0 +1,118 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
static const char kNoSharingResourceID[] =
|
||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
ImmediateTensorHandlePtr* handle) {
|
||||
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
|
||||
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
|
||||
// shape.dims() will be -1.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes = shape.dim_sizes();
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape(
|
||||
"shape", reinterpret_cast<const int64_t*>(dim_sizes.data()),
|
||||
shape.dims()));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString("container", "", 0));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
|
||||
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
|
||||
|
||||
AbstractTensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_var_handle(var_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_var_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_var_handle.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value) {
|
||||
ImmediateOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateTensorHandlePtr* output) {
|
||||
ImmediateOpPtr read_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandle* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_value(value);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(owned_value.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
output->reset(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
ImmediateOpPtr destroy_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(destroy_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE
|
||||
// TensorHandle associated with the variable. This is equivalent to creating an
|
||||
// unitialized TF2 tf.Variable.
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
ImmediateTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
@ -0,0 +1,106 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Sanity check for variable creation
|
||||
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
|
||||
}
|
||||
|
||||
// Sanity check for variable destruction
|
||||
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
|
||||
// Destroy the variable
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
}
|
||||
|
||||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
ImmediateTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
ImmediateTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -0,0 +1,60 @@
|
||||
# This package contains classes corresponding to Revived SavedObjectGraph types
|
||||
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
"constant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"constant.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable",
|
||||
srcs = [
|
||||
"variable.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"variable.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/ops:variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_convertible",
|
||||
hdrs = [
|
||||
"tensorhandle_convertible.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Constant::Constant(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Constant::Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
|
||||
if (handle == nullptr) {
|
||||
return errors::Internal("Failed to convert tensor to tensorhandle");
|
||||
}
|
||||
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This class corresponds to python's tf.constant, which is effectively a
|
||||
// TensorHandle explicitly initialized to some value.
|
||||
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
|
||||
// and offer a subclass of TensorHandleConvertible. Note that similar to
|
||||
// the python's eager mode logic, we bypass calling the "Const" op:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
|
||||
class Constant : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// RevivedConstant is movable, but not copyable.
|
||||
Constant(Constant&& other) = default;
|
||||
Constant& operator=(Constant&& other) = default;
|
||||
|
||||
~Constant() override = default;
|
||||
|
||||
private:
|
||||
explicit Constant(ImmediateTensorHandlePtr handle);
|
||||
Constant(const Constant&) = delete;
|
||||
Constant& operator=(const Constant&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A common interface for objects that can be converted to a TensorHandle.
|
||||
// Examples of objects that implement this include Variables, Constants, Assets,
|
||||
// etc. This is used to convert captured objects into a ConcreteFunction's
|
||||
// captured TensorHandles:
|
||||
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
|
||||
class TensorHandleConvertible {
|
||||
public:
|
||||
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
|
||||
: handle_(std::move(handle)) {}
|
||||
|
||||
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
|
||||
|
||||
// TensorHandleConvertible is movable, but not copyable.
|
||||
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
|
||||
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
|
||||
|
||||
virtual ~TensorHandleConvertible() = default;
|
||||
|
||||
protected:
|
||||
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
|
||||
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
|
||||
ImmediateTensorHandlePtr handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Variable::Variable(ImmediateExecutionContext* ctx, DataType dtype,
|
||||
TensorShape shape, absl::optional<std::string> name,
|
||||
ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)),
|
||||
name_(name.has_value() ? *name : "Variable"),
|
||||
dtype_(dtype),
|
||||
shape_(shape),
|
||||
ctx_(ctx) {}
|
||||
|
||||
Variable::~Variable() {
|
||||
// If the handle is null (perhaps because variable was std::moved from), then
|
||||
// we don't have to do anything.
|
||||
if (handle_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
Status status = internal::DestroyResource(ctx_, handle_.get());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error destroying variable: " << name_
|
||||
<< "due to: " << status;
|
||||
}
|
||||
}
|
||||
|
||||
DataType Variable::dtype() { return dtype_; }
|
||||
|
||||
TensorShape Variable::shape() { return shape_; }
|
||||
|
||||
Status Variable::Assign(ImmediateExecutionTensorHandle* handle) {
|
||||
return internal::AssignVariable(ctx_, handle_.get(), dtype_, handle);
|
||||
}
|
||||
|
||||
Status Variable::ReadValue(ImmediateTensorHandlePtr* out) {
|
||||
return internal::ReadVariable(ctx_, handle_.get(), dtype_, out);
|
||||
}
|
||||
|
||||
Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable(
|
||||
ctx, dtype, shape, &handle));
|
||||
|
||||
output->reset(
|
||||
new Variable(ctx, dtype, shape, std::move(name), std::move(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Variable : public TensorHandleConvertible {
|
||||
public:
|
||||
// Creates an uninitialized resource variable. Note that a caller must
|
||||
// call "assign" to associate a value with the variable.
|
||||
static Status CreateUninitialized(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// The dtype of the underlying variable.
|
||||
DataType dtype();
|
||||
|
||||
// The shape of the underlying variable.
|
||||
TensorShape shape();
|
||||
|
||||
// Updates the variable's contents with `handle`.
|
||||
Status Assign(ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
// Reads the value of the variable, and stores it in `out`
|
||||
Status ReadValue(ImmediateTensorHandlePtr* out);
|
||||
|
||||
// Variable is movable, but not copyable.
|
||||
Variable(Variable&& other) = default;
|
||||
Variable& operator=(Variable&& other) = default;
|
||||
|
||||
~Variable() override;
|
||||
|
||||
private:
|
||||
Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape,
|
||||
absl::optional<std::string> name, ImmediateTensorHandlePtr handle);
|
||||
Variable(const Variable& variable) = delete;
|
||||
Variable& operator=(const Variable&) = delete;
|
||||
|
||||
std::string name_;
|
||||
DataType dtype_;
|
||||
TensorShape shape_;
|
||||
|
||||
// ctx_ must outlive Variable.
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_VARIABLE_H_
|
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
tensorflow::Tensor tensor;
|
||||
bool parse_result = tensor.FromProto(proto);
|
||||
if (!parse_result) {
|
||||
return errors::Internal("Failed to parse tensor from tensorproto");
|
||||
}
|
||||
|
||||
TensorInterface tensor_interface(std::move(tensor));
|
||||
return Constant::Create(ctx, &tensor_interface, output);
|
||||
}
|
||||
|
||||
// This follows the python variable restoration logic:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output) {
|
||||
const std::string& name = variable.name();
|
||||
tensorflow::TensorShape shape(variable.shape());
|
||||
tensorflow::DataType dtype = variable.dtype();
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
Variable::CreateUninitialized(ctx, dtype, shape, name, output));
|
||||
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
|
||||
// Some internal utility functions for the SavedModelAPI, factored out into a
|
||||
// separately unit-testable header.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Load a TensorProto into a tensorflow::Constant. This is similar to the
|
||||
// constant loading logic in python:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// Creates a tensorflow::Variable from a SavedVariable. This is similar to the
|
||||
// logic in:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407
|
||||
// Note that the caller **must assign a value** to the loaded variable.
|
||||
Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
@ -0,0 +1,122 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class SavedVariableLoadingTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>>> {
|
||||
public:
|
||||
SavedVariableLoadingTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Sanity check that constructing a tensorflow::Variable from a SavedVariable
|
||||
// 1. does not cause an error
|
||||
// 2. preserves dtype and shape.
|
||||
TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
|
||||
SavedVariable saved_variable;
|
||||
saved_variable.set_dtype(dtype);
|
||||
shape.AsProto(saved_variable.mutable_shape());
|
||||
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(internal::LoadSavedVariable(context(), saved_variable, &var));
|
||||
EXPECT_EQ(var->dtype(), dtype);
|
||||
EXPECT_EQ(var->shape(), shape);
|
||||
}
|
||||
|
||||
// Assigning and reading values should yield
|
||||
// consistent results.
|
||||
TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) {
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
std::vector<int64> shape_vector = std::get<1>(test_params);
|
||||
TensorShape shape(shape_vector);
|
||||
|
||||
// Create the variable.
|
||||
Status status;
|
||||
std::unique_ptr<Variable> var;
|
||||
TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape,
|
||||
absl::nullopt, &var));
|
||||
|
||||
// Create a TensorHandle
|
||||
ImmediateTensorHandlePtr expected_handle =
|
||||
testing::CreateTensorHandle(context(), dtype, shape_vector, 42);
|
||||
AbstractTensorPtr expected_tensor(expected_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
|
||||
// Assign the tensorhandle to the variable.
|
||||
TF_EXPECT_OK(var->Assign(expected_handle.get()));
|
||||
|
||||
// Read back the value from the variable
|
||||
ImmediateTensorHandlePtr output_handle;
|
||||
TF_EXPECT_OK(var->ReadValue(&output_handle));
|
||||
AbstractTensorPtr output_tensor(output_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
|
||||
// Check that output_tensor == expected_tensor
|
||||
EXPECT_EQ(output_tensor->Type(), expected_tensor->Type());
|
||||
EXPECT_EQ(output_tensor->NumElements(), expected_tensor->NumElements());
|
||||
testing::CheckBufferDataIsEqual(
|
||||
output_tensor->Type(), output_tensor->NumElements(),
|
||||
output_tensor->Data(), expected_tensor->Data());
|
||||
}
|
||||
|
||||
// Test against combinations of SavedVariables of
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SavedVariableIntegerDtypesTest, SavedVariableLoadingTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(testing::DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(testing::InterestingShapes())));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
SavedVariableFloatingDtypesTest, SavedVariableLoadingTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(testing::InterestingShapes())));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
143
tensorflow/c/experimental/saved_model/core/test_utils.cc
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr() {
|
||||
return std::make_unique<StaticDeviceMgr>(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:0"));
|
||||
}
|
||||
|
||||
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) {
|
||||
return EagerContextPtr(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr,
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr));
|
||||
}
|
||||
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||
DataType dtype,
|
||||
absl::Span<const int64> shape,
|
||||
int8 value) {
|
||||
AbstractTensorPtr tensor(ctx->CreateTensor(dtype, shape));
|
||||
CHECK_NE(tensor.get(), nullptr)
|
||||
<< "Tensor creation failed for tensor of dtype: "
|
||||
<< DataTypeString(dtype);
|
||||
CHECK_EQ(tensor->Type(), dtype);
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
CHECK_EQ(tensor->Dim(i), shape[i]);
|
||||
}
|
||||
FillNumericTensorBuffer(tensor->Type(), tensor->NumElements(), tensor->Data(),
|
||||
value);
|
||||
ImmediateTensorHandlePtr handle(ctx->CreateLocalHandle(tensor.get()));
|
||||
CHECK_NE(handle.get(), nullptr);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
int8 value) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_buffer = static_cast<type*>(buffer); \
|
||||
for (size_t i = 0; i < num_elements; ++i) { \
|
||||
typed_buffer[i] = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
75
tensorflow/c/experimental/saved_model/core/test_utils.h
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
// Creates a DeviceMgr suitable for local tests.
|
||||
std::unique_ptr<StaticDeviceMgr> CreateTestingDeviceMgr();
|
||||
|
||||
// Creates an EagerContext suitable for local tests. Does not take ownership
|
||||
// of `device_mgr`.
|
||||
EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr);
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is useful for tests using GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set);
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes
|
||||
std::vector<std::vector<int64>> InterestingShapes();
|
||||
|
||||
// Returns a TensorHandle of `dtype` and `shape`, filled with `value`.
|
||||
// `dtype` must be an integer dtype, float, or double.
|
||||
// If a TensorHandle cannot be created successfully, this function will
|
||||
// CHECK fail. This should only be used for testing purposes.
|
||||
ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx,
|
||||
DataType dtype,
|
||||
absl::Span<const int64> shape,
|
||||
int8 value);
|
||||
|
||||
// Fills a numeric tensor's buffer with `value`.
|
||||
// dtype must be any integer dtype, float or double.
|
||||
void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
int8 value);
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
// dtype must be any integer dtype, float, or double.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b);
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
|
@ -23,14 +23,13 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
|
@ -144,7 +144,9 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
@ -176,7 +178,7 @@ cc_library(
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
@ -188,7 +190,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,11 +22,15 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
@ -34,10 +38,21 @@ extern "C" {
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, absl::nullopt,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, tagset,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -54,6 +54,20 @@ class AbstractTensorInterface {
|
||||
virtual ~AbstractTensorInterface() {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_
|
||||
|
@ -228,57 +228,6 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type,
|
||||
} // namespace tensorflow
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
}
|
||||
|
||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
if (sz < src_len) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument("src string is too large to encode"));
|
||||
return 0;
|
||||
}
|
||||
if (dst_len < sz) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
StringEncode(src, src_len, dst);
|
||||
return sz;
|
||||
}
|
||||
|
||||
static Status TF_StringDecode_Impl(const char* src, size_t src_len,
|
||||
const char** dst, size_t* dst_len) {
|
||||
tensorflow::uint64 len64 = 0;
|
||||
const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
|
||||
if (p == nullptr) {
|
||||
return InvalidArgument("invalid string encoding or truncated src buffer");
|
||||
}
|
||||
if (len64 > std::numeric_limits<size_t>::max()) {
|
||||
return InvalidArgument("encoded string is ", len64,
|
||||
"-bytes, which is too large for this architecture");
|
||||
}
|
||||
*dst = p;
|
||||
*dst_len = static_cast<size_t>(len64);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
|
||||
size_t* dst_len, TF_Status* status) {
|
||||
Set_TF_Status_from_Status(status,
|
||||
TF_StringDecode_Impl(src, src_len, dst, dst_len));
|
||||
if (TF_GetCode(status) != TF_OK) return 0;
|
||||
return static_cast<size_t>(*dst - src) + *dst_len;
|
||||
}
|
||||
|
||||
size_t TF_StringEncodedSize(size_t len) {
|
||||
return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
|
||||
}
|
||||
|
||||
static void DeleteArray(void* data, size_t size, void* arg) {
|
||||
DCHECK_EQ(data, arg);
|
||||
@ -334,58 +283,12 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
std::memcpy(TF_TensorData(t), str.c_str(), str.size());
|
||||
return t;
|
||||
}
|
||||
if (src.dtype() != tensorflow::DT_STRING) {
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
||||
}
|
||||
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
|
||||
// encoded sequence of strings.
|
||||
|
||||
// Compute bytes needed for encoding.
|
||||
size_t size = 0;
|
||||
const auto& srcarray = src.flat<tstring>();
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
const string& s = srcarray(i);
|
||||
// uint64 starting_offset, TF_StringEncode-d string.
|
||||
size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
|
||||
}
|
||||
|
||||
// Encode all strings.
|
||||
char* base = new char[size];
|
||||
char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
|
||||
char* dst = data_start; // Where next string is encoded.
|
||||
size_t dst_len = size - static_cast<size_t>(data_start - base);
|
||||
tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
|
||||
for (int i = 0; i < srcarray.size(); ++i) {
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const string& s = srcarray(i);
|
||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||
StringEncode(s.data(), s.size(), dst);
|
||||
dst += consumed;
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
*status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
delete[] base;
|
||||
Tensor tensor;
|
||||
if (!tensor.CopyFrom(src, src.shape())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto dims = src.shape().dim_sizes();
|
||||
std::vector<tensorflow::int64> dimvec(dims.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
dimvec[i] = dims[i];
|
||||
}
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
return TF_NewTensor(TF_STRING,
|
||||
reinterpret_cast<const int64_t*>(dimvec.data()),
|
||||
dimvec.size(), base, size, DeleteArray, base);
|
||||
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
|
||||
@ -409,39 +312,7 @@ Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
if (tensor_.dtype() != DT_STRING) {
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
// TF_STRING tensors require copying since Tensor class expects a sequence of
|
||||
// string objects.
|
||||
const tensorflow::int64 num_elements = tensor_.NumElements();
|
||||
const char* input = reinterpret_cast<const char*>(Data());
|
||||
const size_t src_size = ByteSize();
|
||||
if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
|
||||
num_elements) {
|
||||
return InvalidArgument(
|
||||
"Malformed TF_STRING tensor; too short to hold number of elements");
|
||||
}
|
||||
const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
|
||||
const char* limit = input + src_size;
|
||||
|
||||
*dst = tensorflow::Tensor(tensor_.dtype(), tensor_.shape());
|
||||
auto dstarray = dst->flat<tstring>();
|
||||
for (tensorflow::int64 i = 0; i < num_elements; ++i) {
|
||||
tensorflow::uint64 offset =
|
||||
reinterpret_cast<const tensorflow::uint64*>(input)[i];
|
||||
if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
|
||||
return InvalidArgument("Malformed TF_STRING tensor; element ", i,
|
||||
" out of range");
|
||||
}
|
||||
size_t len;
|
||||
const char* p;
|
||||
const char* srcp = data_start + offset;
|
||||
Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
|
||||
if (!status.ok()) return status;
|
||||
dstarray(i).assign(p, len);
|
||||
}
|
||||
*dst = tensor_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -148,34 +148,6 @@ TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from,
|
||||
int num_new_dims,
|
||||
TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Encode the string `src` (`src_len` bytes long) into `dst` in the format
|
||||
// required by TF_STRING tensors. Does not write to memory more than `dst_len`
|
||||
// bytes beyond `*dst`. `dst_len` should be at least
|
||||
// TF_StringEncodedSize(src_len).
|
||||
//
|
||||
// On success returns the size in bytes of the encoded string.
|
||||
// Returns an error into `status` otherwise.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringEncode(const char* src, size_t src_len,
|
||||
char* dst, size_t dst_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Decode a string encoded using TF_StringEncode.
|
||||
//
|
||||
// On success, sets `*dst` to the start of the decoded string and `*dst_len` to
|
||||
// its length. Returns the number of bytes starting at `src` consumed while
|
||||
// decoding. `*dst` points to memory within the encoded buffer. On failure,
|
||||
// `*dst` and `*dst_len` are undefined and an error is set in `status`.
|
||||
//
|
||||
// Does not read memory more than `src_len` bytes beyond `src`.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len,
|
||||
const char** dst, size_t* dst_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Return the size in bytes required to encode a string `len` bytes long into a
|
||||
// TF_STRING tensor.
|
||||
TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len);
|
||||
|
||||
// Returns bool iff this tensor is aligned.
|
||||
TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*);
|
||||
|
||||
|
20
tensorflow/c/tf_tstring.h
Normal file
20
tensorflow/c/tf_tstring.h
Normal file
@ -0,0 +1,20 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_TF_TSTRING_H_
|
||||
#define TENSORFLOW_C_TF_TSTRING_H_
|
||||
|
||||
#include "tensorflow/core/platform/ctstring.h"
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_
|
@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) {
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) {
|
||||
SetRandomValuesForMaxPooling<float>(&x_init_value);
|
||||
RunTest(x, x_init_value, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
TensorShape x_shape({1, 2, 2, 1});
|
||||
@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
// TODO(rocm):
|
||||
// Re-enable this test once 3D pooling is supported on ROCm platform
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
TensorShape x_shape({1, 3, 3, 3, 1});
|
||||
TensorShape y_shape({1, 1, 1, 1, 1});
|
||||
@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
auto y = AvgPool3D(scope_, x, ksize, strides, "SAME");
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
|
@ -106,6 +106,7 @@ cc_library(
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
":loader_util",
|
||||
":reader",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -132,6 +133,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader_util",
|
||||
srcs = ["loader_util.cc"],
|
||||
hdrs = ["loader_util.h"],
|
||||
deps = [":constants"] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bundle_v2_test",
|
||||
srcs = ["bundle_v2_test.cc"],
|
||||
|
@ -10,6 +10,9 @@ tf_cc_test(
|
||||
srcs = [
|
||||
"saved_model_api_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/cc/experimental/base/public:runtime",
|
||||
"//tensorflow/cc/experimental/base/public:runtime_builder",
|
||||
|
@ -90,7 +90,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
|
||||
// That unblocks writing other tests that require a TF_SavedModel*,
|
||||
// like loading a ConcreteFunction. This test at least checks that the
|
||||
// C API builds and can be minimally run.
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED);
|
||||
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -29,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
|
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/loader_util.h
Normal file
39
tensorflow/cc/saved_model/loader_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name);
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
@ -67,13 +67,14 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
] + if_llvm_aarch64_available([
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -94,8 +95,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:resource_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support", # fixdeps: keep
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:Support", # fixdeps: keep
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -109,12 +110,12 @@ cc_library(
|
||||
name = "llvm_targets",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:arm_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:x86_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"@llvm-project//llvm:aarch64_target", # fixdeps: keep
|
||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
@ -286,9 +287,9 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:ir",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:target_base",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
|
@ -1,5 +1,5 @@
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
|
||||
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
|
||||
|
||||
# Checks the error message produced by tfcompile with mlir_component
|
||||
# Checks that source debug information is used in the output error message and
|
||||
|
@ -4,6 +4,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,9 +13,11 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
col: 1
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -23,6 +26,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
MlirCommonFlags* mlir_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
absl::once_flag flags_init;
|
||||
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names.")});
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
std::vector<string> tensor_names;
|
||||
};
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
@ -108,8 +108,7 @@ class XlaExecutableClosure {
|
||||
explicit XlaExecutableClosure(
|
||||
xla::LocalClient* client, xla::LocalExecutable* executable,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
std::map<int, OptionalTensor> resource_var_snapshots,
|
||||
int num_constant_args)
|
||||
ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
|
||||
: client_(client),
|
||||
executable_(executable),
|
||||
compilation_result_(compilation_result),
|
||||
@ -124,7 +123,7 @@ class XlaExecutableClosure {
|
||||
const XlaCompiler::CompilationResult* compilation_result() const {
|
||||
return compilation_result_;
|
||||
}
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots() const {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots() const {
|
||||
return resource_var_snapshots_;
|
||||
}
|
||||
int num_constant_args() const { return num_constant_args_; }
|
||||
@ -133,7 +132,7 @@ class XlaExecutableClosure {
|
||||
xla::LocalClient* client_;
|
||||
xla::LocalExecutable* executable_;
|
||||
const XlaCompiler::CompilationResult* compilation_result_;
|
||||
std::map<int, OptionalTensor> resource_var_snapshots_;
|
||||
ResourceVarsSnapshot resource_var_snapshots_;
|
||||
int num_constant_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
|
||||
@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx,
|
||||
|
||||
static Status CompileToLocalExecutable(
|
||||
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
|
||||
const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
|
||||
const XlaPlatformInfo& platform_info,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
|
||||
std::map<int, OptionalTensor>* variables,
|
||||
const XlaCompiler::CompilationResult** kernel,
|
||||
const XlaCompiler::CompilationResult** compilation_result,
|
||||
xla::LocalExecutable** executable) {
|
||||
// We store information about the JIT-compiled XLA computation
|
||||
// in the ResourceMgr.
|
||||
@ -299,7 +298,6 @@ static Status CompileToLocalExecutable(
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
@ -337,11 +335,11 @@ static Status CompileToLocalExecutable(
|
||||
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_args, *variables, ctx, &args));
|
||||
constant_args, variable_infos, ctx, &args));
|
||||
return cache->Compile(options, function, args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
: XlaCompilationCache::CompileMode::kStrict,
|
||||
kernel, executable);
|
||||
compilation_result, executable);
|
||||
}
|
||||
|
||||
void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
|
||||
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
const XlaCompiler::CompilationResult* compilation_result;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
|
||||
ResourceVarsSnapshot variables;
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status s = CompileToLocalExecutable(
|
||||
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
|
||||
resources_, constants_, /*lazy=*/false, &client, &variables, &kernel,
|
||||
&executable);
|
||||
variable_infos, constants_, /*lazy=*/false, &client,
|
||||
&compilation_result, &executable);
|
||||
OP_REQUIRES_OK(ctx, s);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
}
|
||||
|
||||
se::Stream* stream =
|
||||
@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
client, allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, kernel, variables,
|
||||
launch_context.PopulateInputs(ctx, compilation_result, variables,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
// Execute the computation.
|
||||
@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
OP_REQUIRES_OK(
|
||||
ctx, launch_context.PopulateOutputs(
|
||||
ctx, kernel, run_result.ConsumeValueOrDie(),
|
||||
ctx, compilation_result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
VLOG(1) << "Done";
|
||||
}
|
||||
@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalClient* client;
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
std::map<int, OptionalTensor> variables;
|
||||
ResourceVarsSnapshot variables;
|
||||
|
||||
bool cannot_compile_cluster;
|
||||
{
|
||||
@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
|
||||
cannot_compile_cluster) {
|
||||
executable = nullptr;
|
||||
} else {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
|
||||
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
|
||||
Status status = CompileToLocalExecutable(
|
||||
ctx, function_, has_ref_vars_, platform_info_, resources_, constants_,
|
||||
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
|
||||
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
|
||||
constants_,
|
||||
/*lazy=*/!must_compile_, &client, &kernel, &executable);
|
||||
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
|
||||
variable_infos, &variables));
|
||||
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
|
||||
OP_REQUIRES_OK(ctx, status);
|
||||
}
|
||||
|
@ -1837,7 +1837,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
"ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse",
|
||||
"ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV",
|
||||
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack"}}};
|
||||
"Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}};
|
||||
// clang-format on
|
||||
return result;
|
||||
}
|
||||
|
@ -28,32 +28,23 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
|
||||
std::map<int, OptionalTensor> variables;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
// Returns argument indices corresponding to the resource variable inputs of
|
||||
// kernel context `ctx`.
|
||||
static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
|
||||
std::vector<int> out;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); i++) {
|
||||
if (ctx->input(i).dtype() == DT_RESOURCE) {
|
||||
core::RefCountPtr<Var> variable;
|
||||
ResourceHandle handle = HandleFromInput(ctx, i);
|
||||
OptionalTensor& optional = variables[i];
|
||||
optional.name = handle.name();
|
||||
if (LookupResource(ctx, handle, &variable).ok()) {
|
||||
tf_shared_lock lock(*variable->mu());
|
||||
optional.present = true;
|
||||
optional.value = *variable->tensor();
|
||||
}
|
||||
out.push_back(i);
|
||||
}
|
||||
}
|
||||
return variables;
|
||||
return out;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable) {
|
||||
std::map<int, OptionalTensor> variables = GetVariables(ctx);
|
||||
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args) {
|
||||
xla::LocalClient* client = metadata.client();
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
|
||||
launch_context.PopulateInputs(ctx, result, variables,
|
||||
launch_context.PopulateInputs(ctx, result, variable_args,
|
||||
/*missing_ctx_input_prefix=*/0);
|
||||
|
||||
se::Stream* stream =
|
||||
@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
executable->executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
|
||||
ctx, result, run_result.ConsumeValueOrDie(),
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variables));
|
||||
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
|
||||
Status XlaCompileOnDemandOp::Compile(
|
||||
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
xla::LocalExecutable** executable) {
|
||||
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
|
||||
std::map<int, Tensor> constant_arguments;
|
||||
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& device_tensor = ctx->input(i);
|
||||
@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
// rather than a one-element tuple.
|
||||
compile_options.always_return_tuple = false;
|
||||
|
||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||
|
||||
std::vector<int> variables_indices = GetResourceVariableIndices(ctx);
|
||||
std::vector<XlaCompiler::Argument> args;
|
||||
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_args, ctx, &args));
|
||||
{
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arguments, variable_infos, ctx, &args));
|
||||
}
|
||||
|
||||
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
|
||||
executable);
|
||||
@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
xla::LocalExecutable* executable;
|
||||
const XlaDevice::Metadata* metadata;
|
||||
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
|
||||
OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
|
||||
ResourceVarsSnapshot variable_args;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Compile(ctx, *metadata, &result, &variable_args, &executable));
|
||||
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_device.h"
|
||||
#include "tensorflow/compiler/jit/xla_launch_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel {
|
||||
bool* result);
|
||||
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult** result,
|
||||
ResourceVarsSnapshot* variable_args,
|
||||
xla::LocalExecutable** executable);
|
||||
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
|
||||
const XlaCompiler::CompilationResult* result,
|
||||
xla::LocalExecutable* executable);
|
||||
xla::LocalExecutable* executable,
|
||||
const ResourceVarsSnapshot& variable_args);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] =
|
||||
"resource inputs to XLA.";
|
||||
} // anonymous namespace
|
||||
|
||||
VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
|
||||
VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
|
||||
: index_(index), name_(name), var_(var) {}
|
||||
VariableInfo::VariableInfo(VariableInfo&& other)
|
||||
: index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
|
||||
other.index_ = -1;
|
||||
@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() {
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
static Status GetVariableInfosFromCtxInputs(
|
||||
OpKernelContext* ctx, absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result) {
|
||||
std::vector<const ResourceHandle*> resource_handles;
|
||||
absl::c_transform(
|
||||
variable_indices, std::back_inserter(resource_handles),
|
||||
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
|
||||
|
||||
std::vector<core::RefCountPtr<Var>> variables;
|
||||
|
||||
Status s = LookupResources(ctx, resource_handles, &variables);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
|
||||
@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs(
|
||||
// *Release* the variable because we're going to unref it later in
|
||||
// ~VariableInfo.
|
||||
Var* variable = variables[i].release();
|
||||
result->emplace_back(variable_indices[i], variable);
|
||||
int input_idx = variable_indices[i];
|
||||
std::string var_name = HandleFromInput(ctx, input_idx).name();
|
||||
result->emplace_back(input_idx, var_name, variable);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -162,21 +164,12 @@ Status LockVariables(absl::Span<VariableInfo> variables) {
|
||||
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result) {
|
||||
for (int i = 0; i < variable_indices.size(); i++) {
|
||||
if (variable_infos[i].var()) {
|
||||
OptionalTensor& tensor = (*result)[variable_indices[i]];
|
||||
tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
|
||||
tensor.present = true;
|
||||
tensor.value = *variable_infos[i].var()->tensor();
|
||||
} else {
|
||||
(*result)[variable_indices[i]] = OptionalTensor();
|
||||
}
|
||||
Var* var = variable_infos[i].var();
|
||||
(*result)[variable_indices[i]] =
|
||||
var ? absl::make_optional(*var->tensor()) : absl::nullopt;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -195,53 +188,45 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->input_mapping[i];
|
||||
DCHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
||||
if (variables.count(arg_num)) {
|
||||
t = &(variables.at(arg_num).value);
|
||||
CHECK(t);
|
||||
} else {
|
||||
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
}
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value())
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||
<< "Must have a stream available when using XLA tensors!";
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(
|
||||
ctx->op_device_context()->stream());
|
||||
}
|
||||
|
||||
const xla::Shape on_device_shape =
|
||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||
if (on_device_shape.IsTuple()) {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
} else {
|
||||
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
|
||||
on_device_shape))
|
||||
<< "On-device shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
} else {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -269,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
int xla_param = input_output_alias.GetAliasedParameter({output_num})
|
||||
.value()
|
||||
@ -281,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput(
|
||||
// entry time.
|
||||
if (input_tensor->dtype() == DT_RESOURCE) {
|
||||
auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
|
||||
CHECK(v.present);
|
||||
return &v.value;
|
||||
CHECK(v.has_value());
|
||||
return &v.value();
|
||||
}
|
||||
return input_tensor;
|
||||
}
|
||||
@ -305,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput(
|
||||
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
absl::Span<const int> input_mapping,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots,
|
||||
DataType output_dtype, const TensorShape& output_shape,
|
||||
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
|
||||
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
|
||||
Allocator* output_allocator) {
|
||||
if (const Tensor* aliased_tensor = FindAliasedTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
input_mapping, resource_var_snapshots)) {
|
||||
@ -370,13 +355,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||
static Status SetOutputForConstant(
|
||||
OpKernelContext* ctx, se::Stream* stream,
|
||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor =
|
||||
compilation_result->outputs[output_num].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(output_num, const_tensor);
|
||||
output_tensor = ctx->mutable_output(output_num);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a list of updates resource variables.
|
||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
const ResourceVarsSnapshot& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
@ -384,7 +450,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
}
|
||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||
|
||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||
// buffer with a nullptr root index table. This allows the code below to treat
|
||||
@ -410,89 +476,70 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
}
|
||||
|
||||
std::vector<TensorShape> output_tensor_shapes;
|
||||
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||
if (output.on_host_shape().is_dynamic()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(i, const_tensor);
|
||||
output_tensor = ctx->mutable_output(i);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
if (compilation_result->outputs[i].is_constant) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetOutputForConstant(ctx, stream, compilation_result, i));
|
||||
} else if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
const DataType& type = kernel->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (MustAliasOutput(input_output_alias, output_num)) {
|
||||
DCHECK(output.buffer({output_num}).is_null())
|
||||
<< "Expected output buffer to be aliased, but it is not nil.";
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
} else {
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
@ -502,34 +549,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(kernel->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<VariableInfo> variable_infos,
|
||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
@ -543,7 +570,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
variable_infos[i].var()->is_initialized |= write.modified;
|
||||
@ -555,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args) {
|
||||
args->resize(ctx->num_inputs());
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
<< "Need to hold the lock on resource variables "
|
||||
"before calling BuildXlaCompilerArguments";
|
||||
variable_info_lookup.emplace(info.index(), &info);
|
||||
}
|
||||
|
||||
for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
|
||||
if (constant_args.count(input_num) > 0) {
|
||||
// Handles compile-time constants.
|
||||
const Tensor& input = constant_args.at(input_num);
|
||||
@ -569,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = input.dtype();
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
} else if (variable_args.count(input_num) == 0) {
|
||||
} else if (variable_info_lookup.count(input_num) == 0) {
|
||||
// Handles the non-constant arguments.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() != DT_RESOURCE);
|
||||
@ -585,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
// Handles resource variables.
|
||||
const Tensor& input = ctx->input(input_num);
|
||||
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
|
||||
const OptionalTensor& variable = variable_args.at(input_num);
|
||||
arg.name = variable.name;
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
arg.name = std::string(variable.name());
|
||||
arg.kind = XlaCompiler::Argument::kResource;
|
||||
arg.resource_kind = XlaResource::kVariable;
|
||||
if (variable.present) {
|
||||
const Tensor& value = variable.value;
|
||||
arg.type = value.dtype();
|
||||
arg.shape = value.shape();
|
||||
if (variable.var()) {
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
arg.type = value->dtype();
|
||||
arg.shape = value->shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
|
@ -34,36 +34,17 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Struct that represents a possibly-absent Tensor.
|
||||
struct OptionalTensor {
|
||||
string name; // A descriptive name
|
||||
bool present = false; // Is the tensor present?
|
||||
Tensor value; // If present, what is the Tensor's value?
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
// Returns a map of TensorFlow argument index to resource variable. If a
|
||||
// resource variable is not initialized, the corresponding OptionalTensor
|
||||
// will have its `present` field set to false.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::map<int, OptionalTensor>* result);
|
||||
// Snapshot of resource variables for a TF kernel invocation, mapping from
|
||||
// parameter number to values at execution time. If the resource variable is not
|
||||
// initialized, the value will not be present.
|
||||
using ResourceVarsSnapshot = absl::flat_hash_map<int, absl::optional<Tensor>>;
|
||||
|
||||
// Information about the state of a variable passed as input to the _XlaCompile
|
||||
// and _XlaRun operators. Unlocks the resource variable and decrements its
|
||||
// refcount on destruction.
|
||||
class VariableInfo {
|
||||
public:
|
||||
explicit VariableInfo(int index, Var* var);
|
||||
explicit VariableInfo(int index, absl::string_view name, Var* var);
|
||||
VariableInfo(VariableInfo&& other);
|
||||
|
||||
VariableInfo& operator=(VariableInfo&& other);
|
||||
@ -79,6 +60,9 @@ class VariableInfo {
|
||||
// "empty", i.e. it does not track a resource variable.
|
||||
Var* var() const { return var_; }
|
||||
|
||||
// Returns the variable name.
|
||||
absl::string_view name() const { return name_; }
|
||||
|
||||
// Returns true if the resource variable lock was successfully acquired by
|
||||
// this thread.
|
||||
bool lock_held() const { return lock_held_; }
|
||||
@ -88,6 +72,7 @@ class VariableInfo {
|
||||
|
||||
private:
|
||||
int index_;
|
||||
std::string name_;
|
||||
Var* var_;
|
||||
|
||||
// We can't use a optional<mutex_lock> here because it confuses the compiler's
|
||||
@ -96,6 +81,20 @@ class VariableInfo {
|
||||
bool lock_held_ = false;
|
||||
};
|
||||
|
||||
// Takes a snapshot of the values of resource variable arguments, whose indices
|
||||
// are specified in `variable_indices` argument. We snapshot tensors that back
|
||||
// resource variables since concurrent updates may modify the shape, and it is
|
||||
// important that the shapes used for compilation match the true shapes of the
|
||||
// buffers.
|
||||
//
|
||||
// We snapshot the entire set of resource variables as one atomic operation.
|
||||
// This models Read->* dependencies between resource variable operations. See
|
||||
// jit/resource_operation_safety_analysis for details.
|
||||
Status SnapshotResourceVariables(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
absl::Span<VariableInfo const> variable_infos,
|
||||
ResourceVarsSnapshot* result);
|
||||
|
||||
// Acquires the mutexes for all the variables in `variables` using a
|
||||
// deadlock-safe protocol (acquire the mutexes in increasing-address order).
|
||||
//
|
||||
@ -104,6 +103,13 @@ class VariableInfo {
|
||||
Status LockVariables(absl::Span<VariableInfo> variables)
|
||||
TF_EXCLUSIVE_LOCK_FUNCTION();
|
||||
|
||||
// Returns a vector of VariableInfo instances for the resource variable inputs
|
||||
// to the kernel with context `ctx`. The input indices for the resource
|
||||
// variable inputs are in `variable_indices`.
|
||||
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
|
||||
absl::Span<const int> variable_indices,
|
||||
std::vector<VariableInfo>* result);
|
||||
|
||||
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
|
||||
// ShapedBuffers suitable for passing to an XLA computation.
|
||||
class XlaComputationLaunchContext {
|
||||
@ -123,9 +129,10 @@ class XlaComputationLaunchContext {
|
||||
|
||||
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
|
||||
// op.
|
||||
// Precondition: variables in `variable_args` are locked.
|
||||
static Status BuildXlaCompilerArguments(
|
||||
const std::map<int, Tensor>& constant_args,
|
||||
const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
|
||||
absl::Span<VariableInfo const> variable_args, OpKernelContext* ctx,
|
||||
std::vector<XlaCompiler::Argument>* args);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
@ -136,8 +143,8 @@ class XlaComputationLaunchContext {
|
||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const ResourceVarsSnapshot& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
|
||||
@ -148,13 +155,14 @@ class XlaComputationLaunchContext {
|
||||
// See jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the
|
||||
// compilation_result are missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
const ResourceVarsSnapshot& resource_var_snapshots);
|
||||
|
||||
// Return the argument list. Only valid after PopulateInputs() has been
|
||||
// called.
|
||||
|
@ -27,10 +27,6 @@ namespace tensorflow {
|
||||
return xla_tensor;
|
||||
}
|
||||
|
||||
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
|
||||
return tensor.RefCountIsOne();
|
||||
}
|
||||
|
||||
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaTensor* xla_tensor = FromTensor(&tensor);
|
||||
|
@ -39,8 +39,6 @@ class XlaTensor {
|
||||
// fails.
|
||||
static XlaTensor* FromTensor(const Tensor* tensor);
|
||||
|
||||
static bool RefCountIsOne(const Tensor& tensor);
|
||||
|
||||
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
|
||||
// which case the returned value is shaped_buffer()->root_buffer(), or a
|
||||
// normal Tensor in which case the returned value is
|
||||
@ -57,7 +55,7 @@ class XlaTensor {
|
||||
// manage the memory for these tensors a ShapedBuffer may be required.
|
||||
|
||||
// Return true if this XlaTensor contains a ShapedBuffer.
|
||||
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
|
||||
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
|
||||
// Return the contained ShapedBuffer.
|
||||
// REQUIRES: has_shaped_buffer()
|
||||
const xla::ShapedBuffer& shaped_buffer() const {
|
||||
@ -70,8 +68,7 @@ class XlaTensor {
|
||||
}
|
||||
// Mutates the XlaTensor to set the ShapedBuffer.
|
||||
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
|
||||
shaped_buffer_ =
|
||||
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
|
||||
shaped_buffer_ = std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
// Some tensors on the device may have known values on the host. We use these
|
||||
@ -79,14 +76,12 @@ class XlaTensor {
|
||||
// host value already.
|
||||
|
||||
// Return true if this XlaTensor contains a host tensor.
|
||||
bool has_host_tensor() const { return host_tensor_ != nullptr; }
|
||||
bool has_host_tensor() const { return host_tensor_.has_value(); }
|
||||
// Return the contained host tensor.
|
||||
// REQUIRES: has_host_tensor()
|
||||
const Tensor& host_tensor() const { return *host_tensor_; }
|
||||
// Sets the contained host tensor.
|
||||
void set_host_tensor(const Tensor& tensor) {
|
||||
host_tensor_.reset(new Tensor(tensor));
|
||||
}
|
||||
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
|
||||
|
||||
// Adds synchronization events to 'stream' that wait for this tensor to be
|
||||
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
||||
@ -113,9 +108,9 @@ class XlaTensor {
|
||||
|
||||
private:
|
||||
// The optional contained ShapedBuffer.
|
||||
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
// An optional host tensor value.
|
||||
std::unique_ptr<Tensor> host_tensor_;
|
||||
absl::optional<Tensor> host_tensor_;
|
||||
// An optional event that is triggered when the tensor's content has been
|
||||
// defined. If this event is nullptr, it is assumed that the tensor's content
|
||||
// is always defined.
|
||||
|
@ -30,7 +30,7 @@ cc_library(
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -42,7 +42,7 @@ cc_library(
|
||||
":init_mlir",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
@ -86,7 +86,7 @@ cc_library(
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -102,7 +102,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -155,7 +155,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
265
tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md
Normal file
265
tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md
Normal file
@ -0,0 +1,265 @@
|
||||
# MLIR CodeGen for XLA
|
||||
|
||||
<!--*
|
||||
# Document freshness: For more information, see go/fresh-source.
|
||||
freshness: { owner: 'timshen' reviewed: '2020-06-16' }
|
||||
*-->
|
||||
|
||||
XLA operates on `HloInstruction` and performs many optimizations on this
|
||||
representation, sharing a lot of these between targeted devices. As some point a
|
||||
linear schedule is computed and the memory buffer is assigned to each value
|
||||
statically. The device specific codegen operates by traversing this sequence and
|
||||
calling "emitters" to generate a representation suitable for the device (for
|
||||
example a single LLVM function per XLA computation on CPU, or a sequence of
|
||||
"thunks" encapsulating GPU operations and possibly generated PTX when targeting
|
||||
GPU).
|
||||
|
||||
As a staging step, we're currently in the process of intercepting the process
|
||||
right after XLA completes the buffer-assignment phase and emit instead an MLIR
|
||||
module in the `lhlo` dialect. From there we perform the codegen using MLIR
|
||||
components (Linalg, affine, and GPU dialect mainly) depending on the device.
|
||||
|
||||
Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as
|
||||
the codegen input.
|
||||
|
||||
## Tasks
|
||||
|
||||
| | Host | Device
|
||||
| ------------- | ------------------------ | ------------------------
|
||||
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
|
||||
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
|
||||
|
||||
* **Task 1** changes both host and device input format from HloInstruction* to
|
||||
LHLO.
|
||||
* **Task 2** changes output format of host from thunks to "some landing pad
|
||||
for host" (see below).
|
||||
* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's
|
||||
optional to this project, and see the section "Migrating Device LLVM IR" for
|
||||
details.
|
||||
|
||||
This project prioritizes having end-to-end runnable models with LHLO-emitters
|
||||
enabled as much as possible. This implies that the following order list of
|
||||
objectives by priority:
|
||||
|
||||
* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters
|
||||
unmodified.
|
||||
* Eliminate the references to HloInstruction\* in LHLO, case by case:
|
||||
* Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or
|
||||
* Mechanically translate the existing emitter to take MLIR representation
|
||||
(migrate to Standard with GPU Dialect).
|
||||
|
||||
## Migrating Thunks (Task 2)
|
||||
|
||||
xla::gpu::Thunk is a data structure that:
|
||||
|
||||
* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()).
|
||||
* Carries various data in its subclasses.
|
||||
* Interacts with BufferAllocation::Slice and StreamExecutor.
|
||||
* Launches kernels
|
||||
* Calls into all runtime libraries.
|
||||
|
||||
The cost of that includes:
|
||||
|
||||
* Representing op-specific configuration data (e.g. convolution configs).
|
||||
* Migrating op shape and operand shapes.
|
||||
* Representing a tree of thunks (while, condition, etc).
|
||||
|
||||
The migration work is independent from LHLO / emitter migration. Under limited
|
||||
resources, it's prioritized behind LHLO / emitter migration.
|
||||
|
||||
We have several choices on how to lower the host-side part from LHLO:
|
||||
|
||||
* TFRT
|
||||
* (Pro) great CUDA and HIP wrappers for use.
|
||||
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
|
||||
TFRT ops are interpreted by C++ code.
|
||||
* (Con) host side is under development and not tested.
|
||||
* (Con) the JAX integration isn’t clear from a runtime point of view
|
||||
* Jitted CPU code
|
||||
* (Pro) great lower-ability. Create a few loops and conditions and it's
|
||||
done.
|
||||
* (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device
|
||||
allocation.
|
||||
* (Con) CUDA / HIP runtime support is minimal (toolkit path, version,
|
||||
dynamic loading, etc).
|
||||
* Existing (interpreting) XLA runtime
|
||||
|
||||
Tentative conclusion: Use jitted CPU code during the transition, and optionally
|
||||
adopt TFRT in the end.
|
||||
|
||||
## Migrating Device LLVM IR (Task 3)
|
||||
|
||||
An elemental emitter generates target op by filling it element by element. Each
|
||||
output element depends on a set of elements from the operands. All elements are
|
||||
described by combining the buffer with dynamic indices. It's sufficient to
|
||||
describe almost all "math" ops, but for performance reasons only a large subset
|
||||
of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter.
|
||||
|
||||
ElementalIrEmitter is unique in that:
|
||||
|
||||
* A large portion of the code is shared between XLA/GPU and CPU.
|
||||
* It represents a large portion of ops seen in models, including all
|
||||
element-wise ops.
|
||||
* Most fusions solely depend on ElementalIrEmitter.
|
||||
* It's structurally simple, as it describes a data dependency DAG between op
|
||||
elements and operand elements.
|
||||
* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy).
|
||||
* Dynamic shape support is easy for at least element-wise ops.
|
||||
|
||||
Now, for all ops, elementally-emitted or not, there are several flavors of the
|
||||
end state of each XLA op:
|
||||
|
||||
1. Device code stays as LLVM IR.
|
||||
1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect:
|
||||
* (Cost) Will be throw-away work if we want to ultimately migrate to
|
||||
Standard.
|
||||
* (Benefit) It is easy and mechanical. Can be done in a short period.
|
||||
* (Benefit) It doesn't benefit more compared to a).
|
||||
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
|
||||
* (Cost) Lifting existing emitters to Standard introduces some challenges.
|
||||
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
|
||||
amdgpu completeness is another one.
|
||||
* (Cost) XLA/GPU heavily relies on LLVM metadata:
|
||||
* `range` for block/thread indices.
|
||||
* `align`, `dereferenceable`, `invariant.load`, `alias.scope`,
|
||||
`noalias` for load/stores.
|
||||
* `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`,
|
||||
`llvm.loop.vectorize.enable` for sequential loops.
|
||||
* (Benefit) Can be long-term. More portable.
|
||||
1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters
|
||||
* (Cost) This is case by case. Compared to previous options, a new
|
||||
implementation that matches XLA's performance needs to go through the
|
||||
benchmark <-> optimize workflow, which can be a significant cost for
|
||||
some ops.
|
||||
* (Benefit) unified stack; community support; portability; more
|
||||
optimization potentials.
|
||||
|
||||
## Prioritization
|
||||
|
||||
While all three tasks mentioned above are parallelizable, under limited
|
||||
resources they have to be serialized. The prioritization focuses on visible
|
||||
results for completion of each task.
|
||||
|
||||
The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task
|
||||
3 (MLIR emitters).
|
||||
|
||||
By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator)
|
||||
and execute them. The compilation format will not be serializable MLIR.
|
||||
|
||||
By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables
|
||||
offline compilation.
|
||||
|
||||
By the end of Task 3, all XLA emitters are MLIR-based in its implementation.
|
||||
|
||||
## Detailed Design
|
||||
|
||||
### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO
|
||||
|
||||
This step makes all existing XLA/GPU emitters interact with MLIR ops. This step
|
||||
is pure refactoring and NFC.
|
||||
|
||||
This step is mostly mechanical, but it's worth noticing the following
|
||||
discrepancies between an unnested HloComputation and LHLO:
|
||||
|
||||
* Each HloInstruction has direct access to its operands (a data-flow DAG). On
|
||||
contrary, each LHLO op only has access to its operand buffers (a bipartite
|
||||
between ops and buffers). LHLO ops have to go through use-def chains to
|
||||
access their operand ops.
|
||||
* Unnested legacy emitters empirically almost never access their operands. The
|
||||
only exception is kReduce.
|
||||
* Unnested legacy emitters access BufferAssignment only for getting slices,
|
||||
not for accessing aux data structures like dataflow\_analysis() or
|
||||
alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice
|
||||
information.
|
||||
|
||||
The conclusion is that LHLO should fit right-in without major hassle.
|
||||
|
||||
### Step 2: (Optional) Profiling Support
|
||||
|
||||
**This step is only needed if we start to discard some of the XLA Thunk logic
|
||||
(see the next step).**
|
||||
|
||||
Before actually turning on any MLIR-based emitters, we need profiling for
|
||||
MLIR-based emitters.
|
||||
|
||||
Currently XLA performs its own profiling by calling into StreamExecutor's timer.
|
||||
The timer under the hood inserts two events before and after a kernel launch,
|
||||
and measures the sync time between these two events.
|
||||
|
||||
There are roughly three approaches to support profiling in MLIR:
|
||||
|
||||
* Run a profiler end-to-end
|
||||
* Add a profile op for each op in LHLO, using an injected profiler.
|
||||
|
||||
The "end-to-end" approach is transparent to MLIR, but suffers the same problem
|
||||
that makes XLA not use it in the first place: library calls collected by a
|
||||
profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN
|
||||
launches multiple kernels for each HLO, and it's hard to tell which kernels
|
||||
correspond to which HLO.
|
||||
|
||||
The "injected profiler" approach requires:
|
||||
|
||||
* LHLO to take a profiler as a parameter.
|
||||
* inserting profile.start / profile.end before and after each op.
|
||||
* a pass from that lowers profile.{start,end} to a C++ implementation.
|
||||
|
||||
The exact profiling can't be easily done for MLIR-generated ops, since:
|
||||
|
||||
* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor.
|
||||
* MLIR doesn't easily call into C functions with complicated parameters.
|
||||
|
||||
### Step 3: (Task 2) Migrating Thunks
|
||||
|
||||
This step migrates all host ops and library calls. This step will eliminate most
|
||||
of the thunks and produce serializable MLIR instead.
|
||||
|
||||
There are roughly three kinds of thunks:
|
||||
|
||||
* KernelThunk, which launches a kernel.
|
||||
* Control flow thunks, which has host control flow logic (conditional, while,
|
||||
for, sequence) and launch body kernels.
|
||||
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
|
||||
|
||||
The **bottom line** is to:
|
||||
|
||||
* Create a Thunk dialect that provides (de)serialize logic for all existing
|
||||
C++-based Thunks.
|
||||
* Change emitters to emit a graph of Thunk dialect.
|
||||
|
||||
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
|
||||
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
|
||||
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
|
||||
step requires profiling and stream support.
|
||||
|
||||
### Step 4: (Task 3) Migrated ElementalIrEmitter
|
||||
|
||||
Once profiling is ready, we can complete and tune all ElementalIrEmitter-based
|
||||
emitters in MLIR. Then we turn them on by default, assuming that all of these
|
||||
MLIR-based emitters use a single stream.
|
||||
|
||||
Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well,
|
||||
since they share a large portion of the code.
|
||||
|
||||
With all benchmarking and performance hunting done (TODO: define performance
|
||||
parity), we turn on the new MLIR-based elemental emitter, and delete the legacy
|
||||
ElementalIrEmitter.
|
||||
|
||||
This step also provides easy fusion transitions (nested ops) for the later
|
||||
migration.
|
||||
|
||||
### Step 5: Multi-Stream Support or Drop
|
||||
|
||||
We can't delete
|
||||
[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140)
|
||||
until we support it in MLIR, or we drop the feature. It's a relatively large
|
||||
amount of work in MLIR and a small amount of gain for XLA. We should investigate
|
||||
current users of multi-stream XLA/GPU users, and try to delete this feature if
|
||||
reasonable.
|
||||
|
||||
### Step 6: (Task 3) Migrated Device Ops
|
||||
|
||||
This step migrates all unnested ops, then we can delete all unnested emitters.
|
||||
|
||||
This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already
|
||||
worked on for plenty, so the actual amount of work that needs to be done remains
|
||||
to be seen.
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user