Merge remote-tracking branch 'upstream/master'

This commit is contained in:
ather 2020-12-28 22:29:10 -05:00
commit a22a29347c
27106 changed files with 4705276 additions and 891713 deletions

612
.bazelrc Normal file
View File

@ -0,0 +1,612 @@
# TensorFlow Bazel configuration file.
# This file tries to group and simplify build options for TensorFlow
#
# ----CONFIG OPTIONS----
# Android options:
# android:
# android_arm:
# android_arm64:
# android_x86:
# android_x86_64:
#
# iOS options:
# ios:
# ios_armv7:
# ios_arm64:
# ios_i386:
# ios_x86_64:
# ios_fat:
#
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options (links with libc++)
# c++1z: Build with C++17 options (links with libc++)
# c++17_gcc: Build with C++17 options (links with stdlibc++)
# c++1z_gcc: Build with C++17 options (links with stdlibc++)
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# native_arch_linux: Build with instruction sets available to the host machine on linux
# avx_win: Build with avx instruction set on windows
# avx2_win: Build with avx2 instruction set on windows
#
# Other build options:
# short_logs: Only log errors during build, skip warnings.
# verbose_logs: Show all compiler warnings during build.
# monolithic: Build all TF C++ code into a single shared object.
# dynamic_kernels: Try to link all kernels dynamically (experimental).
# libc++: Link against libc++ instead of stdlibc++
#
#
# TF version options;
# v1: Build TF V1 (without contrib)
# v2: Build TF v2
#
# Feature and Third party library support options:
# xla: Build TF with XLA
# tpu: Build TF with TPU support
# using_cuda: CUDA is available to build system.
# cuda: Build with full cuda support.
# rocm: Build with AMD GPU support (rocm).
# mkl: Enable full mkl support.
# tensorrt: Enable Tensorrt support.
# numa: Enable numa using hwloc.
# noaws: Disable AWS S3 storage support
# nogcp: Disable GCS support.
# nohdfs: Disable hadoop hdfs support.
# nonccl: Disable nccl support.
#
#
# Remote build execution options (only configured to work with TF team projects for now.)
# rbe: General RBE options shared by all flavors.
# rbe_linux: General RBE options used on all linux builds.
# rbe_win: General RBE options used on all windows builds.
#
# rbe_cpu_linux: RBE options to build with only CPU support.
# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc.
#
# rbe_linux_py2: Linux Python 2 RBE config.
# rbe_linux_py3: Linux Python 3 RBE config
#
# rbe_win_py37: Windows Python 3.7 RBE config
# rbe_win_py38: Windows Python 3.8 RBE config
#
# tensorflow_testing_rbe_linux: RBE options to use RBE with tensorflow-testing project on linux
# tensorflow_testing_rbe_win: RBE options to use RBE with tensorflow-testing project on windows
#
# Embedded Linux options (experimental and only tested with TFLite build yet)
# elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
#
# Release build options (for all operating systems)
# release_common: Common options for all builds on all operating systems.
# release_windows_common: Common options for all builds on Windows.
# release_gpu_common: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
build:libc++ --action_env=CC
build:libc++ --action_env=CXX
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
build:libc++ --action_env=PATH
build:libc++ --define force_libcpp=enabled
build:libc++ --linkopt -fuse-ld=lld
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
# target CPU to build transient dependencies correctly. See
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
build:android --crosstool_top=//external:android/crosstool
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:android_arm --config=android
build:android_arm --cpu=armeabi-v7a
build:android_arm --fat_apk_cpu=armeabi-v7a
build:android_arm64 --config=android
build:android_arm64 --cpu=arm64-v8a
build:android_arm64 --fat_apk_cpu=arm64-v8a
build:android_x86 --config=android
build:android_x86 --cpu=x86
build:android_x86 --fat_apk_cpu=x86
build:android_x86_64 --config=android
build:android_x86_64 --cpu=x86_64
build:android_x86_64 --fat_apk_cpu=x86_64
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
# iOS configs for each architecture and the fat binary builds.
build:ios --apple_platform_type=ios
build:ios --apple_bitcode=embedded --copt=-fembed-bitcode
build:ios --copt=-Wno-c++11-narrowing
build:ios_armv7 --config=ios
build:ios_armv7 --cpu=ios_armv7
build:ios_arm64 --config=ios
build:ios_arm64 --cpu=ios_arm64
build:ios_i386 --config=ios
build:ios_i386 --cpu=ios_i386
build:ios_x86_64 --config=ios
build:ios_x86_64 --cpu=ios_x86_64
build:ios_fat --config=ios
build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64
# Config to use a mostly-static build and disable modular op registration
# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
# By default, TensorFlow will build with a dependence on
# //tensorflow:libtensorflow_framework.so.
build:monolithic --define framework_shared_object=false
# For projects which use TensorFlow as part of a Bazel build process, putting
# nothing in a bazelrc will default to a monolithic build. The following line
# opts in to modular op registration support by default.
build --define framework_shared_object=true
# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1
build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain
build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# Please note that MKL on MacOS or windows is still not supported.
# If you would like to use a local MKL instead of downloading, please set the
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_openmp=true
build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only --define=build_with_openmp=true
build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm.
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_aarch64 --define=build_with_mkl_opensource=true
build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# Enable the mlir generated GPU kernels only for cuda builds.
build --define=tensorflow_enable_mlir_generated_gpu_kernels=0
# This is a more specific option, so it takes precedence over the line above for cuda builds.
build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
# This config refers to building CUDA op kernels with clang.
build:cuda_clang --config=using_cuda
build:cuda_clang --define=using_cuda_clang=true
build:cuda_clang --define=using_clang=true
build:cuda_clang --action_env TF_CUDA_CLANG=1
# dbg config, as a shorthand for '--config=opt -c dbg'
build:dbg --config=opt -c dbg
# for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
# AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498
build:dbg --copt -DDEBUG_BUILD
# Config to build TPU backend
build:tpu --define=with_tpu_support=true
build:tensorrt --action_env TF_NEED_TENSORRT=1
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --action_env TF_NEED_ROCM=1
# Options extracted from configure script
build:numa --define=with_numa_support=true
# Options to disable default on features
build:noaws --define=no_aws_support=true
build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nonccl --define=no_nccl_support=true
build:stackdriver_support --define=stackdriver_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --spawn_strategy=standalone
build -c opt
# Make Bazel print out all options from rc files.
build --announce_rc
# Other build flags.
build --define=grpc_no_ares=true
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
# --incompatible_remove_legacy_whole_archive flag does.
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
# Tensorflow to the default, however test coverage wasn't enough to catch the
# errors.
# There is ongoing work on Bazel team's side to provide support for transitive
# shared libraries. As part of migrating to transitive shared libraries, we
# hope to provide a better mechanism for control over symbol exporting, and
# then tackle this issue again.
#
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
# Build TF with C++ 17 features.
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
build:android --noenable_platform_specific_config
build:ios --noenable_platform_specific_config
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:android --copt=-w
build:ios --copt=-w
build:linux --copt=-w
build:linux --host_copt=-w
build:macos --copt=-w
build:windows --copt=/W0
# Tensorflow uses M_* math constants that only get defined by MSVC headers if
# _USE_MATH_DEFINES is defined.
build:windows --copt=/D_USE_MATH_DEFINES
build:windows --host_copt=/D_USE_MATH_DEFINES
# Default paths for TF_SYSTEM_LIBS
build:linux --define=PREFIX=/usr
build:linux --define=LIBDIR=$(PREFIX)/lib
build:linux --define=INCLUDEDIR=$(PREFIX)/include
build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
build:macos --define=PREFIX=/usr
build:macos --define=LIBDIR=$(PREFIX)/lib
build:macos --define=INCLUDEDIR=$(PREFIX)/include
build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include
# TF_SYSTEM_LIBS do not work on windows.
# By default, build TF in C++ 14 mode.
build:android --cxxopt=-std=c++14
build:android --host_cxxopt=-std=c++14
build:ios --cxxopt=-std=c++14
build:ios --host_cxxopt=-std=c++14
build:linux --cxxopt=-std=c++14
build:linux --host_cxxopt=-std=c++14
build:macos --cxxopt=-std=c++14
build:macos --host_cxxopt=-std=c++14
build:windows --cxxopt=/std:c++14
build:windows --host_cxxopt=/std:c++14
# On windows, we still link everything into a single DLL.
build:windows --config=monolithic
# On linux, we dynamically link small amount of kernels
build:linux --config=dynamic_kernels
# Make sure to include as little of windows.h as possible
build:windows --copt=-DWIN32_LEAN_AND_MEAN
build:windows --host_copt=-DWIN32_LEAN_AND_MEAN
build:windows --copt=-DNOGDI
build:windows --host_copt=-DNOGDI
# MSVC (Windows): Standards-conformant preprocessor mode
# See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview
build:windows --copt=/experimental:preprocessor
build:windows --host_copt=/experimental:preprocessor
# Misc build options we need for windows.
build:windows --linkopt=/DEBUG
build:windows --host_linkopt=/DEBUG
build:windows --linkopt=/OPT:REF
build:windows --host_linkopt=/OPT:REF
build:windows --linkopt=/OPT:ICF
build:windows --host_linkopt=/OPT:ICF
build:windows --experimental_strict_action_env=true
# Verbose failure logs when something goes wrong
build:windows --verbose_failures
# On windows, we never cross compile
build:windows --distinct_host_configuration=false
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:verbose_logs --output_filter=
build --config=short_logs
# Instruction set optimizations
# TODO(gunan): Create a feature in toolchains for avx/avx2 to
# avoid having to define linux/win separately.
build:avx_linux --copt=-mavx
build:avx_linux --host_copt=-mavx
build:avx2_linux --copt=-mavx2
build:native_arch_linux --copt=-march=native
build:avx_win --copt=/arch=AVX
build:avx2_win --copt=/arch=AVX2
# Options to build TensorFlow 1.x or 2.x.
build:v1 --define=tf_api_version=1
build:v2 --define=tf_api_version=2
build:v1 --action_env=TF2_BEHAVIOR=0
build:v2 --action_env=TF2_BEHAVIOR=1
build --config=v2
test --config=v2
# Enable XLA
build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS
# Options when using remote execution
# WARNING: THESE OPTIONS WONT WORK IF YOU DO NOT HAVE PROPER AUTHENTICATION AND PERMISSIONS
# Flag to enable remote config
common --experimental_repo_remote_exec
build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1
build:rbe --google_default_credentials
build:rbe --bes_backend=buildeventservice.googleapis.com
build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations"
build:rbe --bes_timeout=600s
build:rbe --define=EXECUTOR=remote
build:rbe --distinct_host_configuration=false
build:rbe --flaky_test_attempts=3
build:rbe --jobs=200
build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:rbe --remote_timeout=3600
build:rbe --spawn_strategy=remote,worker,standalone,local
test:rbe --test_env=USER=anon
# Attempt to minimize the amount of data transfer between bazel and the remote
# workers:
build:rbe --remote_download_toplevel
build:rbe_linux --config=rbe
build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin"
build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8
build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8
build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8
# Non-rbe settings we should include because we do not run configure
build:rbe_linux --config=xla
build:rbe_linux --config=avx_linux
build:rbe_linux --config=short_logs
# TODO(gunan): Check why we need this specified in rbe, but not in other builds.
build:rbe_linux --linkopt=-lrt
build:rbe_linux --host_linkopt=-lrt
build:rbe_linux --linkopt=-lm
build:rbe_linux --host_linkopt=-lm
build:rbe_cpu_linux --config=rbe_linux
build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain"
build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform"
build:rbe_linux_cuda_base --config=rbe_linux
build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1
build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10
build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1
build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt"
build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl"
build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7"
build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5"
build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6"
build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7"
build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8"
# Map default to CUDA 11 for PY35 and greater.
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8
# Deprecated configs that people might still use.
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36
build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true
build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
# ROCm
build:rbe_linux_rocm_base --config=rbe_linux
build:rbe_linux_rocm_base --repo_env=TF_NEED_ROCM=1
build:rbe_linux_rocm_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain"
build:rbe_linux_rocm_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm//crosstool:toolchain-linux-x86_64"
build:rbe_linux_rocm_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --platforms="@ubuntu18.04-gcc7_manylinux2010-rocm_config_platform//:platform"
build:rbe_linux_rocm_base --action_env=TF_ROCM_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_rocm"
build:rbe_linux_rocm_base --define=using_rocm_hipcc=true
build:rbe_linux_rocm_py2.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python2.7"
build:rbe_linux_rocm_py3.5 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.5"
build:rbe_linux_rocm_py3.6 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.6"
build:rbe_linux_rocm_py3.7 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.7"
build:rbe_linux_rocm_py3.8 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-rocm_config_python3.8"
# Linux CPU
build:rbe_linux_py2 --config=rbe_linux
build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2"
build:rbe_linux_py2 --python_path="/usr/bin/python2"
build:rbe_linux_py2 --repo_env=TF_PYTHON_CONFIG_REPO="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/py"
build:rbe_linux_py3 --config=rbe_linux
build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --host_platform="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"
build:rbe_win --shell_executable=C:\\tools\\msys64\\usr\\bin\\bash.exe
# TODO(gunan): Remove once we use MSVC 2019 with latest patches.
build:rbe_win --define=override_eigen_strong_inline=true
build:rbe_win --jobs=100
build:rbe_win_py37 --config=rbe
build:rbe_win_py37 --repo_env=TF_PYTHON_CONFIG_REPO="@windows_py37_config_python"
build:rbe_win_py37 --python_path=C:\\Python37\\python.exe
build:rbe_win_py38 --config=rbe
build:rbe_win_py38 --repo_env=PYTHON_BIN_PATH=C:\\Python38\\python.exe
build:rbe_win_py38 --repo_env=PYTHON_LIB_PATH=C:\\Python38\\lib\\site-packages
build:rbe_win_py38 --repo_env=TF_PYTHON_CONFIG_REPO=@org_tensorflow//third_party/toolchains/preconfig/win_1803/py38
build:rbe_win_py38 --python_path=C:\\Python38\\python.exe
# These you may need to change for your own GCP project.
build:tensorflow_testing_rbe --project_id=tensorflow-testing
common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance
build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe
build:tensorflow_testing_rbe_linux --config=rbe
build:tensorflow_testing_rbe_linux --config=rbe_linux
common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows
build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe
# TFLite build configs for generic embedded Linux
build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain
build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:elinux_aarch64 --config=elinux
build:elinux_aarch64 --cpu=aarch64
build:elinux_armhf --config=elinux
build:elinux_armhf --cpu=armhf
# END TF REMOTE BUILD EXECUTION OPTIONS
# Default options should come above this line
# Options from ./configure
try-import %workspace%/.tf_configure.bazelrc
# Put user-specific options in .bazelrc.user
try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_common --config=opt
build:release_common --config=v2
build:release_common --distinct_host_configuration=false
build:release_common --action_env TF_CONFIGURE_IOS="0"
build:release_cpu_linux --config=release_common
build:release_cpu_linux --config=avx_linux
# We use the same toolchain for CPU/GPU packages.
# Did not add this to the defaults in case this changes.
build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_cpu_macos --config=release_common
build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0"
build:release_gpu_common --action_env=TF_CUDA_VERSION="11"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="8"
build:release_gpu_common --action_env=TF_NEED_TENSORRT="1"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
# First available in VS 16.4. Speeds Windows compile times by a lot. See
# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
build:release_windows_common --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common
build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux
build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10"
build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7"

1
.bazelversion Normal file
View File

@ -0,0 +1 @@
3.7.2

42
.github/ISSUE_TEMPLATE/00-bug-issue.md vendored Normal file
View File

@ -0,0 +1,42 @@
---
name: Bug Issue
about: Use this template for reporting a bug
labels: 'type:bug'
---
<em>Please make sure that this is a bug. As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

View File

@ -0,0 +1,30 @@
---
name: Build/Installation Issue
about: Use this template for build/installation issues
labels: 'type:build/install'
---
<em>Please make sure that this is a build/installation issue. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:build_template</em>
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version:
- Python version:
- Installed using virtualenv? pip? conda?:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
**Describe the problem**
**Provide the exact sequence of commands / steps that you executed before running into the problem**
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -0,0 +1,60 @@
---
name: Documentation Issue
about: Use this template for documentation related issues
labels: 'type:docs'
---
Thank you for submitting a TensorFlow documentation issue. Per our GitHub
policy, we only address code/doc bugs, performance issues, feature requests, and
build/installation issues on GitHub.
The TensorFlow docs are open source! To get involved, read the documentation
contributor guide: https://www.tensorflow.org/community/contribute/docs
## URL(s) with the issue:
Please provide a link to the documentation entry, for example:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/MyMethod
## Description of issue (what needs changing):
### Clear description
For example, why should someone use this method? How is it useful?
### Correct links
Is the link to the source code correct?
### Parameters defined
Are all parameters defined and formatted correctly?
### Returns defined
Are return values defined?
### Raises listed and defined
Are the errors defined? For example,
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/feature_column/categorical_column_with_vocabulary_file#raises
### Usage example
Is there a usage example?
See the API guide: https://www.tensorflow.org/community/contribute/docs_ref
on how to write testable usage examples.
### Request visuals, if applicable
Are there currently visuals? If not, will it clarify the content?
### Submit a pull request?
Are you planning to also submit a pull request to fix the issue? See the docs
contributor guide: https://www.tensorflow.org/community/contribute/docs,
docs API guide: https://www.tensorflow.org/community/contribute/docs_ref and the
docs style guide: https://www.tensorflow.org/community/contribute/docs_style

View File

@ -0,0 +1,23 @@
---
name: Feature Request
about: Use this template for raising a feature request
labels: 'type:feature'
---
<em>Please make sure that this is a feature request. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md), we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template</em>
**System information**
- TensorFlow version (you are using):
- Are you willing to contribute it (Yes/No):
**Describe the feature and the current behavior/state.**
**Will this change the current api? How?**
**Who will benefit with this feature?**
**Any Other info.**

View File

@ -0,0 +1,30 @@
---
name: TensorFlow Lite Op Request
about: Use this template for reporting Lite ops you are using or missing
labels: 'comp:lite'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- TensorFlow version (or github SHA if from source):
**Provide the text output from tflite_convert**
```
# Copy and paste here
```
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

View File

@ -0,0 +1,14 @@
---
name: Other Issues
about: Use this template for any other non-support related issues
labels: 'type:others'
---
This template is for miscellaneous issues not covered by the other issue categories.
For questions on how to work with TensorFlow, or support for problems that are not verified bugs in TensorFlow, please go to [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow).
If you are reporting a vulnerability, please use the [dedicated reporting process](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md).
For high-level discussions about TensorFlow, please post to discuss@tensorflow.org, for questions about the development or internal workings of TensorFlow, or if you would like to know how to contribute to TensorFlow, please post to developers@tensorflow.org.

View File

@ -0,0 +1,46 @@
---
name: TensorFlow Lite New Converter Issue
about: Use this template for reporting issues during model conversion to TFLite
labels: 'TFLiteConverter'
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- TensorFlow version (or github SHA if from source):
**Command used to run the converter or code if youre using the Python API**
If possible, please share a link to Colab/Jupyter/any notebook.
```
# Copy and paste here the exact command
```
**The output from the converter invocation**
```
# Copy and paste the output here.
```
**Also, please include a link to the saved model or GraphDef**
```
# Put link here or attach to the issue.
```
**Failure details**
If the conversion is successful, but the generated model is wrong,
state what is wrong:
- Producing wrong results and/or decrease in accuracy
- Producing correct results, but the model is slower than expected (model generated from old converter)
**RNN conversion support**
If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -0,0 +1,19 @@
---
name: TensorFlow Lite for Microcontrollers Issue
about: Use this template for reporting issues with TensorFlow Lite for microcontrollers
labels: 'comp:micro'
---
@tensorflow/micro
**System information**
- Host OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- Tensorflow version (commit SHA if source):
- Target platform (e.g. Arm Mbed OS, Arduino Nano 33 etc.):
**Describe the problem**
**Please provide the exact sequence of commands/steps when you ran into the problem**

View File

@ -0,0 +1,42 @@
---
name: Performance Issue
about: Use this template for reporting a performance issue
labels: 'type:performance'
---
<em>Please make sure that this is an issue related to performance of TensorFlow.
As per our
[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:performance_template</em>
**System information**
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):
- TensorFlow version (use command below):
- Python version:
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture
[script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh)
You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"`
**Describe the current behavior**
**Describe the expected behavior**
**Standalone code to reproduce the issue**
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
**Other info / logs** Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

110
.github/bot_config.yml vendored Normal file
View File

@ -0,0 +1,110 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# A list of assignees
assignees:
- amahendrakar
- ravikyram
- Saduf2019
# A list of assignees for compiler folder
compiler_assignees:
- joker-eph
# filesystem path
filesystem_path:
- tensorflow/c/experimental/filesystem
# security path
security_path:
- tensorflow/security
# words checklist
segfault_memory:
- segfault
- memory leaks
# assignees
filesystem_security_assignee:
- mihaimaruseac
tflite_micro_path:
- tensorflow/lite/micro
tflite_micro_comment: >
Thanks for contributing to TensorFlow Lite Micro.
To keep this process moving along, we'd like to make sure that you have completed the items on this list:
* Read the [contributing guidelines for TensorFlow Lite Micro](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/micro/CONTRIBUTING.md)
* Created a [TF Lite Micro Github issue](https://github.com/tensorflow/tensorflow/issues/new?labels=comp%3Amicro&template=70-tflite-micro-issue.md)
* Linked to the issue from the PR description
We would like to have a discussion on the Github issue first to determine the best path forward, and then proceed to the PR review.
# Cuda Comment
cuda_comment: >
From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries:
* For TF-GPU - See point 1
* For TF-CPU - See point 2
-----------------------------------------------------------------------------------------------
**1. Installing **TensorFlow-GPU** (TF) prebuilt binaries**
Make sure you are using compatible TF and CUDA versions.
Please refer following TF version and CUDA version compatibility table.
| TF | CUDA |
| :-------------: | :-------------: |
| 2.1.0 - 2.2.0 | 10.1 |
| 1.13.1 - 2.0 | 10.0 |
| 1.5.0 - 1.12.0 | 9.0 |
* If you have above configuration and using _**Windows**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable.
* Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup).
* If you have above configuration and using _**Ubuntu/Linux**_ platform -
* Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable.
* Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup).
* If error still persists then, apparently your CPU model does not support AVX instruction sets.
* Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements).
-----------------------------------------------------------------------------------------------
**2. Installing **TensorFlow** (TF) CPU prebuilt binaries**
*TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.*
Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load.
Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below:
* Try Google Colab to use TensorFlow.
* The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version.
* It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task.
* All you need is a good internet connection and you are all set.
* Try to build TF from sources by changing CPU optimization flags.
*Please let us know if this helps.*
windows_comment: >
From the stack trace it looks like you are hitting windows path length limit.
* Try to disable path length limit on Windows 10.
* Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/)
Please let us know if this helps.

39
.github/stale.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#
# THIS IS A GENERATED DOCKERFILE.
#
# This file was assembled from multiple pieces, whose use is documented
# throughout. Please refer to the TensorFlow dockerfiles documentation
# for more information.
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you.
# Comment to post when removing the stale label. Set to `false` to disable
unmarkComment: false
closeComment: >
Closing as stale. Please reopen if you'd like to work on this further.
limitPerRun: 30
# Limit to only `issues` or `pulls`
only: issues

28
.github/workflows/update-nightly.yml vendored Normal file
View File

@ -0,0 +1,28 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
on:
workflow_dispatch: # Allow manual triggers
schedule:
- cron: 0 4 * * * # 4am UTC is 9pm PDT and 8pm PST
name: Set nightly branch to master HEAD
jobs:
master-to-nightly:
runs-on: ubuntu-latest
steps:
- uses: zofrex/mirror-branch@v1
name: Set nightly branch to master HEAD
with:
target-branch: 'nightly'

45
.gitignore vendored
View File

@ -1,16 +1,49 @@
.DS_Store
.ipynb_checkpoints
node_modules
/.bazelrc.user
/.tf_configure.bazelrc
/bazel-*
/third_party/py/numpy/numpy_include
/tools/bazel.rc
/bazel_pip
/tools/python_bin_path.sh
/tools/git/gen
/util/python/python_include
/util/python/python_lib
/tensorflow/tools/git/gen
/pip_test
/_python_build
*.pyc
__pycache__
*.swp
.vscode/
.vscode/
cmake_build/
tensorflow/contrib/cmake/_build/
.idea/**
/build/
[Bb]uild/
/tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp
/tensorflow/lite/gen/**
/tensorflow/lite/tools/make/downloads/**
/tensorflow/lite/tools/make/gen/**
/api_init_files_list.txt
/estimator_api_init_files_list.txt
*.whl
# Android
.gradle
.idea
*.iml
local.properties
gradleBuild
# iOS
*.pbxproj
*.xcworkspace
/*.podspec
/tensorflow/lite/**/coreml/**/BUILD
/tensorflow/lite/**/ios/BUILD
/tensorflow/lite/**/objc/BUILD
/tensorflow/lite/**/swift/BUILD
/tensorflow/lite/examples/ios/simple/data/*.tflite
/tensorflow/lite/examples/ios/simple/data/*.txt
Podfile.lock
Pods
xcuserdata

View File

@ -1,11 +0,0 @@
{
"maxReviewers": 2,
"numFilesToCheck": 10,
"userBlacklist": ["tensorflower-gardener"],
"requiredOrgs": ["tensorflow"],
"skipAlreadyAssignedPR": true,
"skipAlreadyMentionedPR": true,
"skipTitle": "Branch",
"delayed": true,
"delayedUntil": "10m"
}

1
.pylintrc Symbolic link
View File

@ -0,0 +1 @@
tensorflow/tools/ci_build/pylintrc

View File

@ -1,10 +0,0 @@
# TensorFlow Adopters
This page contains a list of people and organizations who are using TensorFlow. If you'd like to be included
here, please send a pull request which modifies this file.
We intend to use this list to contact you for surveys, and to find good candidates for invite-only events.
We will also point to this list if we are asked who uses TensorFlow.
We will not use any of the information here for promotions or to send other regular communications. You
should subscribe to discuss@tensorflow.org for such announcements.

View File

@ -7,4 +7,4 @@
# The email address is not required for organizations.
Google Inc.
Yuan Tang terrytangyuan@gmail.com
Yuan Tang <terrytangyuan@gmail.com>

32
BUILD
View File

@ -1,28 +1,8 @@
# Description:
# TensorFlow is an open source software library for numerical computation using
# data flow graphs.
package(
default_visibility = [
"//tensorflow:internal",
"//tensorflow_models:__subpackages__",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"cc_header_only_library",
)
cc_header_only_library(
name = "protobuf_headers",
includes = ["external/protobuf/src"],
visibility = ["//visibility:public"],
deps = [
"@protobuf//:protobuf",
exports_files(
[
"LICENSE",
"ACKNOWLEDGEMENTS",
"configure",
"configure.py",
],
)

17
CODEOWNERS Normal file
View File

@ -0,0 +1,17 @@
# Where component owners are known, add them here.
/tensorflow/c/eager @qqfish @kkimdev
/tensorflow/core/common_runtime/eager @qqfish @kkimdev
/tenosrflow/core/debug @caisq
/tensorflow/core/kernels/mkl/ @penpornk
/tensorflow/core/kernels/sparse/ @penpornk
/tensorflow/core/nccl/ @azaks2 @chsigg
/tensorflow/core/platform/windows/ @mihaimaruseac
/tensorflow/lite/experimental/micro @petewarden @advaitjain
/tensorflow/python/autograph/ @mdanatg @kkimdev
/tensorflow/python/debug @caisq
/tensorflow/python/eager @rohan100jain @kkimdev
/tensorflow/python/tools/api/generator/ @annarev
/tensorflow/tools/docs/ @markdaoust
/third_party/systemlibs/ @perfinion

85
CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,85 @@
# TensorFlow Code of Conduct
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and our
community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, nationality, personal appearance, race, religion, or sexual identity
and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Using welcoming and inclusive language.
* Being respectful of differing viewpoints and experiences.
* Gracefully accepting constructive criticism.
* Focusing on what is best for the community.
* Showing empathy towards other community members.
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances.
* Trolling, insulting/derogatory comments, and personal or political attacks.
* Public or private harassment.
* Publishing others' private information, such as a physical or electronic
address, without explicit permission.
* Conduct which could reasonably be considered inappropriate for the forum in
which it occurs.
All TensorFlow forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable.
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
## Scope
This Code of Conduct applies to all content on tensorflow.org, TensorFlows GitHub organization, or any other official TensorFlow web presence allowing for community interactions, as well as at all official TensorFlow events, whether offline or online.
The Code of Conduct also applies within project spaces and in public spaces whenever an individual is representing TensorFlow or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed or de facto representative at an online or offline event.
## Conflict Resolution
Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between.
If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below.
However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute.
If you are experiencing or witnessing conflict, we ask you to use the following escalation strategy to address the conflict:
1. Address the perceived conflict directly with those involved, preferably in a
real-time medium.
2. If this fails, get a third party (e.g. a mutual friend, and/or someone with
background on the issue, but not involved in the conflict) to intercede.
3. If you are still unable to resolve the conflict, and you believe it rises to
harassment or another code of conduct violation, report it.
## Reporting Violations
Violations of the Code of Conduct can be reported to TensorFlows Project
Stewards, Edd Wilder-James (ewj@google.com) and Thea Lamkin
(thealamkin@google.com). The Project Steward will determine whether the Code of
Conduct was violated, and will issue an appropriate sanction, possibly including
a written warning or expulsion from the project, project sponsored spaces, or
project forums. We ask that you make a good-faith effort to resolve your
conflict via the conflict resolution policy before submitting a report.
Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report.
## Enforcement
If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice.
## Attribution
This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct.

View File

@ -1,5 +1,16 @@
# Contributing guidelines
## Pull Request Checklist
Before sending your pull requests, make sure you followed this list.
- Read [contributing guidelines](CONTRIBUTING.md).
- Read [Code of Conduct](CODE_OF_CONDUCT.md).
- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/).
- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution).
- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style).
- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests).
## How to become a contributor and submit your own code
### Contributor License Agreements
@ -8,8 +19,8 @@ We'd love to accept your patches! Before we can take them, we have to jump a cou
Please fill out either the individual or corporate Contributor License Agreement (CLA).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
* If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html).
* If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html).
Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests.
@ -18,12 +29,181 @@ Follow either of the two links above to access the appropriate CLA and instructi
### Contributing code
If you have improvements to TensorFlow, send us your pull requests! For those
just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/).
just getting started, Github has a
[how to](https://help.github.com/articles/using-pull-requests/).
If you want to contribute but you're not sure where to start, take a look at the
TensorFlow team members will be assigned to review your pull requests. Once the
pull requests are approved and pass continuous integration checks, a TensorFlow
team member will apply `ready to pull` label to your change. This means we are
working on getting your pull request submitted to our internal repository. After
the change has been submitted internally, your pull request will be merged
automatically on GitHub.
If you want to contribute, start working through the TensorFlow codebase,
navigate to the
[Github "issues" tab](https://github.com/tensorflow/tensorflow/issues) and start
looking through interesting issues. If you are not sure of where to start, then
start by trying one of the smaller/easier issues here i.e.
[issues with the "good first issue" label](https://github.com/tensorflow/tensorflow/labels/good%20first%20issue)
and then take a look at the
[issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
These are issues that we believe are particularly well suited for outside
contributions, often because we probably won't get to them right now. If you
decide to start on an issue, leave a comment so that other people know that
you're working on it. If you want to help out, but not alone, use the issue
comment thread to coordinate.
### Contribution guidelines and standards
Before sending your pull request for
[review](https://github.com/tensorflow/tensorflow/pulls),
make sure your changes are consistent with the guidelines and follow the
TensorFlow coding style.
#### General guidelines and philosophy for contribution
* Include unit tests when you contribute new features, as they help to a)
prove that your code works correctly, and b) guard against future breaking
changes to lower the maintenance cost.
* Bug fixes also generally require unit tests, because the presence of bugs
usually indicates insufficient test coverage.
* Keep API compatibility in mind when you change code in core TensorFlow,
e.g., code in
[tensorflow/core](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core)
and
[tensorflow/python](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python).
TensorFlow has passed version 1.0 and hence cannot make
non-backward-compatible API changes without a major release. Reviewers of
your pull request will comment on any API compatibility issues.
* When you contribute a new feature to TensorFlow, the maintenance burden is
(by default) transferred to the TensorFlow team. This means that the benefit
of the contribution must be compared against the cost of maintaining the
feature.
* Full new features (e.g., a new op implementing a cutting-edge algorithm)
typically will live in
[tensorflow/addons](https://github.com/tensorflow/addons) to get some
airtime before a decision is made regarding whether they are to be migrated
to the core.
* As every PR requires several CPU/GPU hours of CI testing, we discourage
submitting PRs to fix one typo, one warning,etc. We recommend fixing the
same issue at the file level at least (e.g.: fix all typos in a file, fix
all compiler warning in a file, etc.)
* Tests should follow the
[testing best practices](https://www.tensorflow.org/community/contribute/tests)
guide.
#### License
Include a license at the top of new files.
* [C/C++ license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op.cc#L1)
* [Python license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/nn.py#L1)
* [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1)
* [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1)
* [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2)
* [HTML license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/tf-backend.html#L2)
* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/backend.ts#L1)
Bazel BUILD files also need to include a license section, e.g.,
[BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61).
#### C++ coding style
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
```
You can check a C/C++ file by doing:
```bash
clang-format <my_cc_file> --style=google > /tmp/my_cc_file.cc
diff <my_cc_file> /tmp/my_cc_file.cc
```
#### Python coding style
Changes to TensorFlow Python code should conform to
[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
Use `pylint` to check your Python changes. To install `pylint` and check a file
with `pylint` against TensorFlow's custom style definition:
```bash
pip install pylint
pylint --rcfile=tensorflow/tools/ci_build/pylintrc myfile.py
```
Note `pylint --rcfile=tensorflow/tools/ci_build/pylintrc` should run from the
top level tensorflow directory.
#### Coding style for other languages
* [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html)
* [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html)
* [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml)
* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html)
#### Running sanity check
If you have Docker installed on your system, you can perform a sanity check on
your changes by running the command:
```bash
tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh
```
This will catch most license, Python coding style and BUILD file issues that
may exist in your changes.
#### Running unit tests
There are two ways to run TensorFlow unit tests.
1. Using tools and libraries installed directly on your system.
Refer to the
[CPU-only developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile)
and
[GPU developer Dockerfile](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile)
for the required packages. Alternatively, use the said
[Docker images](https://hub.docker.com/r/tensorflow/tensorflow/tags/), e.g.,
`tensorflow/tensorflow:devel` and `tensorflow/tensorflow:devel-gpu` for
development to avoid installing the packages directly on your system (in
which case remember to change directory from `/root` to `/tensorflow` once
you get into the running container so `bazel` can find the `tensorflow`
workspace).
Once you have the packages installed, you can run a specific unit test in
bazel by doing as follows:
If the tests are to be run on GPU, add CUDA paths to LD_LIBRARY_PATH and add
the `cuda` option flag
```bash
export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
export flags="--config=opt --config=cuda -k"
```
For example, to run all tests under tensorflow/python, do:
```bash
bazel test ${flags} //tensorflow/python/...
```
2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts.
```bash
# Install Docker first, then this will build and run cpu tests
tensorflow/tools/ci_build/ci_build.sh CPU bazel test //tensorflow/...
```
See
[TensorFlow Builds](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/ci_build)
for details.

11
ISSUES.md Normal file
View File

@ -0,0 +1,11 @@
If you open a GitHub Issue, here is our policy:
1. It must be a bug/performance issue or a feature request or a build issue or
a documentation issue (for small doc fixes please send a PR instead).
1. Make sure the Issue Template is filled out.
1. The issue should be related to the repo it is created in.
**Here's why we have this policy:** We want to focus on the work that benefits
the whole community, e.g., fixing bugs and adding features. Individual support
should be sought on Stack Overflow or other non-GitHub channels. It helps us to
address bugs and feature requests in a timely manner.

View File

@ -1,36 +1,47 @@
NOTE: Only file GitHub issues for bugs and feature requests. All other topics will be closed.
Please go to Stack Overflow for help and support:
For general support from the community, see [StackOverflow](https://stackoverflow.com/questions/tagged/tensorflow).
To make bugs and feature requests more easy to find and organize, we close issues that are deemed
out of scope for GitHub Issues and point people to StackOverflow.
https://stackoverflow.com/questions/tagged/tensorflow
For bugs or installation issues, please provide the following information.
The more information you provide, the more easily we will be able to offer
help and advice.
If you open a GitHub issue, here is our policy:
### What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?
1. It must be a bug, a feature request, or a significant problem with the
documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go
[here](https://github.com/tensorflow/tensorboard/issues).
### Environment info
Operating System:
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
Installed version of CUDA and cuDNN:
(please attach the output of `ls -l /path/to/cuda/lib/libcud*`):
------------------------
If installed from binary pip package, provide:
### System information
1. A link to the pip package you installed:
2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`.
- **Have I written custom code (as opposed to using a stock example script
provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
happens on a mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
If installed from source, provide
You can collect some of this information using our environment capture script:
1. The commit hash (`git rev-parse HEAD`)
2. The output of `bazel version`
https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh
### If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)
You can obtain the TensorFlow version with:
```bash
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
```
### What other attempted solutions have you tried?
### Describe the problem
Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request.
### Logs or other output that would be helpful
(If logs are large, please upload as attachment or provide link).
### Source code / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.

View File

@ -1,4 +1,4 @@
Copyright 2017 The TensorFlow Authors. All rights reserved.
Copyright 2019 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
@ -188,7 +188,7 @@ Copyright 2017 The TensorFlow Authors. All rights reserved.
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017, The TensorFlow Authors.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

200
README.md
View File

@ -1,67 +1,177 @@
<div align="center">
<img src="https://www.tensorflow.org/images/tf_logo_transp.png"><br><br>
<img src="https://www.tensorflow.org/images/tf_logo_social.png">
</div>
-----------------
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|-----------------|---------------------|------------------|-------------------|---------------|
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |
[![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow)
[![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow)
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. Nodes in the graph represent mathematical operations, while
the graph edges represent the multidimensional data arrays (tensors) that flow
between them. This flexible architecture lets you deploy computation to one
or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
code. TensorFlow also includes TensorBoard, a data visualization toolkit.
**`Documentation`** |
------------------- |
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) |
TensorFlow was originally developed by researchers and engineers
working on the Google Brain team within Google's Machine Intelligence research
organization for the purposes of conducting machine learning and deep neural
networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
[TensorFlow](https://www.tensorflow.org/) is an end-to-end open source platform
for machine learning. It has a comprehensive, flexible ecosystem of
[tools](https://www.tensorflow.org/resources/tools),
[libraries](https://www.tensorflow.org/resources/libraries-extensions), and
[community](https://www.tensorflow.org/community) resources that lets
researchers push the state-of-the-art in ML and developers easily build and
deploy ML-powered applications.
**If you'd like to contribute to TensorFlow, be sure to review the [contribution
guidelines](CONTRIBUTING.md).**
TensorFlow was originally developed by researchers and engineers working on the
Google Brain team within Google's Machine Intelligence Research organization to
conduct machine learning and deep neural networks research. The system is
general enough to be applicable in a wide variety of other domains, as well.
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, but please see
[Community](tensorflow/g3doc/resources/index.md#community) for general questions
and discussion.**
TensorFlow provides stable [Python](https://www.tensorflow.org/api_docs/python)
and [C++](https://www.tensorflow.org/api_docs/cc) APIs, as well as
non-guaranteed backward compatible API for
[other languages](https://www.tensorflow.org/api_docs).
## Installation
*See [Download and Setup](tensorflow/g3doc/get_started/os_setup.md) for instructions on how to install our release binaries or how to build from source.*
Keep up-to-date with release announcements and security updates by subscribing
to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
See all the [mailing lists](https://www.tensorflow.org/community/forums).
People who are a little more adventurous can also try our nightly binaries:
## Install
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.0.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.0.0-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
See the [TensorFlow install guide](https://www.tensorflow.org/install) for the
[pip package](https://www.tensorflow.org/install/pip), to
[enable GPU support](https://www.tensorflow.org/install/gpu), use a
[Docker container](https://www.tensorflow.org/install/docker), and
[build from source](https://www.tensorflow.org/install/source).
To install the current release, which includes support for
[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu) *(Ubuntu and
Windows)*:
```
$ pip install tensorflow
```
A smaller CPU-only package is also available:
```
$ pip install tensorflow-cpu
```
To update TensorFlow to the latest version, add `--upgrade` flag to the above
commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
```shell
$ python
```
```python
>>> import tensorflow as tf
>>> tf.add(1, 2).numpy()
3
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> sess.run(hello)
Hello, TensorFlow!
>>> a = tf.constant(10)
>>> b = tf.constant(32)
>>> sess.run(a+b)
42
>>>
>>> hello.numpy()
b'Hello, TensorFlow!'
```
##For more information
For more examples, see the
[TensorFlow tutorials](https://www.tensorflow.org/tutorials/).
* [TensorFlow website](http://tensorflow.org)
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
## Contribution guidelines
The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list.
**If you want to contribute to TensorFlow, be sure to review the
[contribution guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's
[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to
uphold this code.**
**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, please see
[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss)
for general questions and discussion, and please direct specific questions to
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).**
The TensorFlow project strives to abide by generally accepted best practices in
open-source software development:
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow)
[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md)
## Continuous build status
### Official Builds
Build Type | Status | Artifacts
----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-py3.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Linux XLA** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-xla.html) | TBA
**macOS** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [PyPI](https://pypi.org/project/tf-nightly/)
**Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [PyPI](https://pypi.org/project/tf-nightly-gpu/)
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/)
### Community Supported Builds
Build Type | Status | Artifacts
----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow-nightly)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow-nightly/) | [Nightly](http://snapshots.linaro.org/ldcg/python/tensorflow-nightly/latest/)
**Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/ldcg/python/tensorflow/latest/)
**Linux aarch64 CPU** Nightly (OpenLab)<br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux aarch64 CPU** Stable Release (OpenLab) | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
### Community Supported Containers
Container Type | Status | Artifacts
----------------------------------------------------------------- | ------ | ---------
**TensorFlow aarch64 Neoverse-N1 CPU** Stable (Linaro)<br> Debian | Static | Release [2.3](https://hub.docker.com/r/linaro/tensorflow-arm-neoverse-n1)
## Resources
* [TensorFlow.org](https://www.tensorflow.org)
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2)
* [Intro to TensorFlow for A.I, M.L, and D.L from Coursera](https://www.coursera.org/learn/introduction-tensorflow)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp)
* [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the
[TensorFlow community](https://www.tensorflow.org/community) and how to
[contribute](https://www.tensorflow.org/community/contribute).
## License
[Apache License 2.0](LICENSE)

3908
RELEASE.md

File diff suppressed because it is too large Load Diff

248
SECURITY.md Normal file
View File

@ -0,0 +1,248 @@
# Using TensorFlow Securely
This document discusses how to safely deal with untrusted programs (models or
model parameters), and input data. Below, we also provide guidelines on how to
report vulnerabilities in TensorFlow.
## TensorFlow models are programs
TensorFlow's runtime system interprets and executes programs. What machine
learning practitioners term
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
as computation
[**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
The model's parameters are often stored separately in **checkpoints**.
At runtime, TensorFlow executes the computation graph using the parameters
provided. Note that the behavior of the computation graph may change
depending on the parameters provided. TensorFlow itself is not a sandbox. When
executing the computation graph, TensorFlow may read and write files, send and
receive data over the network, and even spawn additional processes. All these
tasks are performed with the permissions of the TensorFlow process. Allowing
for this flexibility makes for a powerful machine learning platform,
but it has implications for security.
The computation graph may also accept **inputs**. Those inputs are the
data you supply to TensorFlow to train a model, or to use a model to run
inference on the data.
**TensorFlow models are programs, and need to be treated as such from a security
perspective.**
## Running untrusted models
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
[nsjail](https://github.com/google/nsjail)).
There are several ways in which a model could become untrusted. Obviously, if an
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
The same is true if the untrusted party provides Python code, such as the
Python code that generates TensorFlow graphs.
Even if the untrusted party only supplies the serialized computation
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
set of computation primitives available to TensorFlow is powerful enough that
you should assume that the TensorFlow process effectively executes arbitrary
code. One common solution is to allow only a few safe Ops. While this is
possible in theory, we still recommend you sandbox the execution.
It depends on the computation graph whether a user provided checkpoint is safe.
It is easily possible to create computation graphs in which malicious
checkpoints can trigger unsafe behavior. For example, consider a graph that
contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
stored in the checkpoint, whoever provides the checkpoint now has the ability to
trigger unsafe behavior, even though the graph is not under their control.
In other words, graphs can contain vulnerabilities of their own. To allow users
to provide checkpoints to a model you run on their behalf (e.g., in order to
compare model quality for a fixed model architecture), you must carefully audit
your model, and we recommend you run the TensorFlow process in a sandbox.
## Accepting untrusted Inputs
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: First, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
In general, it is good practice to isolate parts of any system which is exposed
to untrusted (e.g., user-provided) inputs in a sandbox.
A useful analogy to how any TensorFlow graph is executed is any interpreted
programming language, such as Python. While it is possible to write secure
Python code which can be exposed to user supplied inputs (by, e.g., carefully
quoting and sanitizing input strings, size-checking input blobs, etc.), it is
very easy to write Python programs which are insecure. Even secure Python code
could be rendered insecure by a bug in the Python interpreter, or in a bug in a
Python library used (e.g.,
[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
## Running a TensorFlow server
TensorFlow is a platform for distributed computing, and as such there is a
TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
internal communication only. It is not built for use in an untrusted network.**
For performance reasons, the default TensorFlow server does not include any
authorization protocol and sends messages unencrypted. It accepts connections
from anywhere, and executes the graphs it is sent without performing any checks.
Therefore, if you run a `tf.train.Server` in your network, anybody with
access to the network can execute what you should consider arbitrary code with
the privileges of the process running the `tf.train.Server`.
When running distributed TensorFlow, you must isolate the network in which the
cluster lives. Cloud providers provide instructions for setting up isolated
networks, which are sometimes branded as "virtual private cloud." Refer to the
instructions for
[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
[AWS](https://aws.amazon.com/vpc/)) for details.
Note that `tf.train.Server` is different from the server created by
`tensorflow/serving` (the default binary for which is called `ModelServer`).
By default, `ModelServer` also has no built-in mechanism for authentication.
Connecting it to an untrusted network allows anyone on this network to run the
graphs known to the `ModelServer`. This means that an attacker may run
graphs using untrusted inputs as described above, but they would not be able to
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
directly to an untrusted network, **but only if the graphs it is configured to
use have been carefully audited to be safe**.
Similar to best practices for other servers, we recommend running any
`ModelServer` with appropriate privileges (i.e., using a separate user with
reduced permissions). In the spirit of defense in depth, we recommend
authenticating requests to any TensorFlow server connected to an untrusted
network, as well as sandboxing the server to minimize the adverse effects of
any breach.
## Vulnerabilities in TensorFlow
TensorFlow is a large and complex system. It also depends on a large set of
third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
It is possible that TensorFlow or its dependent libraries contain
vulnerabilities that would allow triggering unexpected or dangerous behavior
with specially crafted inputs.
### What is a vulnerability?
Given TensorFlow's flexibility, it is possible to specify computation graphs
which exhibit unexpected or unwanted behavior. The fact that TensorFlow models
can perform arbitrary computations means that they may read and write files,
communicate via the network, produce deadlocks and infinite loops, or run out
of memory. It is only when these behaviors are outside the specifications of the
operations involved that such behavior is a vulnerability.
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
**is** a vulnerability.
This is more subtle from a system perspective. For example, it is easy to cause
a TensorFlow process to try to allocate more memory than available by specifying
a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
should exit cleanly in this case (it would raise an exception in Python, or
return an error `Status` in C++). However, if the surrounding system is not
expecting the possibility, such behavior could be used in a denial of service
attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for TensorFlow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.
One of the most critical parts of any system is input handling. If malicious
input can trigger side effects or incorrect behavior, this is a bug, and likely
a vulnerability.
### Reporting vulnerabilities
Please email reports about any security related issues you find to
`security@tensorflow.org`. This mail is delivered to a small security team. Your
email will be acknowledged within one business day, and you'll receive a more
detailed response to your email within 7 days indicating the next steps in
handling your report. For critical problems, you may encrypt your report (see
below).
Please use a descriptive subject line for your report email. After the initial
reply to your report, the security team will endeavor to keep you informed of
the progress being made towards a fix and announcement.
In addition, please include the following information along with your report:
* Your name and affiliation (if any).
* A description of the technical details of the vulnerabilities. It is very
important to let us know how we can reproduce your findings.
* An explanation who can exploit this vulnerability, and what they gain when
doing so -- write an attack scenario. This will help us evaluate your report
quickly, especially if the issue is complex.
* Whether this vulnerability public or known to third parties. If it is, please
provide details.
If you believe that an existing (public) issue is security-related, please send
an email to `security@tensorflow.org`. The email should include the issue ID and
a short description of why it should be handled according to this security
policy.
Once an issue is reported, TensorFlow uses the following disclosure process:
* When a report is received, we confirm the issue and determine its severity.
* If we know of specific third-party services or software based on TensorFlow
that require mitigation before publication, those projects will be notified.
* An advisory is prepared (but not published) which details the problem and
steps for mitigation.
* Wherever possible, fixes are prepared for the last minor release of the two
latest major releases, as well as the master branch. We will attempt to
commit these fixes as soon as possible, and as close together as
possible.
* Patch releases are published for all fixed released versions, a
notification is sent to discuss@tensorflow.org, and the advisory is published.
Past security advisories are listed below. We credit reporters for identifying
security issues, although we keep your name confidential if you request it.
#### Encryption key for `security@tensorflow.org`
If your disclosure is extremely sensitive, you may choose to encrypt your
report using the key below. Please only use this for critical security
reports.
```
-----BEGIN PGP PUBLIC KEY BLOCK-----
mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
=CDME
-----END PGP PUBLIC KEY BLOCK-----
```
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md).

529
WORKSPACE
View File

@ -1,523 +1,20 @@
workspace(name = "org_tensorflow")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce",
strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d",
urls = [
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03
"https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",
],
)
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
closure_repositories()
load("//tensorflow:workspace.bzl", "check_version", "tf_workspace")
# We must check the bazel version before trying to parse any other BUILD files,
# in case the parsing of those build files depends on the bazel version we
# require here.
check_version("0.4.2")
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
# name = "androidsdk",
# api_level = 23,
# build_tools_version = "23.0.1",
# # Replace with path to Android SDK on your system
# path = "<PATH_TO_SDK>",
#)
# Initialize the TensorFlow repository and all dependencies.
#
#android_ndk_repository(
# name="androidndk",
# path="<PATH_TO_NDK>",
# api_level=21)
# The cascade of load() statements and workspace() calls works around the
# restriction that load() statements need to be at the top of .bzl files.
# E.g. we can not retrieve a new repository with http_archive and then load()
# a macro from that repository in the same file.
load("@//tensorflow:workspace3.bzl", "workspace")
workspace()
load("@//tensorflow:workspace2.bzl", "workspace")
workspace()
load("@//tensorflow:workspace1.bzl", "workspace")
workspace()
load("@//tensorflow:workspace0.bzl", "workspace")
workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
new_http_archive(
name = "inception5h",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip",
sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364"
)
new_http_archive(
name = "mobile_multibox",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96"
)
new_http_archive(
name = "stylize",
build_file = "models.BUILD",
url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa"
)
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_http_archive(
name = "d3",
build_file = "bower.BUILD",
url = "https://github.com/mbostock-bower/d3-bower/archive/v3.5.15.tar.gz",
strip_prefix = "d3-bower-3.5.15",
)
new_http_archive(
name = "dagre",
build_file = "bower.BUILD",
url = "https://github.com/cpettitt/dagre/archive/v0.7.4.tar.gz",
strip_prefix = "dagre-0.7.4",
)
new_http_archive(
name = "es6_promise",
build_file = "bower.BUILD",
url = "https://github.com/components/es6-promise/archive/v2.1.0.tar.gz",
strip_prefix = "es6-promise-2.1.0",
)
new_http_archive(
name = "font_roboto",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/font-roboto/archive/v1.0.1.tar.gz",
strip_prefix = "font-roboto-1.0.1",
)
new_http_archive(
name = "graphlib",
build_file = "bower.BUILD",
url = "https://github.com/cpettitt/graphlib/archive/v1.0.7.tar.gz",
strip_prefix = "graphlib-1.0.7",
)
new_http_archive(
name = "iron_a11y_announcer",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-a11y-announcer/archive/v1.0.5.tar.gz",
strip_prefix = "iron-a11y-announcer-1.0.5",
)
new_http_archive(
name = "iron_a11y_keys_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz",
strip_prefix = "iron-a11y-keys-behavior-1.1.8",
)
new_http_archive(
name = "iron_ajax",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-ajax/archive/v1.2.0.tar.gz",
strip_prefix = "iron-ajax-1.2.0",
)
new_http_archive(
name = "iron_autogrow_textarea",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-autogrow-textarea/archive/v1.0.12.tar.gz",
strip_prefix = "iron-autogrow-textarea-1.0.12",
)
new_http_archive(
name = "iron_behaviors",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-behaviors/archive/v1.0.17.tar.gz",
strip_prefix = "iron-behaviors-1.0.17",
)
new_http_archive(
name = "iron_checked_element_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-checked-element-behavior/archive/v1.0.4.tar.gz",
strip_prefix = "iron-checked-element-behavior-1.0.4",
)
new_http_archive(
name = "iron_collapse",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-collapse/archive/v1.0.8.tar.gz",
strip_prefix = "iron-collapse-1.0.8",
)
new_http_archive(
name = "iron_dropdown",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-dropdown/archive/v1.4.0.tar.gz",
strip_prefix = "iron-dropdown-1.4.0",
)
new_http_archive(
name = "iron_fit_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-fit-behavior/archive/v1.2.5.tar.gz",
strip_prefix = "iron-fit-behavior-1.2.5",
)
new_http_archive(
name = "iron_flex_layout",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-flex-layout/archive/v1.3.0.tar.gz",
strip_prefix = "iron-flex-layout-1.3.0",
)
new_http_archive(
name = "iron_form_element_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-form-element-behavior/archive/v1.0.6.tar.gz",
strip_prefix = "iron-form-element-behavior-1.0.6",
)
new_http_archive(
name = "iron_icon",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-icon/archive/v1.0.11.tar.gz",
strip_prefix = "iron-icon-1.0.11",
)
new_http_archive(
name = "iron_icons",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-icons/archive/v1.1.3.tar.gz",
strip_prefix = "iron-icons-1.1.3",
)
new_http_archive(
name = "iron_iconset_svg",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-iconset-svg/archive/v1.1.0.tar.gz",
strip_prefix = "iron-iconset-svg-1.1.0",
)
new_http_archive(
name = "iron_input",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-input/archive/1.0.10.tar.gz",
strip_prefix = "iron-input-1.0.10",
)
new_http_archive(
name = "iron_list",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-list/archive/v1.3.9.tar.gz",
strip_prefix = "iron-list-1.3.9",
)
new_http_archive(
name = "iron_menu_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-menu-behavior/archive/v1.1.10.tar.gz",
strip_prefix = "iron-menu-behavior-1.1.10",
)
new_http_archive(
name = "iron_meta",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-meta/archive/v1.1.1.tar.gz",
strip_prefix = "iron-meta-1.1.1",
)
new_http_archive(
name = "iron_overlay_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-overlay-behavior/archive/v1.10.1.tar.gz",
strip_prefix = "iron-overlay-behavior-1.10.1",
)
new_http_archive(
name = "iron_range_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-range-behavior/archive/v1.0.4.tar.gz",
strip_prefix = "iron-range-behavior-1.0.4",
)
new_http_archive(
name = "iron_resizable_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-resizable-behavior/archive/v1.0.3.tar.gz",
strip_prefix = "iron-resizable-behavior-1.0.3",
)
new_http_archive(
name = "iron_scroll_target_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz",
strip_prefix = "iron-scroll-target-behavior-1.0.3",
)
new_http_archive(
name = "iron_selector",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-selector/archive/v1.5.2.tar.gz",
strip_prefix = "iron-selector-1.5.2",
)
new_http_archive(
name = "iron_validatable_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/iron-validatable-behavior/archive/v1.1.1.tar.gz",
strip_prefix = "iron-validatable-behavior-1.1.1",
)
new_http_archive(
name = "lodash",
build_file = "bower.BUILD",
url = "https://github.com/lodash/lodash/archive/3.8.0.tar.gz",
strip_prefix = "lodash-3.8.0",
)
new_http_archive(
name = "neon_animation",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/neon-animation/archive/v1.2.2.tar.gz",
strip_prefix = "neon-animation-1.2.2",
)
http_file(
name = "numericjs_numeric_min_js",
url = "https://cdnjs.cloudflare.com/ajax/libs/numeric/1.2.6/numeric.min.js",
)
new_http_archive(
name = "paper_behaviors",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-behaviors/archive/v1.0.12.tar.gz",
strip_prefix = "paper-behaviors-1.0.12",
)
new_http_archive(
name = "paper_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-button/archive/v1.0.11.tar.gz",
strip_prefix = "paper-button-1.0.11",
)
new_http_archive(
name = "paper_checkbox",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-checkbox/archive/v1.4.0.tar.gz",
strip_prefix = "paper-checkbox-1.4.0",
)
new_http_archive(
name = "paper_dialog",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog/archive/v1.0.4.tar.gz",
strip_prefix = "paper-dialog-1.0.4",
)
new_http_archive(
name = "paper_dialog_behavior",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog-behavior/archive/v1.2.5.tar.gz",
strip_prefix = "paper-dialog-behavior-1.2.5",
)
new_http_archive(
name = "paper_dialog_scrollable",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dialog-scrollable/archive/1.1.5.tar.gz",
strip_prefix = "paper-dialog-scrollable-1.1.5",
)
new_http_archive(
name = "paper_dropdown_menu",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-dropdown-menu/archive/v1.4.0.tar.gz",
strip_prefix = "paper-dropdown-menu-1.4.0",
)
new_http_archive(
name = "paper_header_panel",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-header-panel/archive/v1.1.4.tar.gz",
strip_prefix = "paper-header-panel-1.1.4",
)
new_http_archive(
name = "paper_icon_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-icon-button/archive/v1.1.3.tar.gz",
strip_prefix = "paper-icon-button-1.1.3",
)
new_http_archive(
name = "paper_input",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-input/archive/v1.1.18.tar.gz",
strip_prefix = "paper-input-1.1.18",
)
new_http_archive(
name = "paper_item",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-item/archive/v1.1.4.tar.gz",
strip_prefix = "paper-item-1.1.4",
)
new_http_archive(
name = "paper_listbox",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-listbox/archive/v1.1.2.tar.gz",
strip_prefix = "paper-listbox-1.1.2",
)
new_http_archive(
name = "paper_material",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-material/archive/v1.0.6.tar.gz",
strip_prefix = "paper-material-1.0.6",
)
new_http_archive(
name = "paper_menu",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-menu/archive/v1.2.2.tar.gz",
strip_prefix = "paper-menu-1.2.2",
)
new_http_archive(
name = "paper_menu_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-menu-button/archive/v1.5.1.tar.gz",
strip_prefix = "paper-menu-button-1.5.1",
)
new_http_archive(
name = "paper_progress",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-progress/archive/v1.0.9.tar.gz",
strip_prefix = "paper-progress-1.0.9",
)
new_http_archive(
name = "paper_radio_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-radio-button/archive/v1.1.2.tar.gz",
strip_prefix = "paper-radio-button-1.1.2",
)
new_http_archive(
name = "paper_radio_group",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-radio-group/archive/v1.0.9.tar.gz",
strip_prefix = "paper-radio-group-1.0.9",
)
new_http_archive(
name = "paper_ripple",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-ripple/archive/v1.0.5.tar.gz",
strip_prefix = "paper-ripple-1.0.5",
)
new_http_archive(
name = "paper_slider",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-slider/archive/v1.0.10.tar.gz",
strip_prefix = "paper-slider-1.0.10",
)
new_http_archive(
name = "paper_spinner",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-spinner/archive/v1.1.1.tar.gz",
strip_prefix = "paper-spinner-1.1.1",
)
new_http_archive(
name = "paper_styles",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-styles/archive/v1.1.4.tar.gz",
strip_prefix = "paper-styles-1.1.4",
)
new_http_archive(
name = "paper_tabs",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-tabs/archive/v1.7.0.tar.gz",
strip_prefix = "paper-tabs-1.7.0",
)
new_http_archive(
name = "paper_toast",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toast/archive/v1.3.0.tar.gz",
strip_prefix = "paper-toast-1.3.0",
)
new_http_archive(
name = "paper_toggle_button",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toggle-button/archive/v1.2.0.tar.gz",
strip_prefix = "paper-toggle-button-1.2.0",
)
new_http_archive(
name = "paper_toolbar",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-toolbar/archive/v1.1.4.tar.gz",
strip_prefix = "paper-toolbar-1.1.4",
)
new_http_archive(
name = "paper_tooltip",
build_file = "bower.BUILD",
url = "https://github.com/polymerelements/paper-tooltip/archive/v1.1.2.tar.gz",
strip_prefix = "paper-tooltip-1.1.2",
)
new_http_archive(
name = "plottable",
build_file = "bower.BUILD",
url = "https://github.com/palantir/plottable/archive/v1.16.1.tar.gz",
strip_prefix = "plottable-1.16.1",
)
new_http_archive(
name = "polymer",
build_file = "bower.BUILD",
url = "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz",
strip_prefix = "polymer-1.7.0",
)
new_http_archive(
name = "promise_polyfill",
build_file = "bower.BUILD",
url = "https://github.com/polymerlabs/promise-polyfill/archive/v1.0.0.tar.gz",
strip_prefix = "promise-polyfill-1.0.0",
)
http_file(
name = "three_js_three_min_js",
url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/build/three.min.js",
)
http_file(
name = "three_js_orbitcontrols_js",
url = "https://raw.githubusercontent.com/mrdoob/three.js/r77/examples/js/controls/OrbitControls.js",
)
new_http_archive(
name = "web_animations_js",
build_file = "bower.BUILD",
url = "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz",
strip_prefix = "web-animations-js-2.2.1",
)
new_http_archive(
name = "webcomponentsjs",
build_file = "bower.BUILD",
url = "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz",
strip_prefix = "webcomponentsjs-0.7.22",
)
http_file(
name = "weblas_weblas_js",
url = "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js",
)

75
arm_compiler.BUILD Normal file
View File

@ -0,0 +1,75 @@
package(default_visibility = ["//visibility:public"])
filegroup(
name = "gcc",
srcs = glob(["bin/*-gcc"]),
)
filegroup(
name = "ar",
srcs = glob(["bin/*-ar"]),
)
filegroup(
name = "ld",
srcs = glob(["bin/*-ld"]),
)
filegroup(
name = "nm",
srcs = glob(["bin/*-nm"]),
)
filegroup(
name = "objcopy",
srcs = glob(["bin/*-objcopy"]),
)
filegroup(
name = "objdump",
srcs = glob(["bin/*-objdump"]),
)
filegroup(
name = "strip",
srcs = glob(["bin/*-strip"]),
)
filegroup(
name = "as",
srcs = glob(["bin/*-as"]),
)
filegroup(
name = "compiler_pieces",
srcs = glob([
"arm-linux-gnueabihf/**",
"libexec/**",
"lib/gcc/arm-linux-gnueabihf/**",
"include/**",
]),
)
filegroup(
name = "aarch64_compiler_pieces",
srcs = glob([
"aarch64-none-linux-gnu/**",
"libexec/**",
"lib/gcc/aarch64-none-linux-gnu/**",
"include/**",
]),
)
filegroup(
name = "compiler_components",
srcs = [
":ar",
":as",
":gcc",
":ld",
":nm",
":objcopy",
":objdump",
":strip",
],
)

View File

@ -1,645 +0,0 @@
# AUTOGENERATED FILE by tensorboard_bower_dependency_sync.py
package(default_visibility = ["//visibility:public"])
filegroup(
name = "d3",
srcs = [
"d3.js",
"d3.min.js",
"package.js",
],
)
filegroup(
name = "dagre",
srcs = [
"dist/dagre.core.js",
"dist/dagre.core.min.js",
],
)
filegroup(
name = "es6_promise",
srcs = [
"promise.js",
"promise.min.js",
],
)
filegroup(
name = "font_roboto",
srcs = ["roboto.html"],
)
filegroup(
name = "graphlib",
srcs = [
"dist/graphlib.core.js",
"dist/graphlib.core.min.js",
],
)
filegroup(
name = "iron_a11y_announcer",
srcs = [
"index.html",
"iron-a11y-announcer.html",
],
)
filegroup(
name = "iron_a11y_keys_behavior",
srcs = [
"index.html",
"iron-a11y-keys-behavior.html",
],
)
filegroup(
name = "iron_ajax",
srcs = [
"index.html",
"iron-ajax.html",
"iron-request.html",
],
)
filegroup(
name = "iron_autogrow_textarea",
srcs = [
"index.html",
"iron-autogrow-textarea.html",
],
)
filegroup(
name = "iron_behaviors",
srcs = [
"index.html",
"iron-button-state.html",
"iron-control-state.html",
],
)
filegroup(
name = "iron_checked_element_behavior",
srcs = [
"index.html",
"iron-checked-element-behavior.html",
],
)
filegroup(
name = "iron_collapse",
srcs = [
"index.html",
"iron-collapse.html",
],
)
filegroup(
name = "iron_dropdown",
srcs = [
"index.html",
"iron-dropdown.html",
"iron-dropdown-scroll-manager.html",
],
)
filegroup(
name = "iron_fit_behavior",
srcs = [
"index.html",
"iron-fit-behavior.html",
],
)
filegroup(
name = "iron_flex_layout",
srcs = [
"classes/iron-flex-layout.html",
"classes/iron-shadow-flex-layout.html",
"index.html",
"iron-flex-layout.html",
"iron-flex-layout-classes.html",
],
)
filegroup(
name = "iron_form_element_behavior",
srcs = [
"index.html",
"iron-form-element-behavior.html",
],
)
filegroup(
name = "iron_icon",
srcs = [
"index.html",
"iron-icon.html",
],
)
filegroup(
name = "iron_icons",
srcs = [
"av-icons.html",
"communication-icons.html",
"device-icons.html",
"editor-icons.html",
"hardware-icons.html",
"image-icons.html",
"index.html",
"iron-icons.html",
"maps-icons.html",
"notification-icons.html",
"places-icons.html",
"social-icons.html",
],
)
filegroup(
name = "iron_iconset_svg",
srcs = [
"index.html",
"iron-iconset-svg.html",
],
)
filegroup(
name = "iron_input",
srcs = [
"index.html",
"iron-input.html",
],
)
filegroup(
name = "iron_list",
srcs = [
"index.html",
"iron-list.html",
"test/smoke/avg-worst-case.html",
"test/smoke/dummy-data.html",
"test/smoke/index.html",
"test/smoke/physical-count.html",
],
)
filegroup(
name = "iron_menu_behavior",
srcs = [
"index.html",
"iron-menu-behavior.html",
"iron-menubar-behavior.html",
],
)
filegroup(
name = "iron_meta",
srcs = [
"index.html",
"iron-meta.html",
],
)
filegroup(
name = "iron_overlay_behavior",
srcs = [
"index.html",
"iron-focusables-helper.html",
"iron-overlay-backdrop.html",
"iron-overlay-behavior.html",
"iron-overlay-manager.html",
],
)
filegroup(
name = "iron_range_behavior",
srcs = [
"index.html",
"iron-range-behavior.html",
],
)
filegroup(
name = "iron_resizable_behavior",
srcs = [
"demo/src/x-app.html",
"index.html",
"iron-resizable-behavior.html",
],
)
filegroup(
name = "iron_scroll_target_behavior",
srcs = [
"index.html",
"iron-scroll-target-behavior.html",
],
)
filegroup(
name = "iron_selector",
srcs = [
"index.html",
"iron-multi-selectable.html",
"iron-selectable.html",
"iron-selection.html",
"iron-selector.html",
],
)
filegroup(
name = "iron_validatable_behavior",
srcs = [
"index.html",
"iron-validatable-behavior.html",
],
)
filegroup(
name = "lodash",
srcs = [
"lodash.js",
"lodash.min.js",
],
)
filegroup(
name = "neon_animation",
srcs = [
"animations/cascaded-animation.html",
"animations/fade-in-animation.html",
"animations/fade-out-animation.html",
"animations/hero-animation.html",
"animations/opaque-animation.html",
"animations/reverse-ripple-animation.html",
"animations/ripple-animation.html",
"animations/scale-down-animation.html",
"animations/scale-up-animation.html",
"animations/slide-down-animation.html",
"animations/slide-from-bottom-animation.html",
"animations/slide-from-left-animation.html",
"animations/slide-from-right-animation.html",
"animations/slide-from-top-animation.html",
"animations/slide-left-animation.html",
"animations/slide-right-animation.html",
"animations/slide-up-animation.html",
"animations/transform-animation.html",
"demo/card/index.html",
"demo/card/x-card.html",
"demo/card/x-cards-list.html",
"demo/declarative/index.html",
"demo/doc/index.html",
"demo/doc/my-animatable.html",
"demo/doc/my-dialog.html",
"demo/dropdown/animated-dropdown.html",
"demo/dropdown/index.html",
"demo/grid/animated-grid.html",
"demo/grid/fullsize-page-with-card.html",
"demo/grid/index.html",
"demo/list/full-view.html",
"demo/list/index.html",
"demo/list/list-demo.html",
"demo/list/list-view.html",
"demo/load/animated-grid.html",
"demo/load/full-page.html",
"demo/load/index.html",
"demo/reprojection/animated-grid.html",
"demo/reprojection/fullsize-page-with-card.html",
"demo/reprojection/index.html",
"demo/reprojection/reprojected-pages.html",
"demo/tiles/circles-page.html",
"demo/tiles/index.html",
"demo/tiles/squares-page.html",
"index.html",
"neon-animatable.html",
"neon-animatable-behavior.html",
"neon-animated-pages.html",
"neon-animation.html",
"neon-animation-behavior.html",
"neon-animation-runner-behavior.html",
"neon-animations.html",
"neon-shared-element-animatable-behavior.html",
"neon-shared-element-animation-behavior.html",
"web-animations.html",
],
)
filegroup(
name = "paper_behaviors",
srcs = [
"index.html",
"paper-button-behavior.html",
"paper-checked-element-behavior.html",
"paper-inky-focus-behavior.html",
"paper-ripple-behavior.html",
],
)
filegroup(
name = "paper_button",
srcs = [
"index.html",
"paper-button.html",
],
)
filegroup(
name = "paper_checkbox",
srcs = [
"index.html",
"paper-checkbox.html",
],
)
filegroup(
name = "paper_dialog",
srcs = [
"index.html",
"paper-dialog.html",
],
)
filegroup(
name = "paper_dialog_behavior",
srcs = [
"index.html",
"paper-dialog-behavior.html",
"paper-dialog-common.css",
"paper-dialog-shared-styles.html",
],
)
filegroup(
name = "paper_dialog_scrollable",
srcs = [
"index.html",
"paper-dialog-scrollable.html",
],
)
filegroup(
name = "paper_dropdown_menu",
srcs = [
"index.html",
"paper-dropdown-menu.html",
"paper-dropdown-menu-icons.html",
"paper-dropdown-menu-light.html",
"paper-dropdown-menu-shared-styles.html",
],
)
filegroup(
name = "paper_header_panel",
srcs = [
"index.html",
"paper-header-panel.html",
],
)
filegroup(
name = "paper_icon_button",
srcs = [
"index.html",
"paper-icon-button.html",
"paper-icon-button-light.html",
],
)
filegroup(
name = "paper_input",
srcs = [
"all-imports.html",
"index.html",
"paper-input.html",
"paper-input-addon-behavior.html",
"paper-input-behavior.html",
"paper-input-char-counter.html",
"paper-input-container.html",
"paper-input-error.html",
"paper-textarea.html",
],
)
filegroup(
name = "paper_item",
srcs = [
"all-imports.html",
"index.html",
"paper-icon-item.html",
"paper-item.html",
"paper-item-behavior.html",
"paper-item-body.html",
"paper-item-shared-styles.html",
],
)
filegroup(
name = "paper_listbox",
srcs = [
"index.html",
"paper-listbox.html",
],
)
filegroup(
name = "paper_material",
srcs = [
"index.html",
"paper-material.html",
"paper-material-shared-styles.html",
],
)
filegroup(
name = "paper_menu",
srcs = [
"index.html",
"paper-menu.html",
"paper-menu-shared-styles.html",
"paper-submenu.html",
],
)
filegroup(
name = "paper_menu_button",
srcs = [
"index.html",
"paper-menu-button.html",
"paper-menu-button-animations.html",
],
)
filegroup(
name = "paper_progress",
srcs = [
"index.html",
"paper-progress.html",
],
)
filegroup(
name = "paper_radio_button",
srcs = [
"index.html",
"paper-radio-button.html",
],
)
filegroup(
name = "paper_radio_group",
srcs = [
"index.html",
"paper-radio-group.html",
],
)
filegroup(
name = "paper_ripple",
srcs = [
"index.html",
"paper-ripple.html",
],
)
filegroup(
name = "paper_slider",
srcs = [
"index.html",
"paper-slider.html",
],
)
filegroup(
name = "paper_spinner",
srcs = [
"index.html",
"paper-spinner.html",
"paper-spinner-behavior.html",
"paper-spinner-lite.html",
"paper-spinner-styles.html",
],
)
filegroup(
name = "paper_styles",
srcs = [
"classes/global.html",
"classes/shadow.html",
"classes/shadow-layout.html",
"classes/typography.html",
"color.html",
"default-theme.html",
"demo.css",
"demo-pages.html",
"index.html",
"paper-styles.html",
"paper-styles-classes.html",
"shadow.html",
"typography.html",
],
)
filegroup(
name = "paper_tabs",
srcs = [
"index.html",
"paper-tab.html",
"paper-tabs.html",
"paper-tabs-icons.html",
],
)
filegroup(
name = "paper_toast",
srcs = [
"index.html",
"paper-toast.html",
],
)
filegroup(
name = "paper_toggle_button",
srcs = [
"index.html",
"paper-toggle-button.html",
],
)
filegroup(
name = "paper_toolbar",
srcs = [
"index.html",
"paper-toolbar.html",
],
)
filegroup(
name = "paper_tooltip",
srcs = [
"index.html",
"paper-tooltip.html",
],
)
filegroup(
name = "plottable",
srcs = [
"plottable.css",
"plottable.js",
"plottable.min.js",
],
)
filegroup(
name = "polymer",
srcs = [
"polymer.html",
"polymer-micro.html",
"polymer-mini.html",
],
)
filegroup(
name = "promise_polyfill",
srcs = [
"Gruntfile.js",
"Promise.js",
"Promise.min.js",
"Promise-Statics.js",
"promise-polyfill.html",
"promise-polyfill-lite.html",
],
)
filegroup(
name = "web_animations_js",
srcs = [
"web-animations.html",
"web-animations.min.js",
"web-animations-next.min.js",
"web-animations-next-lite.min.js",
],
)
filegroup(
name = "webcomponentsjs",
srcs = [
"CustomElements.js",
"CustomElements.min.js",
"HTMLImports.js",
"HTMLImports.min.js",
"MutationObserver.js",
"MutationObserver.min.js",
"ShadowDOM.js",
"ShadowDOM.min.js",
"webcomponents.js",
"webcomponents.min.js",
"webcomponents-lite.js",
"webcomponents-lite.min.js",
],
)

583
configure vendored
View File

@ -3,584 +3,13 @@
set -e
set -o pipefail
# Find out the absolute path to where ./configure resides
pushd `dirname $0` > /dev/null
SOURCE_BASE_DIR=`pwd -P`
popd > /dev/null
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
function is_linux() {
if [[ "${PLATFORM}" == "linux" ]]; then
true
else
false
fi
}
function is_macos() {
if [[ "${PLATFORM}" == "darwin" ]]; then
true
else
false
fi
}
function is_windows() {
# On windows, the shell script is actually running in msys
if [[ "${PLATFORM}" =~ msys_nt*|mingw*|cygwin*|uwin* ]]; then
true
else
false
fi
}
function bazel_clean_and_fetch() {
# bazel clean --expunge currently doesn't work on Windows
# TODO(pcloudy): Re-enable it after bazel clean --expunge is fixed.
if ! is_windows; then
bazel clean --expunge
fi
bazel fetch "//tensorflow/... -//tensorflow/contrib/nccl/... \
-//tensorflow/examples/android/..."
}
# Delete any leftover BUILD files from the Makefile build, which would interfere
# with Bazel parsing.
MAKEFILE_DOWNLOAD_DIR=tensorflow/contrib/makefile/downloads
if [ -d "${MAKEFILE_DOWNLOAD_DIR}" ]; then
find ${MAKEFILE_DOWNLOAD_DIR} -type f -name '*BUILD' -delete
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$(which python3 || which python || true)
fi
## Set up python-related environment settings
while true; do
fromuser=""
if [ -z "$PYTHON_BIN_PATH" ]; then
default_python_bin_path=$(which python || which python3 || true)
read -p "Please specify the location of python. [Default is $default_python_bin_path]: " PYTHON_BIN_PATH
fromuser="1"
if [ -z "$PYTHON_BIN_PATH" ]; then
PYTHON_BIN_PATH=$default_python_bin_path
fi
fi
if [ -e "$PYTHON_BIN_PATH" ]; then
break
fi
echo "Invalid python path. ${PYTHON_BIN_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
PYTHON_BIN_PATH=""
# Retry
done
## Set up MKL related environment settings
if false; then # Disable building with MKL for now
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
OSNAME=`uname -s`
if [ "$TF_NEED_MKL" == "1" ]; then # TF_NEED_MKL
DST=`dirname $0`
ARCHIVE_BASENAME=mklml_lnx_2017.0.2.20170110.tgz
GITHUB_RELEASE_TAG=v0.3
MKLURL="https://github.com/01org/mkl-dnn/releases/download/$GITHUB_RELEASE_TAG/$ARCHIVE_BASENAME"
if ! [ -e "$DST/third_party/mkl/$ARCHIVE_BASENAME" ]; then
wget --no-check-certificate -P $DST/third_party/mkl/ $MKLURL
fi
tar -xzf $DST/third_party/mkl/$ARCHIVE_BASENAME -C $DST/third_party/mkl/
extracted_dir_name="${ARCHIVE_BASENAME%.*}"
MKL_INSTALL_PATH=$DST/third_party/mkl/$extracted_dir_name
MKL_INSTALL_PATH=`${PYTHON_BIN_PATH} -c "import os; print(os.path.realpath(os.path.expanduser('${MKL_INSTALL_PATH}')))"`
if [ "$OSNAME" == "Linux" ]; then
# Full MKL configuration
MKL_RT_LIB_PATH="lib/intel64/libmkl_rt.so" #${TF_MKL_EXT}#TODO version?
MKL_RT_OMP_LIB_PATH="../compiler/lib/intel64/libiomp5.so" #TODO VERSION?
# MKL-ML configuration
MKL_ML_LIB_PATH="lib/libmklml_intel.so" #${TF_MKL_EXT}#TODO version?
MKL_ML_OMP_LIB_PATH="lib/libiomp5.so" #TODO VERSION?
elif [ "$OSNAME" == "Darwin" ]; then
echo "Darwin is unsupported yet";
exit 1
fi
if [ -e "$MKL_INSTALL_PATH/${MKL_ML_LIB_PATH}" ]; then
ln -sf $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/${MKL_ML_OMP_LIB_PATH} third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/mkl/
ln -sf $MKL_INSTALL_PATH/include third_party/eigen3/mkl_include
else
echo "ERROR: $MKL_INSTALL_PATH/${MKL_ML_LIB_PATH} does not exist";
exit 1
fi
if [ -z "$fromuser" ]; then
exit 1
fi
cat > third_party/mkl/mkl.config <<EOF
# MKL_INSTALL_PATH refers to the location of MKL root folder. The MKL header and library
# files can be either in this directory, or under include/ and lib64/
MKL_INSTALL_PATH=$MKL_INSTALL_PATH
EOF
fi # TF_NEED_MKL
################## MKL
fi # Disable building with MKL for now
## Set up architecture-dependent optimization flags.
if [ -z "$CC_OPT_FLAGS" ]; then
default_cc_opt_flags="-march=native"
read -p "Please specify optimization flags to use during compilation when bazel option "\
"\"--config=opt\" is specified [Default is $default_cc_opt_flags]: " CC_OPT_FLAGS
if [ -z "$CC_OPT_FLAGS" ]; then
CC_OPT_FLAGS=$default_cc_opt_flags
fi
fi
if is_windows; then
TF_NEED_GCP=0
TF_NEED_HDFS=0
TF_NEED_JEMALLOC=0
TF_NEED_OPENCL=0
fi
if is_linux; then
while [ "$TF_NEED_JEMALLOC" == "" ]; do
read -p "Do you wish to use jemalloc as the malloc implementation? [Y/n] "\
INPUT
case $INPUT in
[Yy]* ) echo "jemalloc enabled"; TF_NEED_JEMALLOC=1;;
[Nn]* ) echo "jemalloc disabled"; TF_NEED_JEMALLOC=0;;
"" ) echo "jemalloc enabled"; TF_NEED_JEMALLOC=1;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
else
TF_NEED_JEMALLOC=0
fi
if [ "$TF_NEED_JEMALLOC" == "1" ]; then
sed -i -e "s/WITH_JEMALLOC = False/WITH_JEMALLOC = True/" tensorflow/core/platform/default/build_config.bzl
else
sed -i -e "s/WITH_JEMALLOC = True/WITH_JEMALLOC = False/" tensorflow/core/platform/default/build_config.bzl
fi
while [ "$TF_NEED_GCP" == "" ]; do
read -p "Do you wish to build TensorFlow with "\
"Google Cloud Platform support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=1;;
[Nn]* ) echo "No Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=0;;
"" ) echo "No Google Cloud Platform support will be enabled for "\
"TensorFlow"; TF_NEED_GCP=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
if [ "$TF_NEED_GCP" == "1" ]; then
## Verify that libcurl header files are available.
# Only check Linux, since on MacOS the header files are installed with XCode.
if is_linux && [[ ! -f "/usr/include/curl/curl.h" ]]; then
echo "ERROR: It appears that the development version of libcurl is not "\
"available. Please install the libcurl3-dev package."
exit 1
fi
# Update Bazel build configuration.
sed -i -e "s/WITH_GCP_SUPPORT = False/WITH_GCP_SUPPORT = True/" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
sed -i -e "s/WITH_GCP_SUPPORT = True/WITH_GCP_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
fi
while [ "$TF_NEED_HDFS" == "" ]; do
read -p "Do you wish to build TensorFlow with "\
"Hadoop File System support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "Hadoop File System support will be enabled for "\
"TensorFlow"; TF_NEED_HDFS=1;;
[Nn]* ) echo "No Hadoop File System support will be enabled for "\
"TensorFlow"; TF_NEED_HDFS=0;;
"" ) echo "No Hadoop File System support will be enabled for "\
"TensorFlow"; TF_NEED_HDFS=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
if [ "$TF_NEED_HDFS" == "1" ]; then
# Update Bazel build configuration.
sed -i -e "s/WITH_HDFS_SUPPORT = False/WITH_HDFS_SUPPORT = True/" tensorflow/core/platform/default/build_config.bzl
else
# Update Bazel build configuration.
sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
fi
## Enable XLA.
while [ "$TF_ENABLE_XLA" == "" ]; do
read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
[Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
"" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
if [ "$TF_ENABLE_XLA" == "1" ]; then
# Update Bazel build configuration.
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
else
# Update Bazel build configuration.
sed -i -e "s/^WITH_XLA_SUPPORT = [FT].*/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
fi
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
# Append CC optimization flags to bazel.rc
echo >> tools/bazel.rc
for opt in $CC_OPT_FLAGS; do
echo "build:opt --cxxopt=$opt --copt=$opt" >> tools/bazel.rc
done
# Run the gen_git_source to create links where bazel can track dependencies for
# git hash propagation
GEN_GIT_SOURCE=tensorflow/tools/git/gen_git_source.py
chmod a+x ${GEN_GIT_SOURCE}
"${PYTHON_BIN_PATH}" ${GEN_GIT_SOURCE} --configure "${SOURCE_BASE_DIR}"
## Set up SYCL-related environment settings
while [ "$TF_NEED_OPENCL" == "" ]; do
read -p "Do you wish to build TensorFlow with OpenCL support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=1;;
[Nn]* ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
"" ) echo "No OpenCL support will be enabled for TensorFlow"; TF_NEED_OPENCL=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
## Set up Cuda-related environment settings
while [ "$TF_NEED_CUDA" == "" ]; do
read -p "Do you wish to build TensorFlow with CUDA support? [y/N] " INPUT
case $INPUT in
[Yy]* ) echo "CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=1;;
[Nn]* ) echo "No CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
"" ) echo "No CUDA support will be enabled for TensorFlow"; TF_NEED_CUDA=0;;
* ) echo "Invalid selection: " $INPUT;;
esac
done
export TF_NEED_CUDA
export TF_NEED_OPENCL
if [[ "$TF_NEED_CUDA" == "0" ]] && [[ "$TF_NEED_OPENCL" == "0" ]]; then
echo "Configuration finished"
bazel_clean_and_fetch
exit
fi
if [ "$TF_NEED_CUDA" == "1" ]; then
# Set up which gcc nvcc should use as the host compiler
# No need to set this on Windows
while ! is_windows && true; do
fromuser=""
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
default_gcc_host_compiler_path=$(which gcc || true)
read -p "Please specify which gcc should be used by nvcc as the host compiler. [Default is $default_gcc_host_compiler_path]: " GCC_HOST_COMPILER_PATH
fromuser="1"
if [ -z "$GCC_HOST_COMPILER_PATH" ]; then
GCC_HOST_COMPILER_PATH="$default_gcc_host_compiler_path"
fi
fi
if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
export GCC_HOST_COMPILER_PATH
break
fi
echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
GCC_HOST_COMPILER_PATH=""
# Retry
done
# Find out where the CUDA toolkit is installed
while true; do
# Configure the Cuda SDK version to use.
if [ -z "$TF_CUDA_VERSION" ]; then
read -p "Please specify the CUDA SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: " TF_CUDA_VERSION
fi
fromuser=""
if [ -z "$CUDA_TOOLKIT_PATH" ]; then
default_cuda_path=/usr/local/cuda
if is_windows; then
if [ -z "$CUDA_PATH" ]; then
default_cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v8.0"
else
default_cuda_path="$(cygpath -m "$CUDA_PATH")"
fi
fi
read -p "Please specify the location where CUDA $TF_CUDA_VERSION toolkit is installed. Refer to README.md for more details. [Default is $default_cuda_path]: " CUDA_TOOLKIT_PATH
fromuser="1"
if [ -z "$CUDA_TOOLKIT_PATH" ]; then
CUDA_TOOLKIT_PATH="$default_cuda_path"
fi
fi
if [[ -z "$TF_CUDA_VERSION" ]]; then
TF_CUDA_EXT=""
else
TF_CUDA_EXT=".$TF_CUDA_VERSION"
fi
if is_windows; then
CUDA_RT_LIB_PATH="lib/x64/cudart.lib"
elif is_linux; then
CUDA_RT_LIB_PATH="lib64/libcudart.so${TF_CUDA_EXT}"
elif is_macos; then
CUDA_RT_LIB_PATH="lib/libcudart${TF_CUDA_EXT}.dylib"
fi
if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then
export CUDA_TOOLKIT_PATH
export TF_CUDA_VERSION
break
fi
echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found"
if [ -z "$fromuser" ]; then
exit 1
fi
# Retry
TF_CUDA_VERSION=""
CUDA_TOOLKIT_PATH=""
done
# Find out where the cuDNN library is installed
while true; do
# Configure the Cudnn version to use.
if [ -z "$TF_CUDNN_VERSION" ]; then
read -p "Please specify the Cudnn version you want to use. [Leave empty to use system default]: " TF_CUDNN_VERSION
fi
fromuser=""
if [ -z "$CUDNN_INSTALL_PATH" ]; then
default_cudnn_path=${CUDA_TOOLKIT_PATH}
read -p "Please specify the location where cuDNN $TF_CUDNN_VERSION library is installed. Refer to README.md for more details. [Default is $default_cudnn_path]: " CUDNN_INSTALL_PATH
fromuser="1"
if [ -z "$CUDNN_INSTALL_PATH" ]; then
CUDNN_INSTALL_PATH=$default_cudnn_path
fi
# Result returned from "read" will be used unexpanded. That make "~" unuseable.
# Going through one more level of expansion to handle that.
CUDNN_INSTALL_PATH=`"${PYTHON_BIN_PATH}" -c "import os; print(os.path.realpath(os.path.expanduser('${CUDNN_INSTALL_PATH}')))"`
fi
if [[ -z "$TF_CUDNN_VERSION" ]]; then
TF_CUDNN_EXT=""
else
TF_CUDNN_EXT=".$TF_CUDNN_VERSION"
fi
if is_windows; then
CUDA_DNN_LIB_PATH="lib/x64/cudnn.lib"
CUDA_DNN_LIB_ALT_PATH="lib/x64/cudnn.lib"
elif is_linux; then
CUDA_DNN_LIB_PATH="lib64/libcudnn.so${TF_CUDNN_EXT}"
CUDA_DNN_LIB_ALT_PATH="libcudnn.so${TF_CUDNN_EXT}"
elif is_macos; then
CUDA_DNN_LIB_PATH="lib/libcudnn${TF_CUDNN_EXT}.dylib"
CUDA_DNN_LIB_ALT_PATH="libcudnn${TF_CUDNN_EXT}.dylib"
fi
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
export TF_CUDNN_VERSION
export CUDNN_INSTALL_PATH
break
fi
if is_linux; then
if ! type ldconfig > /dev/null 2>&1; then
LDCONFIG_BIN=/sbin/ldconfig
else
LDCONFIG_BIN=ldconfig
fi
CUDNN_PATH_FROM_LDCONFIG="$($LDCONFIG_BIN -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then
export TF_CUDNN_VERSION
export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
break
fi
fi
echo "Invalid path to cuDNN ${CUDNN_VERSION} toolkit. Neither of the following two files can be found:"
echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_PATH}"
echo "${CUDNN_INSTALL_PATH}/${CUDA_DNN_LIB_ALT_PATH}"
if is_linux; then
echo "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}"
fi
if [ -z "$fromuser" ]; then
exit 1
fi
# Retry
TF_CUDNN_VERSION=""
CUDNN_INSTALL_PATH=""
done
# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
while true; do
fromuser=""
default_cuda_compute_capabilities="3.5,5.2"
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
cat << EOF
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
Please note that each additional compute capability significantly increases your build time and binary size.
EOF
read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
fromuser=1
fi
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
TF_CUDA_COMPUTE_CAPABILITIES=$default_cuda_compute_capabilities
fi
# Check whether all capabilities from the input is valid
COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
ALL_VALID=1
for CAPABILITY in $COMPUTE_CAPABILITIES; do
if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then
echo "Invalid compute capability: " $CAPABILITY
ALL_VALID=0
break
fi
done
if [ "$ALL_VALID" == "0" ]; then
if [ -z "$fromuser" ]; then
exit 1
fi
else
export TF_CUDA_COMPUTE_CAPABILITIES
break
fi
TF_CUDA_COMPUTE_CAPABILITIES=""
done
if is_windows; then
# The following three variables are needed for MSVC toolchain configuration in Bazel
export CUDA_PATH="$CUDA_TOOLKIT_PATH"
export CUDA_COMPUTE_CAPABILITIES="$TF_CUDA_COMPUTE_CAPABILITIES"
export NO_WHOLE_ARCHIVE_OPTION=1
# Set GCC_HOST_COMPILER_PATH to keep cuda_configure.bzl happy
export GCC_HOST_COMPILER_PATH="/usr/bin/dummy_compiler"
fi
# end of if "$TF_NEED_CUDA" == "1"
fi
# OpenCL configuration
if [ "$TF_NEED_OPENCL" == "1" ]; then
# Determine which C++ compiler should be used as the host compiler
while true; do
fromuser=""
if [ -z "$HOST_CXX_COMPILER" ]; then
default_cxx_host_compiler=$(which clang++-3.6 || true)
read -p "Please specify which C++ compiler should be used as the host C++ compiler. [Default is $default_cxx_host_compiler]: " HOST_CXX_COMPILER
fromuser="1"
if [ -z "$HOST_CXX_COMPILER" ]; then
HOST_CXX_COMPILER=$default_cxx_host_compiler
fi
fi
if [ -e "$HOST_CXX_COMPILER" ]; then
export HOST_CXX_COMPILER
break
fi
echo "Invalid C++ compiler path. ${HOST_CXX_COMPILER} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
HOST_CXX_COMPILER=""
# Retry
done
# Determine which C compiler should be used as the host compiler
while true; do
fromuser=""
if [ -z "$HOST_C_COMPILER" ]; then
default_c_host_compiler=$(which clang-3.6 || true)
read -p "Please specify which C compiler should be used as the host C compiler. [Default is $default_c_host_compiler]: " HOST_C_COMPILER
fromuser="1"
if [ -z "$HOST_C_COMPILER" ]; then
HOST_C_COMPILER=$default_c_host_compiler
fi
fi
if [ -e "$HOST_C_COMPILER" ]; then
export HOST_C_COMPILER
break
fi
echo "Invalid C compiler path. ${HOST_C_COMPILER} cannot be found" 1>&2
if [ -z "$fromuser" ]; then
exit 1
fi
HOST_C_COMPILER=""
# Retry
done
while true; do
# Configure the OPENCL version to use.
TF_OPENCL_VERSION="1.2"
# Point to ComputeCpp root
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
default_computecpp_toolkit_path=/usr/local/computecpp
read -p "Please specify the location where ComputeCpp for SYCL $TF_OPENCL_VERSION is installed. [Default is $default_computecpp_toolkit_path]: " COMPUTECPP_TOOLKIT_PATH
fromuser="1"
if [ -z "$COMPUTECPP_TOOLKIT_PATH" ]; then
COMPUTECPP_TOOLKIT_PATH=$default_computecpp_toolkit_path
fi
fi
if is_linux; then
SYCL_RT_LIB_PATH="lib/libComputeCpp.so"
fi
if [ -e "${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH}" ]; then
export COMPUTECPP_TOOLKIT_PATH
break
fi
echo "Invalid SYCL $TF_OPENCL_VERSION library path. ${COMPUTECPP_TOOLKIT_PATH}/${SYCL_RT_LIB_PATH} cannot be found"
if [ -z "$fromuser" ]; then
exit 1
fi
# Retry
TF_OPENCL_VERSION=""
COMPUTECPP_TOOLKIT_PATH=""
done
# end of if "$TF_NEED_OPENCL" == "1"
fi
bazel_clean_and_fetch
# Set all env variables
CONFIGURE_DIR=$(dirname "$0")
"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@"
echo "Configuration finished"

20
configure.cmd Normal file
View File

@ -0,0 +1,20 @@
:: Copyright 2019 The TensorFlow Authors. All Rights Reserved.
::
:: Licensed under the Apache License, Version 2.0 (the "License");
:: you may not use this file except in compliance with the License.
:: You may obtain a copy of the License at
::
:: http://www.apache.org/licenses/LICENSE-2.0
::
:: Unless required by applicable law or agreed to in writing, software
:: distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
:: WARRANTIES OR CONDITIONS OF ANY KIND< either express or implied. See the
:: License for the specific language governing permissions and limitations under
:: the License.
@echo off
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python "%configure_dir%\configure.py" %* || ( exit /b )
echo Configuration finished

1473
configure.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -20,20 +20,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python import *
# pylint: enable=wildcard-import
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# Lazily import the `tf.contrib` module. This avoids loading all of the
# dependencies of `tf.contrib` at `import tensorflow` time.
class _LazyContribLoader(object):
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top
app.flags = flags
def __getattr__(self, item):
global contrib
# Replace the lazy loader with the imported module itself.
import importlib # pylint: disable=g-import-not-at-top
contrib = importlib.import_module('tensorflow.contrib')
return getattr(contrib, item)
del absolute_import
del division
del print_function
contrib = _LazyContribLoader()
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
del python
del core
# pylint: enable=undefined-variable

View File

@ -0,0 +1,183 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Top-level module of TensorFlow. By convention, we refer to this module as
`tf` instead of `tensorflow`, following the common practice of importing
TensorFlow via the command `import tensorflow as tf`.
The primary function of this module is to import all of the public TensorFlow
interfaces into a single place. The interfaces themselves are located in
sub-modules, as described below.
Note that the file `__init__.py` in the TensorFlow source code tree is actually
only a placeholder to enable test cases to run. The TensorFlow build replaces
this file with a file generated from [`api_template.__init__.py`](https://www.github.com/tensorflow/tensorflow/blob/master/tensorflow/api_template.__init__.py)
"""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
import distutils as _distutils
import inspect as _inspect
import logging as _logging
import os as _os
import site as _site
import six as _six
import sys as _sys
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# Make sure code inside the TensorFlow codebase can use tf2.enabled() at import.
_os.environ['TF2_BEHAVIOR'] = '1'
from tensorflow.python import tf2 as _tf2
_tf2.enable()
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = _sys.modules[__name__].bitwise
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
_current_module = _sys.modules[__name__]
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
try:
from tensorboard.summary._tf import summary
_current_module.__path__ = (
[_module_util.get_parent_dir(summary)] + _current_module.__path__)
setattr(_current_module, "summary", summary)
except ImportError:
_logging.warning(
"Limited tf.summary API due to missing TensorBoard installation.")
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v2.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v2 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v2 import estimator
# pylint: enable=g-import-not-at-top
# Enable TF2 behaviors
from tensorflow.python.compat import v2_compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
_major_api_version = 2
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
# directories.
# TODO(gunan): Find a better location for this code snippet.
from tensorflow.python.framework import load_library as _ll
from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
if _site.ENABLE_USER_SITE and _site.USER_SITE is not None:
_site_packages_dirs += [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()
if 'sysconfig' in dir(_distutils):
_site_packages_dirs += [_distutils.sysconfig.get_python_lib()]
_site_packages_dirs = list(set(_site_packages_dirs))
# Find the location of this exact file.
_current_file_location = _inspect.getfile(_inspect.currentframe())
def _running_from_pip_package():
return any(
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _os.path.exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Add module aliases
if hasattr(_current_module, 'keras'):
losses = keras.losses
metrics = keras.metrics
optimizers = keras.optimizers
initializers = keras.initializers
setattr(_current_module, "losses", losses)
setattr(_current_module, "metrics", metrics)
setattr(_current_module, "optimizers", optimizers)
setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER

View File

@ -0,0 +1,182 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
import distutils as _distutils
import inspect as _inspect
import os as _os
import site as _site
import six as _six
import sys as _sys
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import module_util as _module_util
from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.util.lazy_loader import LazyLoader as _LazyLoader
# API IMPORTS PLACEHOLDER
# WRAPPER_PLACEHOLDER
if "dev" in __version__: # pylint: disable=undefined-variable
_logging.warning("""
TensorFlow's `tf-nightly` package will soon be updated to TensorFlow 2.0.
Please upgrade your code to TensorFlow 2.0:
* https://www.tensorflow.org/guide/migrate
Or install the latest stable TensorFlow 1.X release:
* `pip install -U "tensorflow==1.*"`
Otherwise your code may be broken by the change.
""")
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
_API_MODULE = _sys.modules[__name__].bitwise # pylint: disable=undefined-variable
_current_module = _sys.modules[__name__]
_tf_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'):
__path__ = [_tf_api_dir]
elif _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
# Hook external TensorFlow modules.
# Import compat before trying to import summary from tensorboard, so that
# reexport_tf_summary can get compat from sys.modules. Only needed if using
# lazy loading.
_current_module.compat.v2 # pylint: disable=pointless-statement
# Lazy-load estimator.
_estimator_module = "tensorflow_estimator.python.estimator.api._v1.estimator"
estimator = _LazyLoader("estimator", globals(), _estimator_module)
_module_dir = _module_util.get_parent_dir_for_name(_estimator_module)
if _module_dir:
_current_module.__path__ = [_module_dir] + _current_module.__path__
setattr(_current_module, "estimator", estimator)
try:
from .python.keras.api._v1 import keras
_current_module.__path__ = (
[_module_util.get_parent_dir(keras)] + _current_module.__path__)
setattr(_current_module, "keras", keras)
except ImportError:
pass
# Explicitly import lazy-loaded modules to support autocompletion.
# pylint: disable=g-import-not-at-top
if not _six.PY2:
import typing as _typing
if _typing.TYPE_CHECKING:
from tensorflow_estimator.python.estimator.api._v1 import estimator
# pylint: enable=g-import-not-at-top
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
_CONTRIB_WARNING = """
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
* https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
* https://github.com/tensorflow/addons
* https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.
"""
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib',
_CONTRIB_WARNING)
del LazyLoader
# The templated code that replaces the placeholder above sometimes
# sets the __all__ variable. If it does, we have to be sure to add
# "contrib".
if '__all__' in vars():
vars()['__all__'].append('contrib')
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
# The 'app' module will be imported as part of the placeholder section above.
_current_module.app.flags = flags # pylint: disable=undefined-variable
setattr(_current_module, "flags", flags)
_major_api_version = 1
# Load all plugin libraries from site-packages/tensorflow-plugins if we are
# running under pip.
# TODO(gunan): Enable setting an environment variable to define arbitrary plugin
# directories.
# TODO(gunan): Find a better location for this code snippet.
from tensorflow.python.framework import load_library as _ll
from tensorflow.python.lib.io import file_io as _fi
# Get sitepackages directories for the python installation.
_site_packages_dirs = []
_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE]
_site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p]
if 'getsitepackages' in dir(_site):
_site_packages_dirs += _site.getsitepackages()
if 'sysconfig' in dir(_distutils):
_site_packages_dirs += [_distutils.sysconfig.get_python_lib()]
_site_packages_dirs = list(set(_site_packages_dirs))
# Find the location of this exact file.
_current_file_location = _inspect.getfile(_inspect.currentframe())
def _running_from_pip_package():
return any(
_current_file_location.startswith(dir_) for dir_ in _site_packages_dirs)
if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _os.path.exists(_main_dir):
_ll.load_library(_main_dir)
# Load third party dynamic kernels.
_plugin_dir = _os.path.join(_s, 'tensorflow-plugins')
if _os.path.exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Load Pluggable Device Library
_ll.load_pluggable_device_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER

View File

@ -1,20 +1,24 @@
# Description:
# C API for TensorFlow, for use by client language bindings.
licenses(["notice"]) # Apache 2.0
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
"tf_custom_op_library",
"tf_kernel_library",
)
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "filegroup")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
# -----------------------------------------------------------------------------
@ -22,37 +26,451 @@ load(
filegroup(
name = "headers",
srcs = ["c_api.h"],
srcs = [
"c_api.h",
"c_api_experimental.h",
"c_api_macros.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_file_statistics.h",
"tf_status.h",
"tf_tensor.h",
"tf_tstring.h",
"//tensorflow/core/platform:ctstring",
],
visibility = ["//tensorflow:__subpackages__"],
)
filegroup(
name = "srcs",
srcs = glob(
[
"*.cc",
"*.h",
],
exclude = [
"c_api_experimental.cc",
"c_api_experimental.h",
"python_api.cc",
"python_api.h",
"*test*",
],
) + [
"//tensorflow/core/platform:ctstring",
"//tensorflow/cc:srcs_no_runtime",
"//tensorflow/core/distributed_runtime:server_lib.h",
],
visibility = ["//visibility:public"],
)
cc_library(
name = "pywrap_required_hdrs",
textual_hdrs = [
"c_api_internal.h",
"c_api_macros.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
"tf_status_helper.h",
"tf_status_internal.h",
"tf_tensor_internal.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__subpackages__",
],
)
tf_cuda_library(
name = "c_api_internal",
hdrs = [
"c_api.h",
"c_api_internal.h",
"c_api_macros.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
],
visibility = [
"//tensorflow:internal",
"//tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//tensorflow:chromiumos": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:platform",
],
"//conditions:default": [
":tf_attrtype",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:platform",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core/distributed_runtime:server_lib",
],
}) + [
":tf_status_internal",
":tf_tensor_internal",
],
)
filegroup(
name = "pywrap_tf_session_hdrs",
srcs = [
"python_api.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "tf_attrtype",
hdrs = ["tf_attrtype.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "c_api_macros",
hdrs = ["c_api_macros.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
)
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
hdrs = [
"c_api.h",
"tf_attrtype.h",
"tf_datatype.h",
"tf_file_statistics.h",
"tf_status.h",
"tf_tensor.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
deps = [
":c_api_no_xla",
":c_api_internal",
":tf_attrtype",
":tf_status_internal",
":tf_file_statistics",
":tf_tensor_internal",
] + select({
"//tensorflow:with_xla_support": [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/jit",
],
"//conditions:default": [],
}),
)
tf_cuda_library(
name = "c_api_no_xla",
srcs = [
"c_api.cc",
"c_api_function.cc",
],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/python:__subpackages__",
"//third_party/llvm/llvm-project:__subpackages__",
],
deps = [
":c_api_internal",
":tf_attrtype",
":tf_datatype",
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/cc/saved_model:loader",
":logging",
":tf_status",
":tf_tensor",
"@com_google_absl//absl/strings",
"//tensorflow/c/experimental/filesystem:modular_filesystem",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/cc:while_loop",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:op_gen_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:logging_ops",
"//tensorflow/compiler/mlir/tfr:node_expansion_pass",
"//tensorflow/compiler/mlir/tfr:graph_decompose_pass",
],
}),
alwayslink = 1,
)
cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:stringprintf",
],
)
tf_cuda_library(
name = "tf_status_internal",
hdrs = [
"tf_status.h",
"tf_status_internal.h",
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
cc_library(
name = "tf_shape",
srcs = ["tf_shape.cc"],
hdrs = ["tf_shape.h"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tf_shape_internal",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_shape_internal",
hdrs = ["tf_shape_internal.h"],
copts = tf_copts(),
visibility = ["//tensorflow:internal"],
deps = [
":conversion_macros",
"//tensorflow/core:framework",
],
)
cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = [
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
cc_library(
name = "tf_status_headers",
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tf_file_statistics",
hdrs = ["tf_file_statistics.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "tensor_interface",
hdrs = ["tensor_interface.h"],
visibility = ["//tensorflow:internal"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
}),
)
cc_library(
name = "tf_datatype",
srcs = ["tf_datatype.cc"],
hdrs = ["tf_datatype.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}),
alwayslink = 1,
)
cc_library(
name = "tf_tensor",
srcs = ["tf_tensor.cc"],
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
tf_cuda_library(
name = "tf_tensor_internal",
hdrs = [
"tf_tensor.h",
"tf_tensor_internal.h",
],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:casts",
],
}),
)
tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
],
hdrs = [
"c_api_experimental.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":c_api",
":c_api_internal",
":checkpoint_reader",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:get_compiler_ir",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform",
"//tensorflow/core/platform:blocking_counter",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
exports_files(
[
"version_script.lds",
"exported_symbols.lds",
],
visibility = ["//visibility:public"],
)
filegroup(
name = "checkpoint_reader_hdrs",
srcs = [
"checkpoint_reader.h",
"tf_status_helper.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
tf_cuda_library(
name = "tf_status_helper",
srcs = ["tf_status_helper.cc"],
hdrs = ["tf_status_helper.h"],
visibility = ["//visibility:public"],
deps = [
":c_api",
":tf_status",
":tf_status_internal",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
)
tf_cc_test(
name = "tf_status_helper_test",
srcs = ["tf_status_helper_test.cc"],
deps = [
":tf_status_helper",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
@ -69,60 +487,369 @@ tf_cuda_library(
],
)
tf_cuda_library(
name = "env",
srcs = [
"env.cc",
],
hdrs = [
"env.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}) + [
":c_api",
":tf_status_helper",
":c_api_internal",
":tf_file_statistics",
"//tensorflow/core:lib",
],
)
cc_library(
name = "kernels_hdrs",
hdrs = ["kernels.h"],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_internal",
":tf_datatype",
":tf_status",
":tf_tensor",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
],
)
tf_cuda_library(
name = "kernels",
srcs = [
"kernels.cc",
],
hdrs = [
"kernels.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_status",
":tf_status_helper",
":tf_tensor_internal",
] + select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
":tf_tensor",
"//tensorflow/stream_executor:stream",
"//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/c/experimental/stream_executor:stream_executor",
"//tensorflow/c/experimental/stream_executor:stream_executor_internal",
],
}),
)
tf_cuda_library(
name = "ops",
srcs = [
"ops.cc",
],
hdrs = [
"ops.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
":tf_datatype",
":tf_status",
":tf_status_helper",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
],
}),
alwayslink = 1,
)
cc_library(
name = "ops_hdrs",
hdrs = ["ops.h"],
visibility = ["//tensorflow:internal"],
deps = [
":tf_datatype",
":tf_status",
],
)
# -----------------------------------------------------------------------------
# Tests
tf_cuda_library(
name = "c_test_util",
testonly = 1,
srcs = ["c_test_util.cc"],
hdrs = ["c_test_util.h"],
visibility = [
"//learning/brain:__subpackages__",
"//tensorflow:__subpackages__",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "c_test",
srcs = ["c_test.c"],
extra_copts = ["-std=c11"],
deps = [
":c_api",
":c_api_experimental",
":env",
":kernels",
],
)
tf_cuda_cc_test(
name = "c_api_test",
size = "small",
srcs = ["c_api_test.cc"],
data = [
":test_op.so",
":test_op1.so",
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
],
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = [
"no_windows", # TODO(b/155444728)
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":c_test_util",
":test_op_kernel",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/cc/saved_model:tag_constants",
"//tensorflow/compiler/jit",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:bitwise_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:spectral_ops_op_lib",
"//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:math",
"//third_party/eigen3",
"//tensorflow/core/platform:resource_loader",
],
)
tf_cc_test(
name = "c_api_experimental_test",
size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = [
"testdata/tf_record",
"//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so",
],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["notsan"], # b/149031034
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_test_util",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "c_api_function_test",
size = "small",
srcs = ["c_api_function_test.cc"],
deps = [
":c_api",
":c_api_internal",
":c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "while_loop_test",
size = "small",
srcs = ["while_loop_test.cc"],
deps = [
":c_api",
":c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_custom_op_library(
name = "test_op.so",
name = "test_op1.so",
srcs = ["test_op1.cc"],
)
tf_kernel_library(
name = "test_op_kernel",
srcs = ["test_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
alwayslink = 1,
)
tf_cuda_cc_test(
name = "env_test",
size = "small",
srcs = ["env_test.cc"],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":env",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_cc_test(
name = "kernels_test",
size = "small",
srcs = ["kernels_test.cc"],
linkopts = select({
"//tensorflow:macos": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["no_cuda_on_cpu_tap"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:ops_testutil",
"//third_party/eigen3",
"@com_google_absl//absl/container:inlined_vector",
],
)
tf_cc_test(
name = "ops_test",
size = "small",
srcs = ["ops_test.cc"],
linkopts = select({
"//conditions:default": [],
}),
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/strings",
],
)
# -----------------------------------------------------------------------------
# Google-internal targets.
# Python API target
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
tf_cuda_library(
name = "python_api",
srcs = ["python_api.cc"],
hdrs = ["python_api.h"],
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":c_api",
":c_api_internal",
# TODO(b/74620627): remove when _USE_C_SHAPES is removed
"//tensorflow/python:cpp_shape_inference_proto_cc",
],
alwayslink = 1,
)
cc_library(
name = "conversion_macros",
hdrs = [
"conversion_macros.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

7
tensorflow/c/README.md Normal file
View File

@ -0,0 +1,7 @@
# TensorFlow C API
- See [www.tensorflow.org/install/lang_c](https://www.tensorflow.org/install/lang_c)
- Nightly builds:
- [Linux CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-linux-x86_64.tar.gz)
- [Linux GPU](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-gpu-linux-x86_64.tar.gz)
- [MacOS CPU-only](https://storage.googleapis.com/tensorflow-nightly/github/tensorflow/lib_package/libtensorflow-cpu-darwin-x86_64.tar.gz)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,792 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_experimental.h"
#include "absl/strings/substitute.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::FunctionDef;
using tensorflow::Node;
using tensorflow::NodeBuilder;
using tensorflow::Status;
using tensorflow::errors::InvalidArgument;
namespace {
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
UniqueFuncPtr;
}
// struct TF_Operation { tensorflow::Node node; };
static TF_Operation* ToTF_Operation(Node* node) {
return static_cast<TF_Operation*>(static_cast<void*>(node));
}
void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
tensorflow::ConfigProto& config = options->options.config;
auto* optimizer_options =
config.mutable_graph_options()->mutable_optimizer_options();
if (enable) {
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
}
}
unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
tensorflow::BuildXlaOpsPassFlags* flags =
tensorflow::GetBuildXlaOpsPassFlags();
bool original = flags->tf_xla_enable_lazy_compilation;
flags->tf_xla_enable_lazy_compilation = enable;
return original;
}
unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable) {
tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::GetMarkForCompilationPassFlags();
bool original = flags->tf_xla_cpu_global_jit;
flags->tf_xla_cpu_global_jit = static_cast<bool>(enable);
return static_cast<unsigned char>(original);
}
void TF_SetXlaAutoJitMode(const char* mode) {
tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
}
unsigned char TF_GetXlaConstantFoldingDisabled() {
return static_cast<unsigned char>(
tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
}
void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
static_cast<bool>(should_enable);
}
void TF_SetXlaMinClusterSize(int size) {
tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_min_cluster_size = size;
}
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
unsigned char gpu_memory_allow_growth,
unsigned int num_cpu_devices) {
tensorflow::ConfigProto config;
auto* optimizer_options =
config.mutable_graph_options()->mutable_optimizer_options();
if (enable_xla_compilation) {
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
// These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is.
tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1;
} else {
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
}
auto* gpu_options = config.mutable_gpu_options();
gpu_options->set_allow_growth(gpu_memory_allow_growth);
(*config.mutable_device_count())["CPU"] = num_cpu_devices;
// TODO(b/113217601): This is needed for EagerContext::runner_ to use a
// threadpool, so that we avoid the possibility of running the runner_ in the
// threadpool of GPU event mgr, as that can trigger more callbacks to be
// scheduled on that same threadpool, causing a deadlock in cases where the
// caller of event_mgr->ThenExecute() blocks on the completion of the callback
// (as in the case of ConstOp kernel creation on GPU, which involves copying a
// CPU tensor to GPU).
// Setting a larger thread pool does not help with the Swift caller, as we use
// a different TFE context for each thread of execution (for running graph
// functions, and their send/recvs corountines).
config.set_inter_op_parallelism_threads(1);
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(config, ret));
return ret;
}
TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
tensorflow::RunOptions options;
if (enable_full_trace) {
options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
} else {
options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
}
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(options, ret));
return ret;
}
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
*len = debug_str.size();
char* ret = static_cast<char*>(malloc(*len + 1));
memcpy(ret, debug_str.c_str(), *len + 1);
return ret;
}
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
const auto& debug_str = DebugString(func->fdef);
*len = debug_str.size();
char* ret = static_cast<char*>(malloc(*len + 1));
memcpy(ret, debug_str.c_str(), *len + 1);
return ret;
}
// On success, returns a set of TF_Function instances from `text_proto` of
// GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
//
// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
// before creating a TF_Function out of the possibly mutated proto.
static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
const char* text_proto,
std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
tensorflow::GraphDef gdef;
if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
status->status = tensorflow::errors::Internal(
"Invalid text proto for GraphDef: ", text_proto);
return {};
}
const auto& fdef_lib = gdef.library();
if (fdef_lib.gradient_size() > 0) {
status->status = tensorflow::errors::Internal(
"GradientDef is not supported in reading Dataset related functions: ",
text_proto);
return {};
}
std::vector<UniqueFuncPtr> ret;
for (const FunctionDef& fdef : fdef_lib.function()) {
// Make a copy so that we can mutate it.
FunctionDef fdef_to_load = fdef;
if (mutate_proto_func) {
(*mutate_proto_func)(&fdef_to_load);
}
VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
fdef_to_load.SerializeToArray(binary_proto_buf.data(),
binary_proto_buf.size());
TF_Function* func = TF_FunctionImportFunctionDef(
binary_proto_buf.data(), binary_proto_buf.size(), status);
if (!status->status.ok()) return {};
ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
}
return ret;
}
TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
TF_Status* status) {
assert(session);
{
tensorflow::mutex_lock c(session->graph->mu);
VLOG(1) << "Dequeuing named tensor with id " << tensor_id
<< ", with input graph: "
<< session->graph->graph.ToGraphDefDebug().DebugString();
}
TF_Operation* dequeue_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
if (dequeue_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the dequeue node in the TF graph.");
return nullptr;
}
VLOG(1) << "Running the dequeue op";
TF_Output output{dequeue_op, 0};
TF_Tensor* ret;
TF_SessionRun(session, /*run_options*/ nullptr,
// input related parameters
/*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
// output related parameters
/*outputs*/ &output, /*output_values*/ &ret,
/*noutputs*/ 1,
/*targets*/ nullptr, /*ntargets*/ 0,
/*run_metadata*/ nullptr, status);
if (VLOG_IS_ON(1) && status->status.ok()) {
tensorflow::Tensor tensor;
if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
}
}
return ret;
}
void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
TF_Tensor* tensor, TF_Status* status) {
assert(session);
{
tensorflow::mutex_lock c(session->graph->mu);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Enqueuing named tensor with id " << tensor_id
<< ", with input graph: "
<< session->graph->graph.ToGraphDefDebug().DebugString();
tensorflow::Tensor internal_tensor;
if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
VLOG(1) << "Enqueu'ing tensor content: "
<< internal_tensor.DebugString();
}
}
}
TF_Operation* enqueue_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
if (enqueue_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the enqueue node in the TF graph.");
return;
}
TF_Operation* placeholder_op = TF_GraphOperationByName(
session->graph,
tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
if (placeholder_op == nullptr) {
status->status = tensorflow::errors::Internal(
"Unable to find the placeholder node as input to enqueue in the TF "
"graph.");
return;
}
VLOG(1) << "Running the enqueue op";
TF_Output input{placeholder_op, 0};
TF_SessionRun(session, /*run_options*/ nullptr,
// input related parameters
/*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
// output related parameters
/*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
/*targets*/ &enqueue_op, /*ntargets*/ 1,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
tensorflow::ServerDef server_def;
if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
&server_def)) {
status->status = tensorflow::errors::Internal(
"Invalid text proto for ServerDef: ", text_proto);
return nullptr;
}
status->status = tensorflow::Status();
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(server_def, ret));
return ret;
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
std::vector<std::string> variable_list;
};
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
TF_Status* status) {
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
if (!status->status.ok()) {
TF_DeleteCheckpointReader(reader);
return nullptr;
}
const auto& m = reader->GetVariableToDataTypeMap();
for (auto it = m.begin(); it != m.end(); ++it)
reader->variable_list.push_back(it->first);
std::sort(reader->variable_list.begin(), reader->variable_list.end());
return reader;
}
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
const char* name) {
return reader->HasTensor(name);
}
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
int index) {
return reader->variable_list[index].c_str();
}
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
return reader->variable_list.size();
}
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToDataTypeMap();
return static_cast<TF_DataType>(m.at(name));
}
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
const char* name, TF_Status* status) {
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
}
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
const char* name, int64_t* dims,
int num_dims, TF_Status* status) {
const auto& shape = reader->GetVariableToShapeMap().at(name);
int rank = shape.dims();
if (num_dims != rank) {
status->status = InvalidArgument("Expected rank is ", num_dims,
" but actual rank is ", rank);
return;
}
for (int i = 0; i < num_dims; i++) {
dims[i] = shape.dim_size(i);
}
}
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToShapeMap();
return m.at(name).dims();
}
// This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder;
// The string buffers to make sure that any `attr_name` we pass into
// `builder->Set()` will outlive the subsequent
// `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
std::set<std::string> attr_names;
};
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
return new TF_AttrBuilder(op_name);
}
void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
TF_DataType value) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(*iter, static_cast<tensorflow::DataType>(value));
}
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values),
num_values));
}
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
const char* device_type,
TF_Status* status) {
status->status = tensorflow::FindKernelDef(
tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
/* def = */ nullptr, /* kernel_class_name = */ nullptr);
}
const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
TF_Status* status) {
const tensorflow::OpDef* op_def = nullptr;
status->status =
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
if (!status->status.ok()) return nullptr;
if (input_index >= op_def->input_arg_size() || input_index < 0) {
status->status = tensorflow::errors::InvalidArgument(
input_index, " out of range for ", op_name);
return nullptr;
}
const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
if (input_arg.number_attr().empty()) {
status->status = tensorflow::errors::NotFound(
op_name, " does not have number_attr() defined.");
return nullptr;
}
// The returned string is owned by OpRegistry, so liveness is not a concern.
return input_arg.number_attr().c_str();
}
int TF_OpIsStateful(const char* op_type, TF_Status* status) {
const tensorflow::OpRegistrationData* op_reg_data;
status->status =
tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
if (!status->status.ok()) {
return 0;
}
return op_reg_data->op_def.is_stateful();
}
void TF_InitMain(const char* usage, int* argc, char*** argv) {
tensorflow::port::InitMain(usage, argc, argv);
}
int TF_PickUnusedPortOrDie() {
return tensorflow::internal::PickUnusedPortOrDie();
}
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
void* data, size_t len,
TF_Status* status) {
auto dtype = static_cast<tensorflow::DataType>(data_type);
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
status->status = tensorflow::Status::OK();
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
}
namespace {
tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
TFE_Context* ctx) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
// New server created for new server_def. Unused if updating server_def.
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
if (grpc_server == nullptr) {
std::unique_ptr<tensorflow::ServerInterface> new_server;
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
if (grpc_server == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
}
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr.get()));
}
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
}
} // namespace
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status) {
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status = EnableCollectiveOps(server_def, ctx);
}
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
collective_executor_handle->get()->StartAbort(status->status);
}
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
tensorflow::Notification done;
collective_executor_handle->get()->remote_access()->CheckPeerHealth(
task, timeout_in_ms, [&done, status](const Status& s) {
status->status = s;
done.Notify();
});
done.WaitForNotification();
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items;
result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
return result;
}
void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
const int64_t* dims, int num_dims) {
DCHECK(index >= 0 && index < shape_list->num_items);
TF_ShapeAndType& shape = shape_list->items[index];
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
shape.num_dims = num_dims;
shape.dims = new int64_t[num_dims];
memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
}
void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
int index) {
DCHECK(index >= 0 && index < shape_list->num_items);
TF_ShapeAndType& shape = shape_list->items[index];
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
shape.num_dims = -1;
shape.dims = nullptr;
}
void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
TF_DataType dtype) {
DCHECK(index >= 0 && index < shape_list->num_items);
TF_ShapeAndType& shape_and_type = shape_list->items[index];
shape_and_type.dtype = dtype;
}
void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
if (shape_list == nullptr) return;
for (size_t i = 0; i < shape_list->num_items; ++i) {
delete[] shape_list->items[i].dims;
}
delete[] shape_list->items;
delete shape_list;
}
void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
int num_items) {
if (shape_list_array == nullptr) return;
for (int i = 0; i < num_items; ++i) {
TF_DeleteShapeAndTypeList(shape_list_array[i]);
}
delete[] shape_list_array;
}
namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file).
Status LoadPluggableDeviceLibrary(const char* library_filename, void** result);
} // namespace tensorflow
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
TF_Tensor** input_tensors,
TF_ShapeAndTypeList* input_tensors_as_shapes,
TF_ShapeAndTypeList** input_resource_shapes_and_types,
TF_ShapeAndTypeList** output_shapes,
TF_ShapeAndTypeList*** output_resource_shapes_and_types,
TF_Status* status) {
using tensorflow::NodeDef;
using tensorflow::OpRegistrationData;
using tensorflow::Tensor;
using tensorflow::shape_inference::DimensionHandle;
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeAndType;
using tensorflow::shape_inference::ShapeHandle;
const int num_inputs = input_shapes->num_items;
NodeDef node_def;
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name());
node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
OperationFromInterface(op)->Attrs().FillAttrValueMap(node_def.mutable_attr());
const tensorflow::OpRegistrationData* op_reg_data;
status->status =
tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
if (!status->status.ok()) return;
// Initialize a input_tensor vector with `nullptr` values.
std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
// A vector to keep track of newly created `tf::Tensor` objects.
std::vector<Tensor> all_input_tensors;
// Update the vector with information from `input_tensors` if provided.
if (input_tensors != nullptr) {
// Note that we take the address of the elements in `all_input_tensors`
// below. Allocate enough space so that no reallocation happens, which will
// make the pointers invalid.
all_input_tensors.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
if (input_tensors[i] == nullptr) continue;
all_input_tensors.emplace_back();
Tensor& input_tensor = all_input_tensors.back();
status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
if (!status->status.ok()) return;
input_tensors_vector[i] = &input_tensor;
}
}
// Create an inference context with dummy values, which will be updated later.
InferenceContext c(TF_GRAPH_DEF_VERSION, node_def, op_reg_data->op_def,
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
{},
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
// Set input_shapes.
for (int i = 0; i < num_inputs; ++i) {
std::vector<DimensionHandle> dims;
const TF_ShapeAndType& input_shape = input_shapes->items[i];
if (input_shape.num_dims == InferenceContext::kUnknownRank) {
c.SetInput(i, c.UnknownShape());
continue;
}
for (int j = 0; j < input_shape.num_dims; ++j) {
dims.push_back(c.MakeDim(input_shape.dims[j]));
}
c.SetInput(i, c.MakeShape(dims));
}
// TODO(bgogul): Handle input_tensors_as_shapes.
// TODO(bgogul): Handle input_resource_shapes_and_types.
status->status = c.construction_status();
if (!status->status.ok()) return;
if (op_reg_data->shape_inference_fn == nullptr) {
status->status =
InvalidArgument("No shape inference function exists for op '",
node_def.op(), "', did you forget to define it?");
return;
}
status->status = c.Run(op_reg_data->shape_inference_fn);
if (!status->status.ok()) return;
// Set output_shapes.
TF_ShapeAndTypeList* output_shapes_result =
TF_NewShapeAndTypeList(c.num_outputs());
for (int i = 0; i < c.num_outputs(); ++i) {
ShapeHandle shape_handle = c.output(i);
TF_ShapeAndType& shape = output_shapes_result->items[i];
shape.num_dims = c.Rank(shape_handle);
if (shape.num_dims == InferenceContext::kUnknownRank) {
shape.dims = nullptr;
continue;
}
shape.dims = new int64_t[shape.num_dims];
for (size_t j = 0; j < shape.num_dims; ++j) {
shape.dims[j] = c.Value(c.Dim(shape_handle, j));
}
}
if (output_shapes != nullptr) *output_shapes = output_shapes_result;
// TODO(bgogul): Set output_resource_shapes_and_types.
}
void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable) {
opts->opts.validate_colocation_constraints = enable;
}
// Load a Pluggable Device library.
// On success, returns the handle to library in result and return OK from the
// function. Otherwise return nullptr in result and error Status from the
// function.
//
// If `library_filename` has already been loaded, we return a cached handle.
// Device and Kernels/Ops are registered as globals when a library is loaded
// for the first time.
TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
status->status = tensorflow::errors::Unimplemented(
"PluggableDevice plugin functionality is not supported on mobile");
return nullptr;
#else
TF_Library* lib_handle = new TF_Library;
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static std::unordered_map<std::string, void*>* loaded_libs =
new std::unordered_map<std::string, void*>();
tensorflow::Env* env = tensorflow::Env::Default();
{
tensorflow::mutex_lock lock(mu);
auto it = loaded_libs->find(library_filename);
if (it != loaded_libs->end()) {
lib_handle->lib_handle = it->second;
} else {
status->status =
env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle);
if (!status->status.ok()) {
delete lib_handle;
return nullptr;
}
}
return lib_handle;
}
#endif
}
void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) {
delete lib_handle;
}

View File

@ -0,0 +1,332 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_C_C_API_EXPERIMENTAL_H_
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
//
// The API here is subject to changes in the future.
// --------------------------------------------------------------------------
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.$a
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
// When `enable` is true, set
// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also
// set XLA flag values to prepare for XLA compilation. Otherwise set
// global_jit_level to OFF.
//
// This and the next API are syntax sugar over TF_SetConfig(), and is used by
// clients that cannot read/write the tensorflow.ConfigProto proto.
// TODO: Migrate to TF_CreateConfig() below.
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
unsigned char enable);
// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the
// value of 'enabled'. Also returns the original value of that flag.
//
// Use in tests to allow XLA to fallback to TF classic. This has global effect.
TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation(
unsigned char enable);
TF_CAPI_EXPORT unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable);
// Sets XLA's auto jit mode according to the specified string, which is parsed
// as if passed in XLA_FLAGS. This has global effect.
TF_CAPI_EXPORT void TF_SetXlaAutoJitMode(const char* mode);
// Sets XLA's minimum cluster size. This has global effect.
TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size);
// Gets/Sets TF/XLA flag for whether(true) or not(false) to disable constant
// folding. This is for testing to ensure that XLA is being tested rather than
// Tensorflow's CPU implementation through constant folding.
TF_CAPI_EXPORT unsigned char TF_GetXlaConstantFoldingDisabled();
TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled(
unsigned char should_enable);
// Create a serialized tensorflow.ConfigProto proto, where:
//
// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if
// `enable_xla_compilation` is non-zero, and OFF otherwise.
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
// c) ConfigProto.device_count is set to `num_cpu_devices`.
TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth,
unsigned int num_cpu_devices);
// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level
// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE
// otherwise.
TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions(
unsigned char enable_full_trace);
// Returns the graph content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
size_t* len);
// Returns the function content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
//
// Do not return const char*, because some foreign language binding
// (e.g. swift) cannot then call free() on the returned pointer.
TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func,
size_t* len);
// On success, dequeues a tensor from a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_dequeue_<tensor_id>", to be executed by this API call.
// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is
// empty, this call is blocked.
//
// Tensors are enqueued via the corresponding TF enqueue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session,
int tensor_id,
TF_Status* status);
// On success, enqueues `tensor` into a TF-managed FifoQueue given by
// `tensor_id`, associated with `session`. There must be a graph node named
// "fifo_queue_enqueue_<tensor_id>", to be executed by this API call. It reads
// from a placeholder node "arg_tensor_enqueue_<tensor_id>".
//
// `tensor` is still owned by the caller. This call will be blocked if the queue
// has reached its capacity, and will be unblocked when the queued tensors again
// drop below the capacity due to dequeuing.
//
// Tensors are dequeued via the corresponding TF dequeue op.
// TODO(hongm): Add support for `timeout_ms`.
TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);
// Create a serialized tensorflow.ServerDef proto.
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
// TF_NewCheckpointReader() return the CheckpointReader that can be use to
// investigate or load the variable from the checkpoint file
typedef struct TF_CheckpointReader TF_CheckpointReader;
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
const char* filename, TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
TF_CheckpointReader* reader);
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
TF_CheckpointReader* reader, const char* name);
// Get the variable name at the given index
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
TF_CheckpointReader* reader, int index);
// Get the number of variable in the checkpoint
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
// Get the DataType of a variable
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
TF_CheckpointReader* reader, const char* name);
// Read the shape of a variable and write to `dims`
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
TF_Status* status);
// Get the number of dimension of a variable
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
TF_CheckpointReader* reader, const char* name);
// Load the weight of a variable
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
TF_CheckpointReader* reader, const char* name, TF_Status* status);
// TF_NewAttrBuilder() returns an object that you can set attributes on as
// though it were an op. This allows querying properties of that op for
// type-checking purposes like if the op will run on a particular device type.
typedef struct TF_AttrBuilder TF_AttrBuilder;
TF_CAPI_EXPORT extern TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name);
TF_CAPI_EXPORT extern void TF_DeleteAttrBuilder(TF_AttrBuilder* builder);
TF_CAPI_EXPORT extern void TF_AttrBuilderSetType(TF_AttrBuilder* builder,
const char* attr_name,
TF_DataType value);
TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder,
const char* attr_name,
const TF_DataType* values,
int num_values);
// Checks the tensorflow::NodeDef built via the methods above to see if it can
// run on device_type.
TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
TF_AttrBuilder* builder, const char* device_type, TF_Status* status);
// For argument number input_index, fetch the corresponding number_attr that
// needs to be updated with the argument length of the input list.
// Returns nullptr if there is any problem like op_name is not found, or the
// argument does not support this attribute type.
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
const char* op_name, int input_index, TF_Status* status);
// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined
// if the status is not ok.
TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type,
TF_Status* status);
// Platform specific initialization routine. Very few platforms actually require
// this to be called.
TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
// Platform-specific implementation to return an unused port. (This should used
// in tests only.)
TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void);
// Fast path method that makes constructing a single scalar tensor require less
// overhead and copies.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
TF_DataType data_type, void* data, size_t len, TF_Status* status);
// Specify the server_def that enables collective ops.
// This is different to the above function in that it doesn't create remote
// contexts, and remotely executing ops is not possible. It just enables
// communication for collective ops.
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
const void* proto,
size_t proto_len,
TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately. To reset the
// collectives, create a new EagerContext.
//
// This is intended to be used when a peer failure is detected.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Checks the health of collective ops peers. Explicit health check is needed in
// multi worker collective ops to detect failures in the cluster. If a peer is
// down, collective ops may hang.
TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth(
TFE_Context* ctx, const char* task, int64_t timeout_in_ms,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank.
int num_dims;
// Array of dimensions. -1 indicates unknown dim.
int64_t* dims;
// The data type. May be 0 to denote unknown type.
TF_DataType dtype;
};
typedef struct TF_ShapeAndType TF_ShapeAndType;
// A list of TF_ShapeAndType elements..
struct TF_ShapeAndTypeList {
int num_items;
TF_ShapeAndType* items;
};
typedef struct TF_ShapeAndTypeList TF_ShapeAndTypeList;
// API for manipulating TF_ShapeAndTypeList objects.
//
TF_CAPI_EXPORT extern TF_ShapeAndTypeList* TF_NewShapeAndTypeList(
int num_shapes);
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetShape(
TF_ShapeAndTypeList* shape_list, int index, const int64_t* dims,
int num_dims);
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetUnknownShape(
TF_ShapeAndTypeList* shape_list, int index);
TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetDtype(
TF_ShapeAndTypeList* shape_list, int index, TF_DataType dtype);
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList(
TF_ShapeAndTypeList* shape_list);
TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray(
TF_ShapeAndTypeList** shape_list_array, int num_items);
// Infer shapes for the given `op`. The arguments mimic the arguments of the
// `shape_inference::InferenceContext` constructor. Note the following:
// - The inputs of the `op` are not used for shape inference. So, it is
// OK to not have the inputs properly set in `op`. See `input_tensors`
// if you want shape inference to consider the input tensors of the
// op for shape inference.
// - The types need not be set in `input_shapes` as it is not used.
// - The number of `input_tensors` should be the same as the number of items
// in `input_shapes`.
//
// The results are returned in `output_shapes` and
// `output_resource_shapes_and_types`. The caller is responsible for freeing the
// memory in these buffers by calling `TF_DeleteShapeAndTypeList`.
TF_CAPI_EXPORT extern void TFE_InferShapes(
TFE_Op* op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors,
TF_ShapeAndTypeList* input_tensor_as_shapes,
TF_ShapeAndTypeList** input_resource_shapes_and_types,
TF_ShapeAndTypeList** output_shapes,
TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status);
TF_CAPI_EXPORT extern void
TF_ImportGraphDefOptionsSetValidateColocationConstraints(
TF_ImportGraphDefOptions* opts, unsigned char enable);
// Load the library specified by library_filename and register the pluggable
// device and related kernels present in that library. This function is not
// supported on embedded on mobile and embedded platforms and will fail if
// called.
//
// Pass "library_filename" to a platform-specific mechanism for dynamically
// loading a library. The rules for determining the exact location of the
// library are platform-specific and are not documented here.
//
// On success, returns the newly created library handle and places OK in status.
// The caller owns the library handle.
//
// On failure, returns nullptr and places an error status in status.
TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary(
const char* library_filename, TF_Status* status);
// Frees the memory associated with the library handle.
// Does NOT unload the library.
TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle(
TF_Library* lib_handle);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_

View File

@ -0,0 +1,256 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api_experimental.h"
#include "absl/types/optional.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
namespace {
TEST(CAPI_EXPERIMENTAL, GetServerDefTest) {
const string expected_text_proto(R"(cluster {
job {
name: "worker"
tasks {
key: 0
value: "tpuserver:0"
}
tasks {
key: 1
value: "localhost:1"
}
}
}
job_name: "worker"
task_index: 1
protocol: "grpc"
)");
TF_Status* status = TF_NewStatus();
TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK);
ServerDef actual;
ASSERT_TRUE(actual.ParseFromArray(result->data, result->length));
string actual_text_proto;
tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto);
EXPECT_EQ(expected_text_proto, actual_text_proto);
const string malformed_text_proto(R"(cluster {
job {
name: "worker")");
TF_Buffer* null_result =
TFE_GetServerDef(malformed_text_proto.c_str(), status);
EXPECT_NE(TF_GetCode(status), TF_OK);
EXPECT_TRUE(absl::StrContains(TF_Message(status),
"Invalid text proto for ServerDef"));
EXPECT_EQ(null_result, nullptr);
// Cleanup
TF_DeleteBuffer(result);
TF_DeleteStatus(status);
}
TEST(CAPI_EXPERIMENTAL, IsStateful) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
int assign = TF_OpIsStateful("AssignAddVariableOp", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
EXPECT_EQ(assign, 1);
int id = TF_OpIsStateful("Identity", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
EXPECT_EQ(id, 0);
}
class ShapeInferenceTest : public ::testing::Test {
protected:
ShapeInferenceTest()
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
}
~ShapeInferenceTest() override {
TFE_DeleteContextOptions(tfe_context_options_);
TFE_DeleteContext(tfe_context_);
TF_DeleteStatus(status_);
}
// Checks the expected result of shape inference for the given `op`.
void CheckOutputShapes(
TFE_Op* op,
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
const std::vector<TF_Tensor*>& input_tensors,
const absl::optional<std::vector<int64_t>>& expected_shape) {
// Create input_shapes.
TF_ShapeAndTypeList* input_shapes =
TF_NewShapeAndTypeList(input_shapes_vec.size());
for (size_t i = 0; i < input_shapes_vec.size(); ++i) {
const auto& input_shape = input_shapes_vec[i];
if (input_shape.has_value()) {
TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(),
input_shape->size());
} else {
TF_ShapeAndTypeListSetUnknownShape(input_shapes, i);
}
}
TF_ShapeAndTypeList* output_shapes;
TFE_InferShapes(op, input_shapes,
input_tensors.empty()
? nullptr
: const_cast<TF_Tensor**>(input_tensors.data()),
/*input_tensors_as_shapes*/ nullptr,
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
/*output_resource_shapes_and_types*/ nullptr, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(output_shapes->num_items, 1);
int num_dims = output_shapes->items[0].num_dims;
int64_t* dims = output_shapes->items[0].dims;
if (!expected_shape.has_value()) {
EXPECT_EQ(num_dims, -1);
EXPECT_EQ(dims, nullptr);
return;
}
EXPECT_EQ(num_dims, expected_shape->size());
for (size_t i = 0; i < num_dims; ++i) {
EXPECT_EQ(dims[i], (*expected_shape)[i]);
}
TF_DeleteShapeAndTypeList(input_shapes);
TF_DeleteShapeAndTypeList(output_shapes);
}
absl::optional<std::vector<int64_t>> make_shape(
std::vector<int64_t>&& dims) const {
return absl::make_optional(dims);
}
absl::optional<std::vector<int64_t>> unknown_shape() const {
return absl::nullopt;
}
static constexpr int64_t kUnknownDim =
shape_inference::InferenceContext::kUnknownDim;
TF_Status* status_;
TFE_ContextOptions* tfe_context_options_;
TFE_Context* tfe_context_;
};
TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
TFE_Op* matmul_op;
matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
// Infer shape when everything is known.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
/*input_tensors*/ {},
/*expected_shape*/ make_shape({3, 4}));
// Infer shape when second operand has unknown shape.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
/*input_tensors*/ {},
/*expected_shape*/ make_shape({3, kUnknownDim}));
// Infer shape when some dimensions are unknown.
CheckOutputShapes(
matmul_op,
/*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
/*input_tensors*/ {},
/*expected_shape*/ make_shape({kUnknownDim, 4}));
// Infer shape when everything is unknown.
CheckOutputShapes(matmul_op,
/*input_shapes*/ {unknown_shape(), unknown_shape()},
/*input_tensors*/ {},
/*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
TFE_DeleteOp(matmul_op);
// TODO(bgogul): Add some death tests where status is not OK.
}
TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) {
// Prepare some tensors for shape.
TF_Tensor* tensor_1X6 = Int32Tensor({1, 6});
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6});
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT);
TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32);
CheckOutputShapes(reshape_op,
/* input_shapes*/ {unknown_shape(), unknown_shape()},
/* input_tensors*/ {nullptr, tensor_1X6},
/*expected_shape*/ make_shape({1, 6}));
TFE_DeleteOp(reshape_op);
reshape_op = nullptr;
TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
TFE_OpSetAttrType(fill_op, "T", TF_FLOAT);
TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32);
float five = 5.0;
TFE_TensorHandle* scalar = TestScalarTensorHandle(tfe_context_, five);
TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CheckOutputShapes(fill_op,
/* input_shapes*/ {unknown_shape(), unknown_shape()},
/* input_tensors*/ {tensor_1X1X6, scalarTensor},
/*expected_shape*/ make_shape({1, 1, 6}));
TFE_DeleteOp(fill_op);
fill_op = nullptr;
TFE_DeleteTensorHandle(scalar);
TF_DeleteTensor(scalarTensor);
TF_DeleteTensor(tensor_1X1X6);
TF_DeleteTensor(tensor_1X6);
}
TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) {
#if !defined(TENSORFLOW_NO_SHARED_OBJECTS)
// Load the library.
TF_Status* status = TF_NewStatus();
string lib_path =
tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath(
"tensorflow", "c", "experimental", "stream_executor", "test",
"test_pluggable_device.so"));
TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status);
TF_Code code = TF_GetCode(status);
string status_msg(TF_Message(status));
TF_DeleteStatus(status);
ASSERT_EQ(TF_OK, code) << status_msg;
TF_DeletePluggableDeviceLibraryHandle(lib);
#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS)
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,332 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/base64.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::errors::InvalidArgument;
namespace tensorflow {
namespace {
Status ValidateNonRefOutput(const Node* node, int idx) {
const DataType& dt = node->output_type(idx);
return IsRefType(dt)
? InvalidArgument("Output ", idx, " of node '", node->name(),
"' has a reference type ", DataTypeString(dt))
: Status::OK();
}
// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
// does various checks while doing so. `input_nodes` will contain the same
// information as input_tensors just in a different structure to make
// following processing easier. TODO(iga): Simplify this nested structure.
Status ProcessInputs(
const TF_Graph* fn_body, const char* fn_name, int ninputs,
const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
std::unordered_map<const Node*, std::vector<int>>* input_nodes)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
input_tensors->emplace_back(node, idx);
const auto& iter = input_nodes->find(node);
if (iter == input_nodes->end()) {
input_nodes->insert({node, {idx}});
} else {
auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
return InvalidArgument("TF_Output ", node->name(), ":", idx,
" appears more than once in the input list");
}
indices.push_back(idx);
}
}
return Status::OK();
}
// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
// checks while doing so.
Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
int noutputs, const TF_Output* outputs,
std::vector<OutputTensor>* output_tensors)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while creating function '",
fn_name, "'");
output_tensors->emplace_back(node, idx);
}
return Status::OK();
}
// Populates `body_nodes` with the nodes that will become function's body.
// Performs various checks.
Status ComputeBodyNodes(
const TF_Graph* fn_body, const char* fn_name, int num_opers,
const TF_Operation* const* opers,
const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
std::vector<const Node*>* body_nodes)
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
if (num_opers == -1) {
for (const Node* node : fn_body->graph.op_nodes()) {
const auto& iter = input_nodes.find(node);
if (iter == input_nodes.end()) {
// This node is not referenced in inputs. Add it to the body.
body_nodes->push_back(node);
} else {
// This node is referenced in inputs. Currently, we place an
// artificial restriction and require that when num_opers=-1, such
// nodes must have a single output.
if (node->num_outputs() != 1) {
return InvalidArgument(
"When `num_opers` is set to -1, nodes referenced in `inputs` "
"must have a single output. Node ",
node->name(), " has ", node->num_outputs(),
" outputs. Encountered while creating function '", fn_name, "'");
}
}
}
} else {
body_nodes->reserve(num_opers);
for (int i = 0; i < num_opers; ++i) {
const Node* node = &opers[i]->node;
body_nodes->push_back(node);
}
}
return Status::OK();
}
} // namespace
} // namespace tensorflow
using tensorflow::Node;
using tensorflow::string;
TF_Function* TF_GraphToFunctionWithControlOutputs(
const TF_Graph* fn_body, const char* fn_name,
unsigned char append_hash_to_fn_name, int num_opers,
const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs, const char* const* output_names,
int ncontrol_outputs, const TF_Operation* const* control_outputs,
const char* const* control_output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* status) {
tensorflow::mutex_lock l(fn_body->mu);
// Process inputs.
std::vector<tensorflow::OutputTensor> input_tensors;
std::unordered_map<const Node*, std::vector<int>> input_nodes;
status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
&input_tensors, &input_nodes);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process outputs.
std::vector<tensorflow::OutputTensor> output_tensors;
status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
outputs, &output_tensors);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Process output names.
std::vector<string> output_names_vec;
if (output_names) {
output_names_vec.reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
output_names_vec.push_back(string(output_names[i]));
}
}
// Process control output names.
std::vector<string> control_output_names_vec;
if (control_output_names) {
control_output_names_vec.reserve(ncontrol_outputs);
for (int i = 0; i < ncontrol_outputs; ++i) {
control_output_names_vec.push_back(string(output_names[i]));
}
}
// Compute body nodes.
std::vector<const Node*> body_nodes;
status->status = tensorflow::ComputeBodyNodes(
fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
if (TF_GetCode(status) != TF_OK) return nullptr;
// Compute body nodes.
std::vector<const Node*> control_output_nodes;
for (int i = 0; i < ncontrol_outputs; ++i) {
control_output_nodes.push_back(&control_outputs[i]->node);
}
// Do the actual function creation.
TF_Function* tf_function = new TF_Function();
DCHECK(append_hash_to_fn_name <= 1);
status->status = tensorflow::GraphToFunctionDef(
fn_body->graph, fn_name, append_hash_to_fn_name != 0,
/*set_stateful_from_nodes=*/true,
/*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
output_tensors, output_names_vec, control_output_nodes,
control_output_names_vec, description, &tf_function->fdef);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteFunction(tf_function);
return nullptr;
}
for (const Node* n : fn_body->graph.nodes()) {
tf_function->stack_traces[n->name()] = n->GetStackTrace();
}
return tf_function;
}
TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
unsigned char append_hash_to_fn_name,
int num_opers, const TF_Operation* const* opers,
int ninputs, const TF_Output* inputs,
int noutputs, const TF_Output* outputs,
const char* const* output_names,
const TF_FunctionOptions* opts,
const char* description, TF_Status* status) {
return TF_GraphToFunctionWithControlOutputs(
fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
description, status);
}
const char* TF_FunctionName(TF_Function* func) {
return func->fdef.signature().name().c_str();
}
void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
const TF_Function* grad, TF_Status* status) {
if (func == nullptr) {
status->status = InvalidArgument(
"'func' argument to TF_GraphCopyFunction cannot be null");
return;
}
// TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
// to avoid the extra copy here.
tensorflow::FunctionDefLibrary fdef_lib;
*fdef_lib.add_function() = func->fdef;
if (grad) {
*fdef_lib.add_function() = grad->fdef;
tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
gdef->set_function_name(func->fdef.signature().name());
gdef->set_gradient_func(grad->fdef.signature().name());
}
tensorflow::mutex_lock l(g->mu);
status->status = g->graph.AddFunctionLibrary(fdef_lib);
}
int TF_GraphNumFunctions(TF_Graph* g) {
tensorflow::mutex_lock l(g->mu);
return g->graph.flib_def().num_functions();
}
int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
TF_Status* status) {
tensorflow::FunctionDefLibrary lib;
{
tensorflow::mutex_lock l(g->mu);
lib = g->graph.flib_def().ToProto();
}
const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
for (int i = 0; i < len; ++i) {
TF_Function* func = new TF_Function();
func->fdef = lib.function(i);
funcs[i] = func;
}
status->status = tensorflow::Status::OK();
return len;
}
void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
TF_Status* status) {
status->status = MessageToBuffer(func->fdef, output_func_def);
}
TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
TF_Status* status) {
TF_Function* func = new TF_Function();
if (!func->fdef.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument(
"Invalid FunctionDef given to TF_FunctionImportFunctionDef");
TF_DeleteFunction(func);
return nullptr;
}
status->status = tensorflow::Status::OK();
return func;
}
void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
const void* proto, size_t proto_len,
TF_Status* status) {
tensorflow::AttrValue attr_value;
if (!attr_value.ParseFromArray(proto, proto_len)) {
status->status = InvalidArgument(
"Unparseable AttrValue proto passed to "
"TF_FunctionSetAttrValueProto");
return;
}
(*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
status->status = tensorflow::Status::OK();
}
void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
TF_Buffer* output_attr_value,
TF_Status* status) {
const auto& it = func->fdef.attr().find(attr_name);
if (it == func->fdef.attr().end()) {
status->status =
InvalidArgument("Function '", func->fdef.signature().name(),
"' has no attr named '", attr_name, "'.");
return;
}
status->status = MessageToBuffer(it->second, output_attr_value);
}
void TF_DeleteFunction(TF_Function* func) { delete func; }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,219 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_INTERNAL_H_
#define TENSORFLOW_C_C_API_INTERNAL_H_
#include "tensorflow/c/c_api.h"
#include <list>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/c/tf_tensor_internal.h"
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/framework/op_gen_lib.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
class Device;
class DeviceMgr;
class ServerInterface;
} // namespace tensorflow
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_SessionOptions {
tensorflow::SessionOptions options;
};
struct TF_DeprecatedSession {
tensorflow::Session* session;
};
struct TF_Library {
void* lib_handle;
TF_Buffer op_list;
};
struct TF_Graph {
TF_Graph();
mutable tensorflow::mutex mu;
tensorflow::Graph graph TF_GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
TF_GUARDED_BY(mu);
// The keys of this map are all the active sessions using this graph. Each
// value records whether the graph has been mutated since the corresponding
// session has been run (this is detected in RecordMutation function). If the
// string is empty, no mutation has occurred. Otherwise the string is a
// description of the mutation suitable for returning to the user.
//
// Sessions are added to this map in TF_NewSession, and removed in
// TF_DeleteSession.
// TF_Graph may only / must be deleted when
// sessions.size() == 0 && delete_requested == true
//
// TODO(b/74949947): mutations currently trigger a warning instead of a bad
// status, this should be reverted when possible.
tensorflow::gtl::FlatMap<TF_Session*, tensorflow::string> sessions
TF_GUARDED_BY(mu);
bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
TF_Graph* parent;
TF_Output* parent_inputs;
};
struct TF_OperationDescription {
TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
tensorflow::NodeBuilder node_builder;
TF_Graph* graph;
std::set<tensorflow::string> colocation_constraints;
};
struct TF_Operation {
tensorflow::Node node;
};
struct TF_Session {
TF_Session(tensorflow::Session* s, TF_Graph* g);
tensorflow::Session* session;
TF_Graph* const graph;
tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu);
int last_num_graph_nodes;
// If true, TF_SessionRun and similar methods will call
// ExtendSessionGraphHelper before running the graph (this is the default
// public behavior). Can be set to false if the caller needs to call
// ExtendSessionGraphHelper manually.
std::atomic<bool> extend_before_run;
};
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
// Backing memory for TensorId fields in opts.
// TODO(skyewm): it'd be better if ImportGraphDefOptions owned this.
std::list<tensorflow::string> tensor_id_data;
};
struct TF_ImportGraphDefResults {
std::vector<TF_Output> return_tensors;
std::vector<TF_Operation*> return_nodes;
std::vector<const char*> missing_unused_key_names;
std::vector<int> missing_unused_key_indexes;
// Backing memory for missing_unused_key_names values.
std::list<tensorflow::string> missing_unused_key_names_data;
};
struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};
struct TF_Function {
tensorflow::FunctionDef fdef;
tensorflow::StackTracesMap stack_traces;
};
struct TF_ApiDefMap {
explicit TF_ApiDefMap(const tensorflow::OpList& op_list)
:
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
api_def_map(op_list),
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
update_docs_called(false) {
}
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock);
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
bool update_docs_called TF_GUARDED_BY(lock);
tensorflow::mutex lock;
};
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
struct TF_Server {
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server);
const tensorflow::string target;
std::unique_ptr<tensorflow::ServerInterface> server;
};
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
namespace tensorflow {
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out);
// Set the shapes and types of the output's handle.
//
// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must
// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the
// rank is known), then it must be equal to the length of `shapes[i]`; if
// `ranks[i] == 1`, then `shapes[i]` may be nullptr.
//
// TODO(akshayka): Implement a corresponding getter method.
void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
int num_shapes_and_types,
const int64_t** shapes,
const int* ranks,
const TF_DataType* types,
TF_Status* status);
void RecordMutation(TF_Graph* graph, const TF_Operation& op,
const char* mutation_type)
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu);
bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status)
TF_LOCKS_EXCLUDED(session->graph->mu, session->mu);
std::string getTF_OutputDebugString(TF_Output node);
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_API_MACROS_H_
#define TENSORFLOW_C_C_API_MACROS_H_
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
// the datatype for boolean tensors.
#ifndef TF_Bool
#define TF_Bool unsigned char
#endif // TF_Bool
// Macro used to calculate struct size for maintaining ABI stability across
// different struct implementations.
#ifndef TF_OFFSET_OF_END
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
#endif // TF_OFFSET_OF_END
#endif // TENSORFLOW_C_C_API_MACROS_H_

File diff suppressed because it is too large Load Diff

89
tensorflow/c/c_test.c Normal file
View File

@ -0,0 +1,89 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <limits.h>
#include <memory.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <unistd.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/kernels.h"
// A create function. This will never actually get called in this test, it's
// just nice to know that it compiles.
void* create(TF_OpKernelConstruction* ctx) {
TF_DataType type;
TF_Status* s = TF_NewStatus();
TF_OpKernelConstruction_GetAttrType(ctx, "foobar", &type, s);
TF_DeleteStatus(s);
return NULL;
}
// A compute function. This will never actually get called in this test, it's
// just nice to know that it compiles.
void compute(void* kernel, TF_OpKernelContext* ctx) {
TF_Tensor* input;
TF_Status* s = TF_NewStatus();
TF_GetInput(ctx, 0, &input, s);
TF_DeleteTensor(input);
TF_DeleteStatus(s);
}
// Exercises tensorflow's C API.
int main(int argc, char** argv) {
TF_InitMain(argv[0], &argc, &argv);
struct TF_StringStream* s = TF_GetLocalTempDirectories();
const char* path;
if (!TF_StringStreamNext(s, &path)) {
fprintf(stderr, "TF_GetLocalTempDirectories returned no results\n");
return 1;
}
char file_name[100];
time_t t = time(NULL);
snprintf(file_name, sizeof(file_name), "test-%d-%ld.txt", getpid(), t);
size_t length = 2 + strlen(path) + strlen(file_name);
char* full_path = malloc(length);
snprintf(full_path, length, "%s/%s", path, file_name);
TF_WritableFileHandle* h;
TF_Status* status = TF_NewStatus();
TF_NewWritableFile(full_path, &h, status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "TF_NewWritableFile failed: %s\n", TF_Message(status));
return 1;
}
fprintf(stderr, "wrote %s\n", full_path);
free(full_path);
TF_CloseWritableFile(h, status);
if (TF_GetCode(status) != TF_OK) {
fprintf(stderr, "TF_CloseWritableFile failed: %s\n", TF_Message(status));
}
TF_StringStreamDone(s);
TF_KernelBuilder* b =
TF_NewKernelBuilder("SomeOp", "SomeDevice", &create, &compute, NULL);
TF_RegisterKernelBuilder("someKernel", b, status);
TF_DeleteStatus(status);
return 0;
}

549
tensorflow/c/c_test_util.cc Normal file
View File

@ -0,0 +1,549 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/public/session_options.h"
using tensorflow::GraphDef;
using tensorflow::NodeDef;
static void BoolDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<bool*>(data);
}
static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
static void DoubleDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<double*>(data);
}
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
TF_Tensor* BoolTensor(bool v) {
const int num_bytes = sizeof(bool);
bool* values = new bool[1];
values[0] = v;
return TF_NewTensor(TF_BOOL, nullptr, 0, values, num_bytes, &BoolDeallocator,
nullptr);
}
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
return t;
}
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
num_values *= dims[i];
}
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values);
memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values);
return t;
}
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values) {
int64_t dims = values.size();
return Int32Tensor(&dims, 1, values.data());
}
TF_Tensor* Int32Tensor(int32_t v) {
const int num_bytes = sizeof(int32_t);
int32_t* values = new int32_t[1];
values[0] = v;
return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
&Int32Deallocator, nullptr);
}
TF_Tensor* DoubleTensor(double v) {
const int num_bytes = sizeof(double);
double* values = new double[1];
values[0] = v;
return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes,
&DoubleDeallocator, nullptr);
}
TF_Tensor* FloatTensor(float v) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes,
&FloatDeallocator, nullptr);
}
// All the *Helper methods are used as a workaround for the restrictions that
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
TF_DataType dtype, const std::vector<int64_t>& dims,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", dtype);
if (!dims.empty()) {
TF_SetAttrShape(desc, "shape", dims.data(), dims.size());
}
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name,
TF_DataType dtype, const std::vector<int64_t>& dims) {
TF_Operation* op;
PlaceholderHelper(graph, s, name, dtype, dims, &op);
return op;
}
void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
ConstHelper(t, graph, s, name, &op);
return op;
}
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(BoolTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
const char* name) {
unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name, TF_Operation** op,
bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
*op = TF_FinishOperation(desc, s);
if (check) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddOpHelper(l, r, graph, s, name, &op, true);
return op;
}
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
AddOpHelper(l, r, graph, s, name, &op, false);
return op;
}
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
TF_AddControlInput(desc, ctrl_op);
return TF_FinishOperation(desc, s);
}
// If `op_device` is non-empty, set the created op on that device.
void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op, const string& op_device, bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name);
if (!op_device.empty()) {
TF_SetDevice(desc, op_device.c_str());
}
TF_AddInput(desc, {l, 0});
TF_AddInput(desc, {r, 0});
*op = TF_FinishOperation(desc, s);
if (check) {
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
}
TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
const string& op_device, TF_Status* s,
const char* name) {
TF_Operation* op;
BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true);
return op;
}
TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
}
TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name) {
TF_Operation* op;
BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
return op;
}
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output inputs[2] = {l, r};
TF_AddInputList(desc, inputs, 2);
return TF_FinishOperation(desc, s);
}
void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name,
TF_Operation** op) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", name);
TF_Output neg_input = {n, 0};
TF_AddInput(desc, neg_input);
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
NegHelper(n, graph, s, name, &op);
return op;
}
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
TF_AddInput(desc, l);
TF_AddInput(desc, r);
return TF_FinishOperation(desc, s);
}
TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
TF_Graph* graph, TF_Status* s) {
TF_OperationDescription* desc =
TF_NewOperation(graph, "RandomUniform", "random_uniform");
TF_AddInput(desc, {shape, 0});
TF_SetAttrType(desc, "dtype", dtype);
return TF_FinishOperation(desc, s);
}
void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op) {
TF_Operation* zero = ScalarConst(
0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str());
TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name);
TF_AddInput(desc, {zero, 0});
TF_AddInput(desc, {input, 0});
TF_SetAttrInt(desc, "num_split", 3);
TF_SetAttrType(desc, "T", TF_INT32);
// Set device to CPU since there is no version of split for int32 on GPU
// TODO(iga): Convert all these helpers and tests to use floats because
// they are usually available on GPUs. After doing this, remove TF_SetDevice
// call in c_api_function_test.cc
TF_SetDevice(desc, "/cpu:0");
*op = TF_FinishOperation(desc, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
ASSERT_NE(*op, nullptr);
}
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_Operation* op;
Split3Helper(input, graph, s, name, &op);
return op;
}
bool IsPlaceholder(const tensorflow::NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
}
bool found_dtype = false;
bool found_shape = false;
for (const auto& attr : node_def.attr()) {
if (attr.first == "dtype") {
if (attr.second.type() == tensorflow::DT_INT32) {
found_dtype = true;
} else {
return false;
}
} else if (attr.first == "shape") {
found_shape = true;
}
}
return found_dtype && found_shape;
}
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v) {
if (node_def.op() != "Const" || node_def.name() != "scalar") {
return false;
}
bool found_dtype = false;
bool found_value = false;
for (const auto& attr : node_def.attr()) {
if (attr.first == "dtype") {
if (attr.second.type() == tensorflow::DT_INT32) {
found_dtype = true;
} else {
return false;
}
} else if (attr.first == "value") {
if (attr.second.has_tensor() &&
attr.second.tensor().int_val_size() == 1 &&
attr.second.tensor().int_val(0) == v) {
found_value = true;
} else {
return false;
}
}
}
return found_dtype && found_value;
}
bool IsAddN(const tensorflow::NodeDef& node_def, int n) {
if (node_def.op() != "AddN" || node_def.name() != "add" ||
node_def.input_size() != n) {
return false;
}
bool found_t = false;
bool found_n = false;
for (const auto& attr : node_def.attr()) {
if (attr.first == "T") {
if (attr.second.type() == tensorflow::DT_INT32) {
found_t = true;
} else {
return false;
}
} else if (attr.first == "N") {
if (attr.second.i() == n) {
found_n = true;
} else {
return false;
}
}
}
return found_t && found_n;
}
bool IsNeg(const tensorflow::NodeDef& node_def, const string& input) {
return node_def.op() == "Neg" && node_def.name() == "neg" &&
node_def.input_size() == 1 && node_def.input(0) == input;
}
bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_GraphToGraphDef(graph, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_OperationToNodeDef(oper, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) {
TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer();
TF_FunctionToFunctionDef(func, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
TF_DeleteStatus(s);
return ret;
}
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer();
TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer);
return ret;
}
std::vector<std::pair<string, string>> GetGradDefs(
const tensorflow::GraphDef& graph_def) {
std::vector<std::pair<string, string>> grads;
for (const tensorflow::GradientDef& grad : graph_def.library().gradient()) {
grads.emplace_back(grad.function_name(), grad.gradient_func());
}
std::sort(grads.begin(), grads.end());
return grads;
}
std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def) {
std::vector<string> names;
for (const tensorflow::FunctionDef& func : graph_def.library().function()) {
names.push_back(func.signature().name());
}
std::sort(names.begin(), names.end());
return names;
}
CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) {
TF_SessionOptions* opts = TF_NewSessionOptions();
TF_EnableXLACompilation(opts, use_XLA);
session_ = TF_NewSession(graph, opts, s);
TF_DeleteSessionOptions(opts);
}
CSession::CSession(TF_Session* session) : session_(session) {}
CSession::~CSession() {
TF_Status* s = TF_NewStatus();
CloseAndDelete(s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteStatus(s);
}
void CSession::SetInputs(
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
DeleteInputValues();
inputs_.clear();
for (const auto& p : inputs) {
inputs_.emplace_back(TF_Output{p.first, 0});
input_values_.emplace_back(p.second);
}
}
void CSession::SetOutputs(std::initializer_list<TF_Operation*> outputs) {
ResetOutputValues();
outputs_.clear();
for (TF_Operation* o : outputs) {
outputs_.emplace_back(TF_Output{o, 0});
}
output_values_.resize(outputs_.size());
}
void CSession::SetOutputs(const std::vector<TF_Output>& outputs) {
ResetOutputValues();
outputs_ = outputs;
output_values_.resize(outputs_.size());
}
void CSession::SetTargets(std::initializer_list<TF_Operation*> targets) {
targets_.clear();
for (TF_Operation* t : targets) {
targets_.emplace_back(t);
}
}
void CSession::Run(TF_Status* s) {
if (inputs_.size() != input_values_.size()) {
ADD_FAILURE() << "Call SetInputs() before Run()";
return;
}
ResetOutputValues();
output_values_.resize(outputs_.size(), nullptr);
const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
TF_Tensor* const* input_values_ptr =
input_values_.empty() ? nullptr : &input_values_[0];
const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
TF_Tensor** output_values_ptr =
output_values_.empty() ? nullptr : &output_values_[0];
TF_Operation* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0];
TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, inputs_.size(),
outputs_ptr, output_values_ptr, outputs_.size(), targets_ptr,
targets_.size(), nullptr, s);
DeleteInputValues();
}
void CSession::CloseAndDelete(TF_Status* s) {
DeleteInputValues();
ResetOutputValues();
if (session_ != nullptr) {
TF_CloseSession(session_, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteSession(session_, s);
session_ = nullptr;
}
}
void CSession::DeleteInputValues() {
for (size_t i = 0; i < input_values_.size(); ++i) {
TF_DeleteTensor(input_values_[i]);
}
input_values_.clear();
}
void CSession::ResetOutputValues() {
for (size_t i = 0; i < output_values_.size(); ++i) {
if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
}
output_values_.clear();
}

165
tensorflow/c/c_test_util.h Normal file
View File

@ -0,0 +1,165 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_C_TEST_UTIL_H_
#define TENSORFLOW_C_C_TEST_UTIL_H_
#include "tensorflow/c/c_api.h"
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"
using ::tensorflow::string;
typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
unique_tensor_ptr;
TF_Tensor* BoolTensor(int32_t v);
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values);
// Create a tensor with values of type TF_INT32 provided by `values`.
TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims,
const int32_t* values);
// Create 1 dimensional tensor with values from `values`
TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
TF_Tensor* Int32Tensor(int32_t v);
TF_Tensor* DoubleTensor(double v);
TF_Tensor* FloatTensor(float v);
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
const char* name = "feed",
TF_DataType dtype = TF_INT32,
const std::vector<int64_t>& dims = {});
TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
const char* name = "const");
TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r,
TF_Graph* graph, TF_Operation* ctrl_op,
TF_Status* s, const char* name = "add");
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name = "add");
TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "min");
TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "mul");
// If `op_device` is non-empty, set the created op on that device.
TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
const string& op_device, TF_Status* s,
const char* name = "min");
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s,
const char* name = "neg");
TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s);
TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype,
TF_Graph* graph, TF_Status* s);
// Split `input` along the first dimension into 3 tensors
TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s,
const char* name = "split3");
bool IsPlaceholder(const tensorflow::NodeDef& node_def);
bool IsScalarConst(const tensorflow::NodeDef& node_def, int v);
bool IsAddN(const tensorflow::NodeDef& node_def, int n);
bool IsNeg(const tensorflow::NodeDef& node_def, const string& input);
bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def);
bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def);
bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s);
// Returns a sorted vector of std::pair<function_name, gradient_func> from
// graph_def.library().gradient()
std::vector<std::pair<string, string>> GetGradDefs(
const tensorflow::GraphDef& graph_def);
// Returns a sorted vector of names contained in `grad_def`
std::vector<string> GetFuncNames(const tensorflow::GraphDef& graph_def);
class CSession {
public:
CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false);
explicit CSession(TF_Session* session);
~CSession();
void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs);
void SetOutputs(std::initializer_list<TF_Operation*> outputs);
void SetOutputs(const std::vector<TF_Output>& outputs);
void SetTargets(std::initializer_list<TF_Operation*> targets);
void Run(TF_Status* s);
void CloseAndDelete(TF_Status* s);
TF_Tensor* output_tensor(int i) { return output_values_[i]; }
TF_Session* mutable_session() { return session_; }
private:
void DeleteInputValues();
void ResetOutputValues();
TF_Session* session_;
std::vector<TF_Output> inputs_;
std::vector<TF_Tensor*> input_values_;
std::vector<TF_Output> outputs_;
std::vector<TF_Tensor*> output_values_;
std::vector<TF_Operation*> targets_;
};
#endif // TENSORFLOW_C_C_TEST_UTIL_H_

View File

@ -16,51 +16,51 @@ limitations under the License.
#include "tensorflow/c/checkpoint_reader.h"
#include <unordered_set>
#include <utility>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
namespace tensorflow {
namespace checkpoint {
class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
: reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) {
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
: reader_(nullptr),
v2_reader_(nullptr),
var_to_shape_map_(nullptr),
var_to_data_type_map_(nullptr) {
// Depending on whether this is a V2 ckpt, initializes "reader_" or
// "v2_reader_".
std::vector<string> v2_path;
if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() &&
!v2_path.empty()) {
v2_reader_ =
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */);
v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status());
Set_TF_Status_from_Status(status, v2_reader_->status());
return;
}
var_to_shape_map_ptr_ = BuildV2VarToShapeMap();
auto result = BuildV2VarMaps();
var_to_shape_map_.swap(result.first);
var_to_data_type_map_.swap(result.second);
} else {
reader_ = new TensorSliceReader(filename);
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
Set_TF_Status_from_Status(status, reader_->status());
return;
}
var_to_shape_map_ptr_ =
new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap());
var_to_shape_map_.reset(
new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()));
var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap(
reader_->GetVariableToDataTypeMap()));
}
}
CheckpointReader::~CheckpointReader() {
delete var_to_shape_map_ptr_;
delete reader_;
delete v2_reader_;
}
bool CheckpointReader::HasTensor(const string& name) const {
if (reader_ != nullptr) {
return reader_->HasTensor(name, nullptr, nullptr);
@ -70,8 +70,14 @@ bool CheckpointReader::HasTensor(const string& name) const {
const TensorSliceReader::VarToShapeMap&
CheckpointReader::GetVariableToShapeMap() const {
CHECK(var_to_shape_map_ptr_);
return *var_to_shape_map_ptr_;
CHECK(var_to_shape_map_);
return *var_to_shape_map_;
}
const TensorSliceReader::VarToDataTypeMap&
CheckpointReader::GetVariableToDataTypeMap() const {
CHECK(var_to_data_type_map_);
return *var_to_data_type_map_;
}
const string CheckpointReader::DebugString() const {
@ -100,7 +106,9 @@ void CheckpointReader::GetTensor(
}
}
TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() {
std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
std::unique_ptr<TensorSliceReader::VarToDataTypeMap>>
CheckpointReader::BuildV2VarMaps() {
CHECK(v2_reader_ != nullptr);
CHECK(v2_reader_->status().ok());
@ -116,25 +124,30 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() {
const auto& slice_proto = entry.slices(i);
CHECK(filtered_keys
.insert(EncodeTensorNameSlice(
v2_reader_->key().ToString() /* full var's name */,
string(v2_reader_->key()) /* full var's name */,
TensorSlice(slice_proto)))
.second);
}
}
// Second pass: adds the entries, ignoring the filtered keys.
TensorSliceReader::VarToShapeMap* var_to_shape_map =
new TensorSliceReader::VarToShapeMap;
std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map(
new TensorSliceReader::VarToShapeMap);
std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map(
new TensorSliceReader::VarToDataTypeMap);
v2_reader_->Seek(kHeaderEntryKey);
for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) {
if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue;
if (filtered_keys.count(string(v2_reader_->key())) > 0) continue;
CHECK(entry.ParseFromArray(v2_reader_->value().data(),
v2_reader_->value().size()))
<< entry.InitializationErrorString();
(*var_to_shape_map)[v2_reader_->key().ToString()] =
TensorShape(entry.shape());
string key(v2_reader_->key());
(*var_to_shape_map)[key] = TensorShape(entry.shape());
(*var_to_data_type_map)[key] = DataType(entry.dtype());
}
return var_to_shape_map; // Owned by caller.
// The returned pointers are owned by the caller.
return std::make_pair(std::move(var_to_shape_map),
std::move(var_to_data_type_map));
}
} // namespace checkpoint

View File

@ -13,18 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
#define TENSORFLOW_C_CHECKPOINT_READER_H
#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_
#define TENSORFLOW_C_CHECKPOINT_READER_H_
#include <memory>
#include <string>
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
#include "tensorflow/core/util/tensor_slice_reader.h"
namespace tensorflow {
namespace checkpoint {
class TensorSliceReader;
@ -37,16 +39,19 @@ class TensorSliceReader;
// variables.
class CheckpointReader {
public:
CheckpointReader(const string& filepattern, TF_Status* out_status);
~CheckpointReader();
CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const;
const string DebugString() const;
// Returns a map from variable names to its shape. Slices of a partitioned
// Returns a map from variable names to their shapes. Slices of a partitioned
// tensor are combined into a single entry.
const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const;
// Returns a map from variable names to their data types. Slices of a
// partitioned tensor are combined into a single entry.
const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const;
// Attempts to look up the tensor named "name" and stores the found result in
// "out_tensor".
void GetTensor(const string& name,
@ -54,14 +59,19 @@ class CheckpointReader {
TF_Status* out_status) const;
private:
// Uses "v2_reader_" to build a "var name -> shape" map; owned by caller.
// Uses "v2_reader_" to build "var name -> shape" and "var name -> data type"
// maps; both owned by caller.
// REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()".
TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap();
std::pair<std::unique_ptr<TensorSliceReader::VarToShapeMap>,
std::unique_ptr<TensorSliceReader::VarToDataTypeMap> >
BuildV2VarMaps();
// Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr.
TensorSliceReader* reader_; // Owned.
BundleReader* v2_reader_; // Owned.
TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned.
// Invariant: exactly one of "reader_" and "v2_reader_" is non-null.
std::unique_ptr<TensorSliceReader> reader_;
std::unique_ptr<BundleReader> v2_reader_;
std::unique_ptr<TensorSliceReader::VarToShapeMap> var_to_shape_map_;
std::unique_ptr<TensorSliceReader::VarToDataTypeMap> var_to_data_type_map_;
TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader);
};
@ -69,4 +79,4 @@ class CheckpointReader {
} // namespace checkpoint
} // namespace tensorflow
#endif // TENSORFLOW_C_CHECKPOINT_READER_H
#endif // TENSORFLOW_C_CHECKPOINT_READER_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_
#define TENSORFLOW_C_CONVERSION_MACROS_H_
#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \
inline cpp_impl *unwrap(wrapper *w) { \
return reinterpret_cast<cpp_impl *>(w); \
} \
\
inline const cpp_impl *unwrap(const wrapper *w) { \
return reinterpret_cast<const cpp_impl *>(w); \
} \
\
inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast<wrapper *>(i); } \
inline const wrapper *wrap(const cpp_impl *i) { \
return reinterpret_cast<const wrapper *>(i); \
}
#endif // TENSORFLOW_C_CONVERSION_MACROS_H_

1134
tensorflow/c/eager/BUILD Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
public:
AbstractContextKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
virtual AbstractOperation* CreateOperation() = 0;
// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual Status RegisterFunction(AbstractFunction*) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
private:
const AbstractContextKind kind_;
};
namespace internal {
struct AbstractContextDeleter {
void operator()(AbstractContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraph, kMlir };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Returns the AbstractFunction as a FunctionDef.
virtual Status GetFunctionDef(FunctionDef**) = 0;
private:
const AbstractFunctionKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_

View File

@ -0,0 +1,138 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to an operation.
// This interface allows building and executing an operation in either
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind {
kGraph,
kMlir,
kEager,
kTfrt,
kTape,
kOpHandler
};
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
public:
AbstractOperationKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0;
// Returns the operation's device name.
//
// The value returned may be different from the one set by SetDeviceName, but
// it will be compatible with it: the name will be updated by device placement
// logic to refer to the specific device chosen.
//
// Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
// returned by DeviceName should be "/device:GPU:*" until a particular GPU is
// chosen for the operation by the device placement logic in the
// executor. After that, the value returned by DeviceName will be a full
// device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
virtual const string& DeviceName() const = 0;
// Sets the operation device name.
//
// The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
// the result will be used as a constraint for device placement. See the
// documentation for DeviceName for more details.
//
// The value will override the previous value - that is, no "merging" of
// existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandle* input) = 0;
virtual Status AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) = 0;
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) = 0;
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) = 0;
virtual Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
private:
const AbstractOperationKind kind_;
};
namespace internal {
struct AbstractOperationDeleter {
void operator()(AbstractOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOperationPtr =
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
std::string AbstractTensorHandle::DebugString() const {
PartialTensorShape shape;
Status s = Shape(&shape);
std::string shape_string;
if (!s.ok()) {
shape_string = "<error computing shape>";
} else {
shape_string = shape.DebugString();
}
return absl::StrCat("TensorHandle(shape=", shape_string,
", dtype=", DataType_Name(DataType()), ")");
}
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle : public core::RefCounted {
protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
public:
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
virtual tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const = 0;
// The default debug string includes a shape and dtype. Implementations are
// free to override it with something more informative.
virtual std::string DebugString() const;
AbstractTensorHandleKind getKind() const { return kind_; }
private:
const AbstractTensorHandleKind kind_;
};
namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Unref();
}
}
};
} // namespace internal
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

1047
tensorflow/c/eager/c_api.cc Normal file

File diff suppressed because it is too large Load Diff

466
tensorflow/c/eager/c_api.h Normal file
View File

@ -0,0 +1,466 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_H_
#define TENSORFLOW_C_EAGER_C_API_H_
// C API extensions to experiment with eager execution of kernels.
// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be
// stable and can change without notice.
#include "tensorflow/c/c_api.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
// This duplicates the TF_EXPORT macro definition in
// tensorflow/core/platform/macros.h in order to keep this .h file independent
// of any other includes.$a
#ifdef SWIG
#define TF_CAPI_EXPORT
#else
#if defined(_WIN32)
#ifdef TF_COMPILE_LIBRARY
#define TF_CAPI_EXPORT __declspec(dllexport)
#else
#define TF_CAPI_EXPORT __declspec(dllimport)
#endif // TF_COMPILE_LIBRARY
#else
#define TF_CAPI_EXPORT __attribute__((visibility("default")))
#endif // _WIN32
#endif // SWIG
#ifdef __cplusplus
extern "C" {
#endif
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void);
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
// If config was not parsed successfully as a ConfigProto, record the
// error information in *status.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
// LINT.IfChange
// Note: Keep in sync with internal copy of enum in eager/context.h.
typedef enum TFE_ContextDevicePlacementPolicy {
// Running operations with input tensors on the wrong device will fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the operation
// will be blocked till the copy completes. This is the default placement
// policy.
TFE_DEVICE_PLACEMENT_SILENT = 2,
// Placement policy which silently copies int32 tensors but not other dtypes.
TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
} TFE_ContextDevicePlacementPolicy;
// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h)
// Sets the default execution mode (sync/async). Note that this can be
// overridden per thread using TFE_ContextSetExecutorForThread.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*,
unsigned char enable);
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
// TFE_Context must outlive all tensor handles created using it. In other
// words, TFE_DeleteContext() must be called after all tensor handles have
// been deleted (with TFE_DeleteTensorHandle).
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
const TFE_ContextOptions* opts, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
// Clears the internal caches in the TFE context. Useful when reseeding random
// ops.
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
// Sets a thread-local device placement policy. After this call, other calls to
// TFE_Execute in the same thread will use the device policy specified here
// instead of the device policy used to construct the context. This has no
// effect on the device policy used by other program threads.
TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy);
// Returns the device placement policy to be used by this context in the current
// thread.
TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy
TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx);
// A tensorflow.ServerDef specifies remote workers (in addition to the current
// workers name). Operations created on this context can then be executed on
// any of these remote workers by setting an appropriate device.
//
// If the following is set, all servers identified by the
// ServerDef must be up when the context is created.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status);
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors
// placed in memory of different devices or remote address spaces.
typedef struct TFE_TensorHandle TFE_TensorHandle;
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t,
TF_Status* status);
// Indicates that the caller will not be using `h` any more.
TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h);
TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
TF_Status* status);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
// Returns the device of the operation that produced `h`. If `h` was produced by
// a copy, returns the destination device of the copy. Note that the returned
// device name is not always the device holding the tensor handle's memory. If
// you want the latter, use TFE_TensorHandleBackingDeviceName. This function
// will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
// Returns the name of the device in whose memory `h` resides.
//
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
TFE_TensorHandle* h, TF_Status* status);
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
// the error and a nullptr is returned.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status);
// This function will block till the operation that produces `h` has
// completed. The memory returned might alias the internal memory used by
// TensorFlow. Hence, callers should not mutate this memory (for example by
// modifying the memory region pointed to by TF_TensorData() on the returned
// TF_Tensor).
TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h,
TF_Status* status);
// Create a new TFE_TensorHandle with the same contents as 'h' but placed
// in the memory of the device name 'device_name'.
// If source and destination are the same device, then this creates a new handle
// that shares the underlying buffer. Otherwise, it currently requires at least
// one of the source or destination devices to be CPU (i.e., for the source or
// destination tensor to be placed in host memory).
// If async execution is enabled, the copy may be enqueued and the call will
// return "non-ready" handle. Else, this function returns after the copy has
// been done.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(
TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name,
TF_Status* status);
// Debugging/Profiling information for TFE_TensorHandle
//
// TFE_TensorDebugInfo contains information useful for debugging and
// profiling tensors.
typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo;
// Retrieves TFE_TensorDebugInfo for `handle`.
// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller
// is responsible for deleting returned TFE_TensorDebugInfo.
// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate
// error and nullptr is returned. This function can block till the operation
// that produces `handle` has completed.
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status);
// Deletes `debug_info`.
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info);
// Returns the number of dimensions used to represent the tensor on its device.
// The number of dimensions used to represent the tensor on device can be
// different from the number returned by TFE_TensorHandleNumDims.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
TFE_TensorDebugInfo* debug_info);
// Returns the number of elements in dimension `dim_index`.
// Tensor representation on device can be transposed from its representation
// on host. The data contained in dimension `dim_index` on device
// can correspond to the data contained in another dimension in on-host
// representation. The dimensions are indexed using the standard TensorFlow
// major-to-minor order (slowest varying dimension first),
// not the XLA's minor-to-major order.
// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns
// the number of elements in a dimension after padding.
// The return value was current at the time of TFE_TensorDebugInfo creation.
TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
TFE_TensorDebugInfo* debug_info, int dim_index);
// Description of the TensorFlow op to execute.
//
// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e.,
// TFE_DeleteOp() is called before TFE_DeleteContext().
//
// Very similar to TF_OperationDescription with some differences:
// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput,
// TF_AddInputList
// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense.
// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since
// the additional sanity checks there seem unnecessary;
typedef struct TFE_Op TFE_Op;
TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
// Returns the op or function name `op` will execute.
//
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
TF_Status* status);
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op,
TFE_TensorHandle** inputs,
int num_inputs,
TF_Status* status);
// Fetches the current number of inputs attached to `op`.
//
// Does not use the operation's definition to determine how many inputs should
// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an
// already-finalized operation.
//
// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat
// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a
// particular named input list, which may only be part of the op's inputs).
TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op,
TF_Status* status);
// Returns a borrowed reference to one of `op`'s inputs. Use
// `TFE_TensorHandleCopySharingTensor` to make a new reference.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op,
int index,
TF_Status* status);
TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op,
const char* attr_name,
unsigned char* is_list,
TF_Status* status);
// Get an attribute type given an op name; a fusion of TFE_NewOp and
// TFE_OpGetAttrType for use from Python without the overhead of the individual
// calls and memory management of TFE_Op.
TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType(
TFE_Context* ctx, const char* op_or_function_name, const char* attr_name,
unsigned char* is_list, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op,
const char* attr_name,
const void* value,
size_t length);
TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name,
int64_t value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name,
float value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name,
unsigned char value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name,
TF_DataType value);
// If the number of dimensions is unknown, `num_dims` must be set to
// -1 and `dims` can be null. If a dimension is unknown, the
// corresponding entry in the `dims` array must be -1.
TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name,
const int64_t* dims,
const int num_dims,
TF_Status* out_status);
// Sets the attribute attr_name to be a function specified by 'function'.
//
// TODO(ashankar,iga): Add this functionality to the C API for graph
// construction. Perhaps we want an AttrValueMap equivalent in the C API?
TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
const char* attr_name,
const TFE_Op* value);
TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length);
TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
const char* attr_name,
TF_Tensor* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op,
const char* attr_name,
const int64_t* values,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op,
const char* attr_name,
const float* values,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op,
const char* attr_name,
const unsigned char* values,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op,
const char* attr_name,
const TF_DataType* values,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(
TFE_Op* op, const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values, TF_Status* out_status);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
const char* attr_name,
const TFE_Op** value,
int num_values);
// Returns the length (number of tensors) of the input argument `input_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status);
// Returns the length (number of tensors) of the output argument `output_name`
// found in the provided `op`.
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status);
// Execute the operation defined by 'op' and return handles to computed
// tensors in `retvals`.
//
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
// '*num_retvals' should be set to the size of this array. It is an error if
// the size of 'retvals' is less than the number of outputs. This call sets
// *num_retvals to the number of outputs.
//
// If async execution is enabled, the call may simply enqueue the execution
// and return "non-ready" handles in `retvals`. Note that any handles contained
// in 'op' should not be mutated till the kernel execution actually finishes.
//
// For sync execution, if any of the inputs to `op` are not ready, this call
// will block till they become ready and then return when the kernel execution
// is done.
// TODO(agarwal): change num_retvals to int from int*.
TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals,
int* num_retvals, TF_Status* status);
// Add a function (serialized FunctionDef protocol buffer) to ctx so
// that it can be invoked using TFE_Execute.
TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(
TFE_Context* ctx, const char* serialized_function_def, size_t size,
TF_Status* status);
// Adds a function (created from TF_GraphToFunction or
// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with
// TFE_Execute by creating an op with the same name as the function.
TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx,
TF_Function* function,
TF_Status* status);
// Removes a function from the context. Once removed, you can no longer
// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any
// other function which calls it as an attribute.
TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx,
const char* name,
TF_Status* status);
// Checks whether a function is registered under `name`.
TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx,
const char* name);
// Enables tracing of RunMetadata on the ops executed from this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx);
// Disables tracing of RunMetadata on the ops executed from this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx);
// Populates the passed-in buffer with a serialized RunMetadata protocol buffer
// containing any run metadata information accumulated so far and clears this
// information.
// If async mode is enabled, this call blocks till all currently pending ops are
// done.
TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
// Some TF ops need a step container to be set to limit the lifetime of some
// resources (mostly TensorArray and Stack, used in while loop gradients in
// graph mode). Calling this on a context tells it to start a step.
TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx);
// Ends a step. When there is no active step (that is, every started step has
// been ended) step containers will be cleared. Note: it is not safe to call
// TFE_ContextEndStep while ops which rely on the step container may be running.
TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#ifdef __cplusplus
// A workaround to ease conversion to and from numpy objects and
// TFE_TensorHandle's.
//
// TODO(ashankar): Figure out an alternative scheme that precludes the need for
// these API-boundary breaking methods.
namespace tensorflow {
class Tensor;
} // namespace tensorflow
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status);
#endif
#endif // TENSORFLOW_C_EAGER_C_API_H_

View File

@ -0,0 +1,479 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->at(task_index) =
tensorflow::strings::StrCat("localhost:", port);
}
void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle,
const std::vector<float>& expected_values) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t = TFE_TensorHandleResolve(handle, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
std::unique_ptr<float[]> actual_values(new float[expected_values.size()]);
EXPECT_EQ(sizeof(float) * expected_values.size(), TF_TensorByteSize(t));
memcpy(actual_values.get(), TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
for (int i = 0; i < expected_values.size(); i++) {
EXPECT_EQ(expected_values[i], actual_values[i])
<< "Mismatch in expected values at (zero-based) index " << i;
}
}
void CheckRemoteMatMulExecutesOK(TFE_Context* ctx,
const char* remote_device_name,
const char* local_device_name) {
TF_Status* status = TF_NewStatus();
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h0_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 =
TFE_TensorHandleCopyToDevice(retvals[0], ctx, local_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckTFE_TensorHandleHasFloats(retval_task0, {7, 10, 15, 22});
TFE_DeleteTensorHandle(retval_task0);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
}
// Read the value of variable `var` and save it into `out_value`.
void ReadVariable(TFE_Context* ctx, TFE_TensorHandle* var,
TFE_TensorHandle** out_value) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 1;
TFE_Execute(op, out_value, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TF_DeleteStatus(status);
}
void TestRemoteExecuteChangeServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
// Update the server def with a new set of names (worker instead of
// localhost).
tensorflow::ServerDef updated_server_def = GetServerDef("worker", 2);
serialized = updated_server_def.SerializeAsString();
updated_server_def.set_task_index(1);
tensorflow::Status s = tensorflow::GrpcServer::Create(
updated_server_def, tensorflow::Env::Default(), &worker_server);
ASSERT_TRUE(s.ok()) << s.error_message();
ASSERT_TRUE(worker_server->Start().ok());
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Create a new tensor_handle.
TFE_TensorHandle* h0_task0_new = TestMatrixTensorHandle(ctx);
// Check that copying it to the old remote device (named localhost) fails.
TFE_TensorHandleCopyToDevice(h0_task0_new, ctx, remote_device_name, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Copying and executing on the new remote device works.
const char new_remote_device_name[] =
"/job:worker/replica:0/task:1/device:CPU:0";
const char new_local_device_name[] =
"/job:worker/replica:0/task:0/device:CPU:0";
auto* h0_task1_new = TFE_TensorHandleCopyToDevice(
h0_task0_new, ctx, new_remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(h0_task0_new);
TFE_DeleteTensorHandle(h0_task1_new);
CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name,
new_local_device_name);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteChangeServerDef) {
TestRemoteExecuteChangeServerDef(false);
}
TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
TestRemoteExecuteChangeServerDef(true);
}
void TestRemoteExecuteUpdateServerDef(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteUpdateServerDef) {
TestRemoteExecuteUpdateServerDef(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
TestRemoteExecuteUpdateServerDef(true);
}
void TestRemoteExecuteUpdateServerDefResourceAccess(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char dev0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char dev1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
TFE_TensorHandle* value_handle = nullptr;
ReadVariable(ctx, var_handle1, &value_handle);
CheckTFE_TensorHandleHasFloats(value_handle, {2});
TFE_DeleteTensorHandle(value_handle);
// Start a new worker to replace task:1
ReplaceTaskInServerDef(&server_def, 1);
server_def.set_task_index(1);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
// Update server def to replace the remote device with the device info on the
// new worker (different incarnation ID).
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// The device of var_handle0 is local device which is the same before and
// after cluster update. Remove resource with valid device should succeed.
TFE_Op* op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle0, status);
TFE_OpSetDevice(op, dev0_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
// The device of var_handle1 is remote device, which was replaced during
// cluster update. Removing resource with invalid device should fail
// gracefully (i.e., with error status) instead of crashing with segfaults.
op = TFE_NewOp(ctx, "DestroyResourceOp", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, var_handle1, status);
TFE_OpSetDevice(op, dev1_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(op);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccess) {
TestRemoteExecuteUpdateServerDefResourceAccess(false);
}
TEST(CAPI, TestRemoteExecuteUpdateServerDefResourceAccessAsync) {
TestRemoteExecuteUpdateServerDefResourceAccess(true);
}
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const char local_device_name[] =
"/job:localhost/replica:0/task:0/device:CPU:0";
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
// Adding a non-existent remote worker to cluster def. This should cause the
// UpdateServerDef call to fail.
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{2, tensorflow::strings::StrCat("localhost:", port)});
server_def.set_task_index(0);
string serialized_update = server_def.SerializeAsString();
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
serialized_update.size(), status);
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Even after the prevoiusly failed cluster update, another update and op
// execution should work fine as long as the provided server_def is valid.
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
TestRemoteExecuteUpdateServerDefWithFailures(false);
}
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
TestRemoteExecuteUpdateServerDefWithFailures(true);
}
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
// Fail fast on GetStatus requests so we can get errors instead of timeout
// when updating cluster with non-exsitent worker
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
const string first_name =
keep_localhost_for_first_connect ? "localhost" : "abc";
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
EXPECT_NE(var_handle0, nullptr);
tensorflow::Status status2;
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
// Rename local device
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev1_name =
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
EXPECT_NE(var_handle1, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
// Another renaming of local device
const string second_name = "def";
server_def.set_job_name(second_name);
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
absl::StrCat(second_name, ":",
tensorflow::testing::PickUnusedPortOrDie());
serialized = server_def.SerializeAsString();
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle2, nullptr);
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
TFE_DeleteTensorHandle(var_handle0);
TFE_DeleteTensorHandle(var_handle1);
TFE_DeleteTensorHandle(var_handle2);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
tensorflow::unsetenv("GRPC_FAIL_FAST");
}
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
} // namespace

View File

@ -0,0 +1,87 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/status.h"
using tensorflow::string;
namespace {
std::vector<tensorflow::int64> TensorShapeAsVector(
const tensorflow::TensorHandle& handle, tensorflow::Status* status) {
std::vector<tensorflow::int64> shape;
int rank = -1;
*status = handle.NumDims(&rank);
if (!status->ok()) {
return shape;
}
shape.reserve(rank);
for (int i = 0; i < rank; ++i) {
tensorflow::int64 dim;
*status = handle.Dim(i, &dim);
if (!status->ok()) {
return shape;
}
shape.push_back(dim);
}
return shape;
}
} // namespace
extern "C" {
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle =
TensorHandleFromInterface(tensorflow::unwrap(h));
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
std::vector<tensorflow::int64> dev_dims =
TensorShapeAsVector(*handle, &status->status);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);
}
TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo(
TFE_TensorDebugInfo* debug_info) {
delete debug_info;
}
TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims(
TFE_TensorDebugInfo* debug_info) {
return debug_info->dev_dims.size();
}
TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim(
TFE_TensorDebugInfo* debug_info, int dim_index) {
return debug_info->dev_dims[dim_index];
}
} // extern "C"

View File

@ -0,0 +1,62 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
TEST(CApiDebug, ScalarCPU) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestScalarTensorHandle(ctx, 1.0f);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CApiDebug, 2DCPU) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h = TestMatrixTensorHandle3X2(ctx);
TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info));
// Shape is the same for CPU tensors.
EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0));
EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1));
TFE_DeleteTensorDebugInfo(debug_info);
TFE_DeleteTensorHandle(h);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}

View File

@ -0,0 +1,624 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <regex> // NOLINT
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Pack 3 variable handles into one TFE_TensorHandle.
// When remote is false, function device is placed on task0. Handle types are
// REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
// placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunctionSignature() {
return " signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }";
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
VariableAddFunctionSignature(), &def));
return def.SerializeAsString();
}
// A graph optimization pass that would fail when triggered for more than once.
class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass {
public:
static bool enabled_;
GraphErrorInjectionPass() {}
tensorflow::Status Run(
const tensorflow::GraphOptimizationPassOptions& options) override {
if (!enabled_) {
return tensorflow::Status::OK();
}
if (first_call_) {
first_call_ = false;
return tensorflow::Status::OK();
}
return tensorflow::errors::Internal("Graph pass runs for more than once!");
}
private:
bool first_call_ = true;
};
// After the graph pass is registered, it takes effect globally and can affect
// other test cases. Define a static variable to switch it on and off.
bool GraphErrorInjectionPass::enabled_ = false;
// Test to ensure that a registered graph optimization pass is only executed
// once (i.e., on the main function side) in running distributed functions.
// This test creates a cluster with two workers, create a variable on the
// second worker, and run a distributed function (VariableAddFunction) whose ops
// span the local and remote workers. If the graph optimization pass is executed
// on both the main function side and the component function side, an error will
// be thrown in the registered graph optimization pass.
TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) {
// Register graph pass that will raise error if called more than once.
tensorflow::optimization_registration::OptimizationPassRegistration
register_test_pass(tensorflow::OptimizationPassRegistry::PRE_PLACEMENT, 0,
std::make_unique<GraphErrorInjectionPass>(),
"error_injector");
GraphErrorInjectionPass::enabled_ = true;
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
// Disable the test graph pass so it does not affect other test cases.
GraphErrorInjectionPass::enabled_ = false;
}
string VariableAddFunctionWithGraphError() {
string signature = VariableAddFunctionSignature();
// Replace the node 'read0' with 'read0_maybe_with_graph_error', so that the
// error injecting pass can identify and introduce graph pass errors.
signature = std::regex_replace(signature, std::regex("read0"),
"read0_maybe_with_graph_error");
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(signature, &def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the
// 'read0_maybe_with_graph_error' op having a requested device `dev2_name`.
// During execution:
// * task:0 processes main function `VariableAddFunctionWithGraphError`
// and places the 'read0_maybe_with_graph_error' op on task:2
// * task:0 partitions the main function with a subgraph containing
// 'read0_maybe_with_graph_error' sent to task:2
// * task:2 graph pass reports an error when it sees
// 'read0_maybe_with_graph_error' with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>(
"read0_maybe_with_graph_error", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
TFE_ContextAsyncWait(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
// TODO(b/170399182): Update test once an alternative to using the function
// optimization hook is in place.
TEST(CAPI, DISABLED_DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -0,0 +1,656 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_experimental.h"
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
tensorflow::ImmediateExecutionOperation* op =
tensorflow::unwrap(op_to_reset);
op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
}
uint64_t TFE_GetContextId(TFE_Context* ctx) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
return context->GetContextId();
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);
}
int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
return cell->cell.value();
}
TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringCounter0({name, description});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell()));
}
TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringCounter1({name, description, label1});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1)));
}
TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringCounter2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->counter->GetStatus());
if (!result->counter->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
delete counter;
}
TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringCounterCell*>(
static_cast<void*>(counter->counter->GetCell(label1, label2)));
}
void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
int64_t value) {
cell->cell.Set(value);
}
int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringIntGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
TFE_MonitoringIntGauge0* gauge) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
TFE_MonitoringIntGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringIntGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
delete gauge;
}
TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringIntGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
const char* value) {
cell->cell.Set({value});
}
const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
tensorflow::string value = cell->cell.value();
void* data = tensorflow::port::Malloc(value.length());
value.copy(static_cast<char*>(data), value.length(), 0);
buf->data = data;
buf->length = value.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* status, const char* description) {
auto* result = new TFE_MonitoringStringGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
TFE_MonitoringStringGauge0* gauge) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* status, const char* description,
const char* label1) {
auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
TFE_MonitoringStringGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2) {
auto* result =
new TFE_MonitoringStringGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
delete gauge;
}
TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringStringGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
bool value) {
cell->cell.Set(value);
}
bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
return cell->cell.value();
}
TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringBoolGauge0({name, description});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
TFE_MonitoringBoolGauge0* gauge) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell()));
}
TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
TF_Status* status,
const char* description,
const char* label1) {
auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
TFE_MonitoringBoolGauge1* gauge, const char* label1) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1)));
}
TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
TF_Status* status,
const char* description,
const char* label1,
const char* label2) {
auto* result =
new TFE_MonitoringBoolGauge2({name, description, label1, label2});
Set_TF_Status_from_Status(status, result->gauge->GetStatus());
if (!result->gauge->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
delete gauge;
}
TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringBoolGaugeCell*>(
static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
}
void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
double value) {
cell->cell.Add(value);
}
void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
TF_Buffer* buf) {
string content;
cell->cell.value().SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
double growth_factor,
int bucket_count) {
return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
bucket_count);
});
}
void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
delete buckets;
}
TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description) {
auto* result = new TFE_MonitoringSampler0(
{name, buckets->create_buckets(), description});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell()));
}
TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1) {
auto* result = new TFE_MonitoringSampler1(
{name, buckets->create_buckets(), description, label1});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1)));
}
TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
const char* description, const char* label1, const char* label2) {
auto* result = new TFE_MonitoringSampler2(
{name, buckets->create_buckets(), description, label1, label2});
Set_TF_Status_from_Status(status, result->sampler->GetStatus());
if (!result->sampler->GetStatus().ok()) {
delete result;
return nullptr;
}
return result;
}
void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
delete sampler;
}
TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
return static_cast<TFE_MonitoringSamplerCell*>(
static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
}
void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
bool lazy_copy) {
options->lazy_remote_inputs_copy = lazy_copy;
}
void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
options->use_tfrt = use_tfrt;
}
TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
}
void TFE_CancellationManagerStartCancel(
TFE_CancellationManager* cancellation_manager) {
cancellation_manager->cancellation_manager.StartCancel();
}
bool TFE_CancellationManagerIsCancelled(
TFE_CancellationManager* cancellation_manager) {
return cancellation_manager->cancellation_manager.IsCancelled();
}
void TFE_DeleteCancellationManager(
TFE_CancellationManager* cancellation_manager) {
delete cancellation_manager;
}
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
tensorflow::EagerOperation* operation =
tensorflow::OperationFromInterface(tensorflow::unwrap(op));
operation->SetCancellationManager(
&cancellation_manager->cancellation_manager);
status->status = tensorflow::Status::OK();
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
return new TFE_Executor(is_async);
}
void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
return executor->executor()->Async();
}
void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
TF_Status* status) {
status->status = executor->executor()->WaitForAllPendingNodes();
}
void TFE_ExecutorClearError(TFE_Executor* executor) {
executor->executor()->ClearError();
}
void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
}
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
tensorflow::unwrap(ctx)->HostCPUParsedName());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}
TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
const int64_t* dims, int num_dims,
TF_Status* status) {
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (ctx == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid Context");
return nullptr;
}
tensorflow::AbstractTensorInterface* t =
tensorflow::unwrap(ctx)->CreateTensor(
static_cast<tensorflow::DataType>(dtype), dimvec);
if (t == nullptr) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
return nullptr;
}
return new TF_Tensor{t};
}
TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
TF_Status* status) {
return tensorflow::wrap(
tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
}
TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
TFE_TensorHandle** handles,
int* num_handles,
TF_Status* status) {
std::vector<tensorflow::TensorHandle*> tensor_handles;
tensor_handles.reserve(*num_handles);
for (int i = 0; i < *num_handles; ++i) {
tensor_handles.push_back(
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::TensorHandle* handle = nullptr;
status->status = tensorflow::TensorHandle::CreatePackedHandle(
std::move(tensor_handles), context, &handle);
return tensorflow::wrap(handle);
}
void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
}
void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
TF_Status* status) {
tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
}
const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
return tensorflow::unwrap(h)->DeviceType(&status->status);
}
int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return -1;
}
return tensorflow::unwrap(h)->DeviceId(&status->status);
}

View File

@ -0,0 +1,568 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
// is for performance optimization by reusing an exiting unused op rather than
// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
// does not set the device name. If it's not `NULL`, then it attempts to parse
// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
// than separately calling it because if the existing op has the same
// `raw_device_name`, it skips parsing and just leave as it is.
TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
const char* op_or_function_name,
const char* raw_device_name,
TF_Status* status);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// Disables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.
// These APIs de-templated monitoring Counter for swig.
typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
// Atomically increments the value of the cell. The value must be non-negative.
TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
TFE_MonitoringCounterCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
TFE_MonitoringCounterCell* cell);
// APIs for Counter without label.
typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
// Returns a new Counter metric object. The caller should manage lifetime of
// the object. Using duplicate metric name will crash the program with fatal
// error.
TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
const char* name, TF_Status* status, const char* description);
// Deletes the Counter object.
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
TFE_MonitoringCounter0* counter);
// Retrieves the cell from the Counter object. The Counter object will manage
// lifetime of the cell.
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
TFE_MonitoringCounter0* counter);
// APIs for Counter with 1 label.
typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
const char* name, TF_Status* status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
TFE_MonitoringCounter1* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
TFE_MonitoringCounter1* counter, const char* label1);
// APIs for Counter with 2 labels.
typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
const char* name, TF_Status* status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
TFE_MonitoringCounter2* counter);
TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Gauge APIs.
// These APIs de-templated monitoring Gauge for swig.
typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
// Atomically set the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
TFE_MonitoringIntGaugeCell* cell, int64_t value);
// Retrieves the current value of the cell.
TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
TFE_MonitoringIntGaugeCell* cell);
// APIs for Int Gauge without label.
typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
TFE_MonitoringIntGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
// APIs for Int Gauge with 1 label.
typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
TFE_MonitoringIntGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
const char* label1);
// APIs for Int Gauge with 2 label.
typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
TFE_MonitoringIntGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
TFE_MonitoringStringGaugeCell* cell, const char* value);
// Retrieves the string value and saves it in buffer.
TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
// APIs for String Gauge without label.
typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
TFE_MonitoringStringGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
// APIs for String Gauge with 1 label.
typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
TFE_MonitoringStringGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
const char* label1);
// APIs for String Gauge with 2 label.
typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
TFE_MonitoringStringGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
const char* label1, const char* label2);
typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
TFE_MonitoringBoolGaugeCell* cell, bool value);
TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
TFE_MonitoringBoolGaugeCell* cell);
// APIs for Bool Gauge without label.
typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
const char* name, TF_Status* out_status, const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
TFE_MonitoringBoolGauge0* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
// APIs for Bool Gauge with 1 label.
typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
const char* name, TF_Status* out_status, const char* description,
const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
TFE_MonitoringBoolGauge1* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
const char* label1);
// APIs for Bool Gauge with 2 label.
typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
const char* name, TF_Status* out_status, const char* description,
const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
TFE_MonitoringBoolGauge2* gauge);
TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
const char* label1, const char* label2);
// -----------------------------------------------------------------------------
// Monitoring Sampler APIs.
// These APIs de-templated monitoring Sampler for swig.
typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
// Atomically add the value of the cell.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
TFE_MonitoringSamplerCell* cell, double value);
// Retrieves the current value of the cell. The return value is a HistogramProto
// saved in buffer.
TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
// APIs for sampler buckets
typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
int bucket_count);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
TFE_MonitoringBuckets* buckets);
// APIs for Sampler without label.
typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
TFE_MonitoringSampler0* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
TFE_MonitoringSampler0* sampler);
// APIs for Sampler with 1 label.
typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
TFE_MonitoringSampler1* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
TFE_MonitoringSampler1* sampler, const char* label1);
// APIs for Sampler with 2 label.
typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
const char* description, const char* label1, const char* label2);
TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
TFE_MonitoringSampler2* sampler);
TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
// Sets whether to copy the remote inputs of a function lazily.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);
// Sets whether to use TFRT
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);
// Returns the context_id from the EagerContext which is used by the
// EagerService to maintain consistency between client and worker. The
// context_id is initialized with a dummy value and is later set when the worker
// is initialized (either locally or remotely). The context_id can change during
// the process lifetime although this should cause the worker to be
// reinitialized (e.g. cleared caches) as well.
TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
// -----------------------------------------------------------------------------
// Cancellation APIs.
typedef struct TFE_CancellationManager TFE_CancellationManager;
TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager();
TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled(
TFE_CancellationManager*);
TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel(
TFE_CancellationManager*);
TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager(
TFE_CancellationManager*);
// Associates the given `cancellation_manager` with `op`, so that invoking
// `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the
// execution of `op`.
typedef struct TFE_CancellationManager TFE_CancellationManager;
TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager(
TFE_Op* op, TFE_CancellationManager* cancellation_manager,
TF_Status* status);
// -----------------------------------------------------------------------------
// Eager Executor APIs.
typedef struct TFE_Executor TFE_Executor;
// Creates a new eager Executor. Nodes in one executor are guaranteed to be
// executed in sequence. Assigning nodes to different executors allows executing
// nodes in parallel.
TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async);
// Deletes the eager Executor without waiting for enqueued nodes. Please call
// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
// make sure all nodes are finished.
TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*);
// Returns true if the executor is in async mode.
TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*);
// Causes the calling thread to block till all ops dispatched in this executor
// have been executed. Note that "execution" here refers to kernel execution /
// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
// that lower level device queues (like GPU streams) have been flushed.
//
// This call may not block for execution of ops enqueued concurrently with this
// call.
TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes(
TFE_Executor*, TF_Status* status);
// When an error happens, any pending operations are discarded and newly issued
// ops return an error. This call clears the error state and re-enables
// execution of newly issued ops.
//
// Note that outputs of discarded ops remain in a corrupt state and should not
// be used for future calls.
// TODO(agarwal): mark the affected handles and raise errors if they are used.
TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*);
// Sets a custom Executor for current thread. All nodes created by this thread
// will be added to this Executor. It will override current executor.
TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*,
TFE_Executor*);
// Returns the Executor for current thread.
TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread(
TFE_Context*);
// -----------------------------------------------------------------------------
// Dynamic cluster API.
// Update an existing context with a new set of servers defined in a ServerDef
// proto. Servers can be added to and removed from the list of remote workers
// in the context. New set of servers identified by the ServerDef must be up
// when the context is updated.
//
// This API is for experimental usage and may be subject to change.
TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status);
// Checks whether a remote worker is alive or not. This will return true even if
// the context doesn't exist on the remote worker.
TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
const char* worker_name,
TF_Status* status);
// Sync pending nodes in local executors (including the context default executor
// and thread executors) and streaming requests to remote executors, and get the
// combined status.
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
TF_Status* status);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. The pointer
// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
// for a GPU tensor this will return a pointer to GPU memory). The pointer is
// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this
// TensorHandle. Only supports POD data types.
TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*,
TF_Status*);
// This function will block till the operation that produces `h` has
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
// bytes of the memory pointed to by the device pointer returned above.
TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*,
TF_Status*);
// Creates a new TensorHandle from memory residing in device_name. Takes
// ownership of the memory, and will call deleter to release it after TF
// no longer needs it or in case of error.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims,
int num_dims, void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
// APIs for generically dealing with op attributes (e.g. when forwarding them
// through custom device implementations).
//
// TODO(allenl): Currently these are black boxes, but we should have some way to
// inspect values. This would let people e.g. copy over most attributes and then
// modify some based on their values.
// A reference to an op's name -> attribute mapping
typedef struct TFE_OpAttrs TFE_OpAttrs;
// Fetch a reference to `op`'s attributes. The returned reference is only valid
// while `op` is alive.
TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
// Add attributes in `attrs` to `op`.
//
// Does not overwrite or update existing attributes, but adds new ones.
TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
// containing the op name and a map of its attributes.
TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
TF_Buffer* buf,
TF_Status* status);
// Set an op's attribute from a serialized AttrValue protocol buffer.
//
// Analogous to TF_SetAttrValueProto for building graph operations.
TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
const char* attr_name,
const void* proto,
size_t proto_len,
TF_Status* status);
// TODO(b/166642410): It would be nice, for custom devices and for other users,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
//
// Arguments provide enough information to reconstruct the original `TFE_Op`,
// or construct a transformed version, by inspecting the passed `op`.
//
// TFE_OpGetDevice(op) records the original placement of the operation. It may
// be an empty string if no device was explicitly requested, but will
// otherwise be the name of this custom device. Ops are placed onto a custom
// device if any of their inputs are on that custom device, but custom devices
// are free to set a bad status in order to require explicit placement.
void (*execute)(const TFE_Op* op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.
//
// Eager operations may be placed on this device, e.g. `with
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device
// which is passed to the functions referenced in the TFE_CustomDevice struct
// `device` (execute, delete_device, etc.). It can for example contain the
// names of wrapped devices.
//
// There are currently no graph semantics implemented for registered custom
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// `device_name` must not name an existing physical or custom device. It must
// follow the format:
//
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
//
// If the device is successfully registered, `status` is set to TF_OK. Otherwise
// the device is not usable. In case of a bad status, `device.delete_device` is
// still called on `device_info` (i.e. the caller does not retain ownership).
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
TFE_CustomDevice device,
const char* device_name,
void* device_info,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
// Allocate and return a new Tensor on the host.
//
// The caller must set the Tensor values by writing them to the pointer returned
// by TF_TensorData with length TF_TensorByteSize.
TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
TF_DataType dtype,
const int64_t* dims,
int num_dims,
TF_Status* status);
// Given a Tensor, wrap it with a TensorHandle
//
// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
// The context should be identical to that of the Tensor.
TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
// Create a packed TensorHandle with the given list of TensorHandles.
// If `handles` are on the same device, assign the same device to the packed
// handle; if `handles` are on different deivces, assign a CompositeDevice to
// it.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
TF_Status* status);
// Configure soft device placement policy for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Configure device placement policy logging for the eager executor. Note this
// policy is applied to any subsequent op executions.
TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
unsigned char enable,
TF_Status* status);
// Returns the device type of the operation that produced `h`.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
TFE_TensorHandle* h, TF_Status* status);
// Returns the device ID of the operation that produced `h`.
TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_

View File

@ -0,0 +1,519 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace tensorflow {
namespace {
static bool HasSubstr(absl::string_view base, absl::string_view substr) {
bool ok = absl::StrContains(base, substr);
EXPECT_TRUE(ok) << base << ", expected substring " << substr;
return ok;
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =
TFE_MonitoringNewCounter0("test/counter", status, "description");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell = TFE_MonitoringGetCellCounter0(counter);
TFE_MonitoringCounterCellIncrementBy(cell, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/counter",
metrics->point_set_map.at("test/counter")->metric_name);
EXPECT_EQ(
1, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringCounterCellIncrementBy(cell, 5);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell), 6);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(
6, metrics->point_set_map.at("test/counter")->points.at(0)->int64_value);
TFE_MonitoringDeleteCounter0(counter);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(metrics->point_set_map.end(),
metrics->point_set_map.find("test/counter"));
}
TEST(CAPI, MonitoringCounterMultiple) {
TF_Status* status = TF_NewStatus();
auto* counter1 = TFE_MonitoringNewCounter1("test/counter1", status,
"description", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellCounter1(counter1, "test");
TFE_MonitoringCounterCellIncrementBy(cell1, 1);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell1), 1);
auto* counter2 = TFE_MonitoringNewCounter2("test/counter2", status,
"description", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
auto* cell2 = TFE_MonitoringGetCellCounter2(counter2, "foo", "bar");
TFE_MonitoringCounterCellIncrementBy(cell2, 2);
EXPECT_EQ(TFE_MonitoringCounterCellValue(cell2), 2);
TFE_MonitoringDeleteCounter1(counter1);
TFE_MonitoringDeleteCounter2(counter2);
}
TEST(CAPI, MonitoringGauge0) {
TF_Status* status = TF_NewStatus();
auto* gauge = TFE_MonitoringNewIntGauge0("test/gauge", status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellIntGauge0(gauge);
TFE_MonitoringIntGaugeCellSet(cell, 1);
EXPECT_EQ(TFE_MonitoringIntGaugeCellValue(cell), 1);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/gauge", metrics->point_set_map.at("test/gauge")->metric_name);
EXPECT_EQ(1,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringIntGaugeCellSet(cell, 5);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(5,
metrics->point_set_map.at("test/gauge")->points.at(0)->int64_value);
TFE_MonitoringDeleteIntGauge0(gauge);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleGauge) {
TF_Status* status = TF_NewStatus();
auto* gauge1 =
TFE_MonitoringNewBoolGauge1("test/gauge1", status, "test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellBoolGauge1(gauge1, "foo");
TFE_MonitoringBoolGaugeCellSet(cell1, true);
EXPECT_TRUE(TFE_MonitoringBoolGaugeCellValue(cell1));
TFE_MonitoringDeleteBoolGauge1(gauge1);
auto* gauge2 = TFE_MonitoringNewStringGauge2("test/gauge2", status, "test",
"label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellStringGauge2(gauge2, "foo", "bar");
TFE_MonitoringStringGaugeCellSet(cell2, "str");
auto* buf = new TF_Buffer;
TFE_MonitoringStringGaugeCellValue(cell2, buf);
string data(static_cast<const char*>(buf->data), buf->length);
TF_DeleteBuffer(buf);
EXPECT_EQ(data, "str");
TFE_MonitoringDeleteStringGauge2(gauge2);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringSampler0) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler =
TFE_MonitoringNewSampler0("test/sampler", buckets, status, "test");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell = TFE_MonitoringGetCellSampler0(sampler);
TFE_MonitoringSamplerCellAdd(cell, 1.0);
auto* collection_registry = monitoring::CollectionRegistry::Default();
monitoring::CollectionRegistry::CollectMetricsOptions options;
std::unique_ptr<monitoring::CollectedMetrics> metrics =
collection_registry->CollectMetrics(options);
EXPECT_EQ("test/sampler",
metrics->point_set_map.at("test/sampler")->metric_name);
EXPECT_EQ(1.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringSamplerCellAdd(cell, 5.0);
metrics = collection_registry->CollectMetrics(options);
EXPECT_EQ(6.0, metrics->point_set_map.at("test/sampler")
->points.at(0)
->histogram_value.sum());
TFE_MonitoringDeleteBuckets(buckets);
TFE_MonitoringDeleteSampler0(sampler);
TF_DeleteStatus(status);
}
TEST(CAPI, MonitoringMultipleSampler) {
TF_Status* status = TF_NewStatus();
auto* buckets = TFE_MonitoringNewExponentialBuckets(1.0, 2.0, 2);
auto* sampler1 = TFE_MonitoringNewSampler1("test/sampler1", buckets, status,
"test", "label1");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell1 = TFE_MonitoringGetCellSampler1(sampler1, "foo");
TFE_MonitoringSamplerCellAdd(cell1, 1.0);
TFE_MonitoringSamplerCellAdd(cell1, 2.0);
TF_Buffer* result1 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell1, result1);
tensorflow::HistogramProto histogram1;
EXPECT_TRUE(histogram1.ParseFromString(
{reinterpret_cast<const char*>(result1->data), result1->length}));
EXPECT_EQ(histogram1.sum(), 3.0);
TF_DeleteBuffer(result1);
TFE_MonitoringDeleteSampler1(sampler1);
auto* sampler2 = TFE_MonitoringNewSampler2("test/sampler2", buckets, status,
"test", "label1", "label2");
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* cell2 = TFE_MonitoringGetCellSampler2(sampler2, "foo", "bar");
TFE_MonitoringSamplerCellAdd(cell2, 2.0);
TFE_MonitoringSamplerCellAdd(cell2, 3.0);
TF_Buffer* result2 = TF_NewBuffer();
TFE_MonitoringSamplerCellValue(cell2, result2);
tensorflow::HistogramProto histogram2;
EXPECT_TRUE(histogram2.ParseFromString(
{reinterpret_cast<const char*>(result2->data), result2->length}));
EXPECT_EQ(histogram2.sum(), 5.0);
TF_DeleteBuffer(result2);
TFE_MonitoringDeleteSampler2(sampler2);
TFE_MonitoringDeleteBuckets(buckets);
TF_DeleteStatus(status);
}
TEST(CAPI, CancellationManager) {
TFE_CancellationManager* c_mgr = TFE_NewCancellationManager();
EXPECT_FALSE(TFE_CancellationManagerIsCancelled(c_mgr));
TFE_CancellationManagerStartCancel(c_mgr);
EXPECT_TRUE(TFE_CancellationManagerIsCancelled(c_mgr));
TFE_DeleteCancellationManager(c_mgr);
}
TEST(CAPI, ExecutorContextDestructionOrder) {
TF_Status* status = TF_NewStatus();
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteContext(ctx);
TFE_DeleteExecutor(executor);
}
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TF_DeleteStatus(status);
}
TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();
TF_OperationDescription* arg_descr =
TF_NewOperation(function_graph, "Placeholder", "arg");
TF_SetAttrType(arg_descr, "dtype", TF_INT32);
TF_Status* status = TF_NewStatus();
TF_Operation* arg = TF_FinishOperation(arg_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_OperationDescription* id_descr =
TF_NewOperation(function_graph, "Identity", "id");
TF_SetAttrType(id_descr, "T", TF_INT32);
TF_AddInput(id_descr, {arg, 0});
TF_Operation* id = TF_FinishOperation(id_descr, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_Output input{arg, 0};
TF_Output output{id, 0};
TF_Function* fn =
TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1,
&output, nullptr, nullptr, "test", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteGraph(function_graph);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextAddFunction(ctx, fn, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteFunction(fn);
for (bool async : {false, true, false}) {
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t =
TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32));
*reinterpret_cast<tensorflow::int32*>(TF_TensorData(t)) = 42;
TFE_TensorHandle* h = TFE_NewTensorHandle(t, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteTensor(t);
TFE_Op* op = TFE_NewOp(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_OpAddInput(op, h, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
std::vector<TFE_TensorHandle*> result;
result.push_back(nullptr);
int num_retvals = 1;
TFE_Execute(op, result.data(), &num_retvals, status);
TFE_DeleteOp(op);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
ASSERT_EQ(num_retvals, 1);
TF_Tensor* r = TFE_TensorHandleResolve(result[0], status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
EXPECT_EQ(*reinterpret_cast<tensorflow::int32*>(TF_TensorData(r)), 42);
TFE_ContextSetExecutorForThread(ctx, old_executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteExecutor(old_executor);
TFE_DeleteTensorHandle(h);
TF_DeleteTensor(r);
TFE_DeleteTensorHandle(result[0]);
}
TFE_ContextRemoveFunction(ctx, "ident", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContext(ctx);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TF_DeleteStatus(status);
}
void Executor_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx);
TFE_Executor* executor = TFE_NewExecutor(async);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
int num_retvals = 2;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
TFE_ContextSetExecutorForThread(ctx, old_executor);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteExecutor(old_executor);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); }
TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); }
void Deleter(void* data, size_t unused, void* tensor_handle) {
TFE_DeleteTensorHandle(static_cast<TFE_TensorHandle*>(tensor_handle));
}
TEST(CAPI, TensorHandleOnDeviceMemory) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(ctx);
TF_Tensor* m_data = TFE_TensorHandleResolve(m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float* m_float = static_cast<float*>(TF_TensorData(m_data));
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int num_devices = TF_DeviceListCount(devices);
for (int d = 0; d < num_devices; ++d) {
const char* name = TF_DeviceListName(devices, d, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* copy = TFE_TensorHandleCopyToDevice(m, ctx, name, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
void* data = TFE_TensorHandleDevicePointer(copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
size_t size = TFE_TensorHandleDeviceMemorySize(copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
int64_t dims[] = {2, 2};
TFE_TensorHandle* copy_aliased = TFE_NewTensorHandleFromDeviceMemory(
ctx, name, TF_FLOAT, dims, 2, data, size, &Deleter, copy, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* on_host =
TFE_TensorHandleCopyToDevice(copy_aliased, ctx, "CPU:0", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* resolved = TFE_TensorHandleResolve(on_host, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
const float* resolved_data =
static_cast<const float*>(TF_TensorData(resolved));
EXPECT_EQ(0, memcmp(m_float, resolved_data, 4 * sizeof(float)));
TF_DeleteTensor(resolved);
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
TFE_DeleteTensorHandle(on_host);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(m_data);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
TEST(CAPI, TensorHandleNullptr) {
TFE_TensorHandle* h = nullptr;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_type = TFE_TensorHandleDeviceType(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_type, nullptr);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int device_id = TFE_TensorHandleDeviceID(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_id, -1);
ASSERT_EQ("Invalid handle", string(TF_Message(status.get())));
}
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
device_type = TFE_TensorHandleDeviceType(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "GPU")) << device_type;
device_id = TFE_TensorHandleDeviceID(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TEST(CAPI, TensorHandleDefaults) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* h_default = TestMatrixTensorHandle(ctx);
const char* device_type = TFE_TensorHandleDeviceType(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type, "CPU")) << device_type;
int device_id = TFE_TensorHandleDeviceID(h_default, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id) << device_id;
TFE_TensorHandle* h_cpu = TFE_TensorHandleCopyToDevice(
h_default, ctx, "/device:CPU:0", status.get());
const char* device_type_cpu = TFE_TensorHandleDeviceType(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_type_cpu, "CPU")) << device_type_cpu;
int device_id_cpu = TFE_TensorHandleDeviceID(h_cpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(0, device_id_cpu) << device_id_cpu;
TFE_DeleteTensorHandle(h_default);
TFE_DeleteTensorHandle(h_cpu);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,41 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
// TODO(b/154564140): Move this to its own header. This requires splitting
// c_api_experimental.h
struct TFE_ContextOptions {
TF_SessionOptions session_options;
// true if async execution is enabled.
bool async = false;
TFE_ContextDevicePlacementPolicy device_placement_policy{
TFE_DEVICE_PLACEMENT_SILENT};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend
bool use_tfrt = false;
};
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false,
bool has_packed_input = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs, has_packed_input);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesRemoteAsyncPackedInputFuncOrdering) {
// A remote input (packed) may be not ready when we start running a function.
// Test that the function execution should wait until the remote input is
// ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/true,
/*remote_func_outputs*/ true,
/*has_packed_input=*/true);
}
} // namespace

View File

@ -0,0 +1,140 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
}
} // namespace

View File

@ -0,0 +1,236 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using ::tensorflow::string;
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs,
bool has_packed_input) {
CHECK(!has_packed_input || func);
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* packed_handle = nullptr;
if (has_packed_input) {
int num_replicas = 1;
std::vector<TFE_TensorHandle*> packed_handles = {h1_task2};
packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(),
&num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
if (remote_func_outputs) {
const string backing_device =
TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(backing_device, task2_name);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
if (packed_handle) {
TFE_DeleteTensorHandle(packed_handle);
}
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one
// which creates a remote input, to simulate a scenario that the remote input
// is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false,
bool has_packed_input = false);
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,377 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string;
using tensorflow::tstring;
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value) {
float data[] = {value};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_STRING, nullptr, 0, status);
tstring* data = static_cast<tstring*>(TF_TensorData(t));
*data = value;
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value) {
int data[] = {value};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value) {
bool data[] = {value};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_BOOL, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
double data[] = {1.0, 2.0, 3.0, 4.0};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_DOUBLE, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];
float data[num_elements];
for (int i = 0; i < num_elements; ++i) {
data[i] = 1.0f;
}
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx) {
int64_t dims[] = {3, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name) {
TF_Status* status = TF_NewStatus();
// Create the variable handle.
TFE_Op* op = TFE_NewOp(ctx, "VarHandleOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
TFE_OpSetAttrString(op, "container", "localhost", 0);
TFE_OpSetAttrString(op, "shared_name", "", 0);
if (!device_name.empty()) {
TFE_OpSetDevice(op, device_name.c_str(), status);
}
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op, &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(1, num_retvals);
// Assign 'value' to it.
op = TFE_NewOp(ctx, "AssignVariableOp", status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpAddInput(op, var_handle, status);
// Convert 'value' to a TF_Tensor then a TFE_TensorHandle.
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> t(
TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(value)), TF_DeleteTensor);
memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get()));
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
value_handle(TFE_NewTensorHandle(t.get(), status),
TFE_DeleteTensorHandle);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op, value_handle.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
num_retvals = 0;
TFE_Execute(op, nullptr, &num_retvals, status);
TFE_DeleteOp(op);
if (TF_GetCode(status) != TF_OK) return nullptr;
CHECK_EQ(0, num_retvals);
TF_DeleteStatus(status);
return var_handle;
}
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "AddV2", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "MatMul", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "Identity", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx) {
int64_t dims[] = {1};
int data[] = {1};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0],
sizeof(dims) / sizeof(int64_t), status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
TFE_TensorHandle* axis) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "Min", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, input, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, axis, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrBool(op, "keep_dims", 1);
TFE_OpSetAttrType(op, "Tidx", TF_INT32);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input));
return op;
}
bool GetDeviceName(TFE_Context* ctx, string* device_name,
const char* device_type) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
for (int i = 0; i < num_devices; ++i) {
const string dev_type(TF_DeviceListType(devices, i, status.get()));
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
const string dev_name(TF_DeviceListName(devices, i, status.get()));
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
if (dev_type == device_type) {
*device_name = dev_name;
LOG(INFO) << "Found " << device_type << " device " << *device_name;
TF_DeleteDeviceList(devices);
return true;
}
}
TF_DeleteDeviceList(devices);
return false;
}
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}

View File

@ -0,0 +1,102 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
// Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
// Return a tensor handle containing a int scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value);
// Return a tensor handle containing a bool scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value);
// Return a tensor handle containing a tstring scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx,
const tensorflow::tstring& value);
// Return a tensor handle containing a 2x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing 2D matrix containing given data and
// dimensions
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
// Get a Matrix TensorHandle with given float values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims);
// Get a Matrix TensorHandle with given int values and dimensions
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of doubles
TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a tensor handle containing a 3x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx);
// Return a variable handle referring to a variable with the given initial value
// on the given device.
TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value,
const tensorflow::string& device_name = "");
// Return an add op multiplying `a` by `b`.
TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return an identity op.
TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a);
// Return a shape op fetching the shape of `a`.
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx);
// Return an op taking minimum of `input` long `axis` dimension.
TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
TFE_TensorHandle* axis);
// If there is a device of type `device_type`, returns true
// and sets 'device_name' accordingly.
// `device_type` must be either "GPU" or "TPU".
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
const char* device_type);
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
int num_tasks);
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(int num_tasks);
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_

View File

@ -0,0 +1,246 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::string;
namespace tensorflow {
namespace tracing {
typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap;
static FactoriesMap& GetFactories() {
static FactoriesMap* factories = new FactoriesMap;
return *factories;
}
static tracing::FactoryFunction default_factory;
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
assert((!GetFactories().count(name)) ||
(GetFactories()[name] == factory) &&
"Duplicate tracing factory registration");
GetFactories()[name] = factory;
}
Status SetDefaultTracingEngine(const char* name) {
auto entry = GetFactories().find(name);
if (entry != GetFactories().end()) {
default_factory = GetFactories().find(name)->second;
return Status::OK();
}
string msg = absl::StrCat(
"No tracing engine factory has been registered with the key '", name,
"' (available: ");
// Ensure deterministic (sorted) order in the error message
std::set<string> factories_sorted;
for (const auto& factory : GetFactories())
factories_sorted.insert(factory.first);
const char* comma = "";
for (const string& factory : factories_sorted) {
msg += comma + factory;
comma = ", ";
}
msg += ")";
return errors::InvalidArgument(msg.c_str());
}
static TracingContext* CreateTracingExecutionContext(const char* fn_name,
TF_Status* s) {
if (default_factory) {
return default_factory(fn_name, s);
}
Set_TF_Status_from_Status(
s, errors::FailedPrecondition("default_factory is nullptr"));
return nullptr;
}
} // end namespace tracing
} // end namespace tensorflow
// =============================================================================
// Public C API entry points
//
// These are only the generic entry points for the C API. This file does not
// have any visibility into the graph/eager implementation and is only providing
// C bindings to the abstract classes defined in the
// c_api_unified_experimental_internal.h header.
//
// =============================================================================
using tensorflow::AbstractFunction;
using tensorflow::AbstractTensorHandle;
using tensorflow::DataType;
using tensorflow::dyn_cast;
using tensorflow::OutputList;
using tensorflow::Status;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::tracing::CreateTracingExecutionContext;
using tensorflow::tracing::SetDefaultTracingEngine;
using tensorflow::tracing::TracingContext;
using tensorflow::tracing::TracingOperation;
using tensorflow::tracing::TracingTensorHandle;
void TF_SetTracingImplementation(const char* name, TF_Status* s) {
Set_TF_Status_from_Status(s, SetDefaultTracingEngine(name));
}
// Creates a new TensorFlow function, it is an execution context attached to a
// given tracing context.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
return wrap(CreateTracingExecutionContext(fn_name, s));
}
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList* outputs, TF_Status* s) {
AbstractFunction* func;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"Only TracingContext can be converted into a function."));
return nullptr;
}
Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func));
TF_DeleteExecutionContext(ctx);
return wrap(func);
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Shape shape,
TF_Status* s) {
DCHECK_GE(shape.num_dims, -1);
TracingTensorHandle* t;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
if (!tracing_ctx) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AddFunctionParameter must be called on a TracingContext."));
return nullptr;
}
tensorflow::PartialTensorShape partial_shape;
if (shape.num_dims != -1) {
DCHECK(shape.dim_sizes != nullptr);
Status status = tensorflow::PartialTensorShape::MakePartialShape(
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
&partial_shape);
if (!status.ok()) {
Set_TF_Status_from_Status(s, status);
return nullptr;
}
}
Set_TF_Status_from_Status(
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
&t));
return wrap(t);
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); }
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
return wrap((unwrap(c)->CreateOperation()));
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
unwrap(o)->expected_num_outputs = num_outputs;
unwrap(o)->outputs.clear();
unwrap(o)->outputs.resize(num_outputs);
}
int TF_OutputListNumOutputs(TF_OutputList* o) {
return unwrap(o)->outputs.size();
}
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return wrap(unwrap(o)->outputs[i]);
}
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status* s) {
unwrap(o)->outputs.push_back(unwrap(tensor));
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type,
/*raw_device_name=*/nullptr));
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op));
if (!tracing_op) {
Set_TF_Status_from_Status(
s, tensorflow::errors::InvalidArgument(
"TF_AbstractOpSetOpName must be called on a TracingOperation."));
return;
}
Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name));
}
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s) {
Status status =
unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value));
TF_SetStatus(s, static_cast<TF_Code>(status.code()),
status.error_message().c_str());
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s) {
for (int i = 0; i < num_inputs; i++) {
Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i])));
if (TF_GetCode(s) != TF_OK) {
return;
}
}
int num_outputs = unwrap(o)->expected_num_outputs;
Set_TF_Status_from_Status(
s, unwrap(op)->Execute(
absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>(
unwrap(o)->outputs.data()),
unwrap(o)->outputs.size()),
&num_outputs));
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
delete unwrap(func);
}
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func)));
}

View File

@ -0,0 +1,153 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#ifdef __cplusplus
extern "C" {
#endif
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
// -----------------------------------------------------------------------------
// Core APIs
// -----------------------------------------------------------------------------
// A TF_ExecutionContext stores knowledge about how to execute an operation.
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors. It is also the result of executing an
// operation.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
// Stores a function representation that can be used for execution or for
// setting functional attributes of other composite ops e.g. control flow.
typedef struct TF_AbstractFunction TF_AbstractFunction;
// This allows the client to swap the implementation of the tracing engine.
// Any future call to TF_CreateFunction will use the implementation defined
// here.
void TF_SetTracingImplementation(const char* name, TF_Status*);
// Creates a new TensorFlow function. A Function is an execution context, and as
// such it can trace operations through TF_ExecuteOperation. After completing
// tracing, a function can be obtained by TF_FinalizeFunction.
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
// Creates a context for eager execution of operations.
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Represents a (partially-defined) shape.
typedef struct TF_Shape {
int num_dims; // Must be >= -1; -1 represents unknown rank.
int64_t* dim_sizes;
} TF_Shape;
// Add a new parameter to a TensorFlow Function.
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Shape shape,
TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
void TF_DeleteAbstractOp(TF_AbstractOp*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s);
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// `attr_name` must outlive `op`.
void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name,
TF_DataType value, TF_Status* s);
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// TF_OutputList holds the list of TF_AbstractTensor that results from executing
// an operation, or provided to create a function.
// When executing an operation in an eager context, the expected number of
// outputs must be set beforehand with `TF_OutputListSetNumOutputs`.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
// Prepare tracing to the expected number of output for an operation.
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*);
// Return the number of outputs in the list.
int TF_OutputListNumOutputs(TF_OutputList* o);
// Return the `i`th output in the list.
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// Append a tensor at the end of the output list, growing its size by one.
void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor,
TF_Status*);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph. The output tensors are
// returned through the provided TF_OutputList.
// Any active tape will observe the effects of this execution.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_Status* s);
// Creates a new TF_AbstractFunction from the current tracing states in the
// context. The provided `ctx` is consumed by this API call and deleted.
// The returned TF_AbstractFunction must be deleted by the client,
// TODO(aminim): clarify the contract on the state of the context after this
// call.
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
TF_OutputList*, TF_Status*);
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
// Register the function with the given context. This is particularly useful for
// making a function available to an eager context.
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*,
TF_AbstractFunction*, TF_Status*);
// -----------------------------------------------------------------------------
// APIs specific to Eager modes
// -----------------------------------------------------------------------------
// Temporary APIs till we figure out how to create scalar valued Eager
// tensors and how to get value out of eager abstract tensors.
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*,
TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/strcat.h"
// =============================================================================
// Public C API entry points
// These are only the entry points specific to the Eager API.
// =============================================================================
using tensorflow::AbstractContext;
using tensorflow::AbstractTensorHandle;
using tensorflow::dyn_cast;
using tensorflow::ImmediateExecutionContext;
using tensorflow::ImmediateExecutionTensorHandle;
using tensorflow::string;
using tensorflow::unwrap;
using tensorflow::wrap;
using tensorflow::strings::StrCat;
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options,
TF_Status* s) {
TFE_Context* c_ctx = TFE_NewContext(options, s);
if (TF_GetCode(s) != TF_OK) {
return nullptr;
}
return wrap(static_cast<AbstractContext*>(unwrap(c_ctx)));
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
TF_Status* s) {
return wrap(static_cast<AbstractTensorHandle*>(unwrap(t)));
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at));
if (!handle) {
string msg =
StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return wrap(handle);
}
TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx,
TF_Status* s) {
auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx));
if (!imm_ctx) {
string msg =
StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return wrap(imm_ctx);
}

View File

@ -0,0 +1,447 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/types.h"
using tensorflow::dyn_cast;
using tensorflow::string;
using tensorflow::gtl::ArraySlice;
namespace tensorflow {
namespace tracing {
namespace graph {
class GraphContext;
class GraphOperation;
class GraphTensor;
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output, TF_Graph* graph)
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
}
tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const override {
DCHECK(shape != nullptr);
TF_Status status;
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
DCHECK_GE(num_dims, -1);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
if (num_dims == kUnknownRank) {
return Status::OK();
}
std::vector<int64> dims(num_dims, kUnknownDim);
TF_GraphGetTensorShape(graph_, output_,
reinterpret_cast<int64_t*>(dims.data()), num_dims,
&status);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
return Status::OK();
}
TF_Output output_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph;
}
private:
TF_Graph* graph_; // For shape inference.
};
// GraphOperation wraps and populates a TF_OperationDescription.
class GraphOperation : public TracingOperation {
public:
explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {}
void Release() override { delete this; }
Status Reset(const char* op, const char* raw_device_name) override {
if (op_) {
return errors::FailedPrecondition("Reset called on already built op.");
}
if (raw_device_name) {
device_name_ = raw_device_name;
}
op_type_ = op;
return Status::OK();
}
Status SetOpName(const char* const op_name) override {
if (op_) {
return errors::FailedPrecondition(
"SetOpName called on already built op.");
}
if (op_type_.empty()) {
return errors::FailedPrecondition(
"GraphOperation::Reset must be called before calling SetOpName.");
}
// TODO(b/145674566): We use Graph::NewName to get a unique name here but
// this may not be consistent with python's naming policy.
mutex_lock l(g_->mu);
op_.reset(new TF_OperationDescription(g_, op_type_.c_str(),
g_->graph.NewName(op_name).c_str()));
return Status::OK();
}
const string& Name() const override { return op_type_; }
const string& DeviceName() const override { return device_name_; }
Status SetDeviceName(const char* name) override {
// TODO(srbs): Implement this.
device_name_ = name;
return Status::OK();
}
Status AddInput(AbstractTensorHandle* input) override {
GraphTensor* t = dyn_cast<GraphTensor>(input);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
TF_AddInput(op_.get(), t->output_);
return Status::OK();
}
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override {
std::vector<TF_Output> tf_outputs(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
GraphTensor* t = dyn_cast<GraphTensor>(inputs[i]);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
tf_outputs[i] = t->output_;
}
TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size());
return Status::OK();
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {
auto* tf_opdesc = op_.release();
if (tf_opdesc == nullptr) {
return errors::InvalidArgument("AbstractOp is incomplete.");
}
TF_Status* s = TF_NewStatus();
auto* operation = TF_FinishOperation(tf_opdesc, s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*num_retvals = TF_OperationNumOutputs(operation);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new GraphTensor({operation, i}, g_);
}
return Status::OK();
}
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override {
tensorflow::StringPiece s(data, length);
op_->node_builder.Attr(attr_name, s);
return Status::OK();
}
Status SetAttrInt(const char* attr_name, int64_t value) override {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
return Status::OK();
}
Status SetAttrFloat(const char* attr_name, float value) override {
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrBool(const char* attr_name, bool value) override {
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) {
return Status(
error::Code::FAILED_PRECONDITION,
"op_type and op_name must be specified before specifying attrs.");
}
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override {
PartialTensorShape shape;
if (num_dims >= 0) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
}
op_->node_builder.Attr(attr_name, shape);
return Status::OK();
}
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override {
tensorflow::NameAttrList func_name;
func_name.set_name(string(value, value + length));
op_->node_builder.Attr(attr_name, func_name);
return Status::OK();
}
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override {
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
op_->colocation_constraints.clear();
for (int i = 0; i < num_values; ++i) {
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
lengths[i]);
}
} else {
std::vector<tensorflow::StringPiece> v;
v.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
}
op_->node_builder.Attr(attr_name, v);
}
return Status::OK();
}
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override {
op_->node_builder.Attr(attr_name,
ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(
attr_name,
ArraySlice<const tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(values), num_values));
return Status::OK();
}
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override {
op_->node_builder.Attr(attr_name,
ArraySlice<const DataType>(values, num_values));
return Status::OK();
}
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op_->node_builder.Attr(attr_name,
ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override {
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
if (num_dims[i] < 0) {
shapes.emplace_back();
} else {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shapes.emplace_back(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
}
}
op_->node_builder.Attr(attr_name, shapes);
return Status::OK();
}
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph;
}
~GraphOperation() override {}
private:
friend class GraphContext; // For access to op_.
TF_Graph* g_;
std::unique_ptr<TF_OperationDescription> op_;
// Hold `op_type` and `op_name` till both are available since we need both
// to build a graph operation.
string op_type_;
const char* op_name_ = nullptr;
// TODO(srbs): Use this.
string device_name_;
};
// GraphFunction is a thin wrapper over a TF_Function.
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kGraph) {}
explicit GraphFunction(TF_Function* func)
: AbstractFunction(kGraph), func(func) {}
~GraphFunction() override {
if (func) TF_DeleteFunction(func);
}
Status GetFunctionDef(FunctionDef** fdef) override {
*fdef = &func->fdef;
return Status::OK();
}
// For LLVM style RTTI.
static bool classof(const AbstractFunction* ptr) {
return ptr->getKind() == kGraph;
}
};
// GraphContext wraps a TF_Graph modeling a single function and manages the
// "execution" of operation, i.e. adding them to the function.
class GraphContext : public TracingContext {
public:
explicit GraphContext(const char* name)
: TracingContext(kGraph),
graph_(new TF_Graph(), TF_DeleteGraph),
name_(name) {}
void Release() override { delete this; }
TracingOperation* CreateOperation() override {
return new GraphOperation(graph_.get());
}
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle** output) override {
TracingOperationPtr operation(CreateOperation());
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
TF_RETURN_IF_ERROR(
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
if (!shape.unknown_rank()) {
TF_RETURN_IF_ERROR(operation->SetAttrShape(
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
shape.dims()));
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(operation->Execute(
absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
if (num_outputs != 1) {
return errors::Internal("Expected 1 output but found ", num_outputs);
}
auto* t = dyn_cast<GraphTensor>(outputs[0]);
if (!t) {
return tensorflow::errors::InvalidArgument(
"Unable to cast input to GraphTensor");
}
inputs_.push_back(t->output_);
*output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]);
return Status::OK();
}
Status Finalize(OutputList* outputs, AbstractFunction** f) override {
std::unique_ptr<GraphFunction> func(new GraphFunction);
std::vector<TF_Output> graph_outputs;
graph_outputs.reserve(outputs->outputs.size());
for (auto* abstract_output : outputs->outputs) {
GraphTensor* output = dyn_cast<GraphTensor>(abstract_output);
if (!output) {
return errors::Unimplemented(
"Returning a non-graph tensor from a function has not "
"been implemented yet.");
}
graph_outputs.push_back(output->output_);
}
auto s = TF_NewStatus();
func->func = TF_GraphToFunction(graph_.get(), name_.data(), 0, -1, nullptr,
inputs_.size(), inputs_.data(),
graph_outputs.size(), graph_outputs.data(),
nullptr, nullptr, name_.data(), s);
TF_RETURN_IF_ERROR(StatusFromTF_Status(s));
TF_DeleteStatus(s);
*f = func.release();
return Status::OK();
}
Status RegisterFunction(AbstractFunction* func) override {
return errors::Unimplemented(
"Registering graph functions has not been implemented yet.");
}
Status RemoveFunction(const string& func) override {
return errors::Unimplemented(
"GraphContext::RemoveFunction has not been implemented yet.");
}
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph;
}
private:
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
std::vector<TF_Output> inputs_;
string name_;
};
static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) {
return new GraphContext(name);
}
// Register the tracing implemented in this file as the default tracing engine.
static bool register_tracing = [] {
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
SetDefaultTracingEngine("graphdef").IgnoreError();
return true;
}();
} // namespace graph
} // namespace tracing
} // namespace tensorflow

View File

@ -0,0 +1,137 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_
#include <vector>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Represents the results of the execution of an operation.
struct OutputList {
std::vector<AbstractTensorHandle*> outputs;
int expected_num_outputs = -1;
};
namespace tracing {
// =============================================================================
// Implementation detail for the unified execution APIs for Eager and tracing
// backends (graph/MLIR).
//
// This defines a set of abstract classes that are intended to provide the
// functionality of the opaque C types exposed in the public APIs defined in the
// `c_api_unified_experimental.h` header.
// =============================================================================
// Represents either a MlirTensor or a GraphTensor.
// This base class does not expose any public methods other than to distinguish
// which subclass it actually is. The user is responsible to use the right
// type of AbstractTensor in their context (do not pass an MlirTensor to a
// GraphContext and vice-versa).
class TracingTensorHandle : public AbstractTensorHandle {
protected:
explicit TracingTensorHandle(AbstractTensorHandleKind kind)
: AbstractTensorHandle(kind) {}
public:
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
// An abstract operation describes an operation by its type, name, and
// attributes. It can be "executed" by the context with some input tensors.
// It is allowed to reusing the same abstract operation for multiple execution
// on a given context, with the same or different input tensors.
class TracingOperation : public AbstractOperation {
protected:
explicit TracingOperation(AbstractOperationKind kind)
: AbstractOperation(kind) {}
public:
// Sets the name of the operation: this is an optional identifier that is
// not intended to carry semantics and preserved/propagated without
// guarantees.
virtual Status SetOpName(const char* op_name) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
namespace internal {
struct TracingOperationDeleter {
void operator()(TracingOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using TracingOperationPtr =
std::unique_ptr<TracingOperation, internal::TracingOperationDeleter>;
// This holds the context for the execution: dispatching operations either to an
// MLIR implementation or to a graph implementation.
class TracingContext : public AbstractContext {
protected:
explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {}
public:
// Add a function parameter and return the corresponding tensor.
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle**) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.
virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kGraph || ptr->getKind() == kMlir;
}
};
typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
Status SetDefaultTracingEngine(const char* name);
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
FactoryFunction factory);
} // namespace tracing
DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext)
DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor)
DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction)
DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp)
DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList)
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,383 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "absl/strings/match.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/custom_device_testutil.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* context = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(context);
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
ASSERT_FALSE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context, hcpu, hdevice), TFE_DeleteOp);
TFE_OpSetDevice(matmul.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteTensorHandle(hdevice);
TFE_DeleteContext(context);
}
TEST(CUSTOM_DEVICE, ResetOperation) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts, status.get()), TFE_DeleteContext);
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* custom_device_name =
"/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), custom_device_name,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> reused_op(
TFE_NewOp(context.get(), "Identity", status.get()), TFE_DeleteOp);
TFE_OpReset(reused_op.get(), "Identity", custom_device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string(custom_device_name));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpReset(reused_op.get(), "Identity",
"/job:localhost/replica:0/task:0/device:CPU:0", status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(tensorflow::string(TFE_OpGetDevice(reused_op.get(), status.get())),
tensorflow::string("/job:localhost/replica:0/task:0/device:CPU:0"));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, MakeVariable) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto value_cleaner = tensorflow::gtl::MakeCleanup(
[var_value]() { TFE_DeleteTensorHandle(var_value); });
ASSERT_EQ(tensorflow::string(name),
tensorflow::string(
TFE_TensorHandleBackingDeviceName(var_value, status.get())));
TFE_TensorHandle* var_value_unpacked =
UnpackTensorHandle(var_value, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> resolved_value(
TFE_TensorHandleResolve(var_value_unpacked, status.get()),
TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_EQ(111., *static_cast<float*>(TF_TensorData(resolved_value.get())));
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, AccessVariableOnCustomDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/false,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle placed on the custom device.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op.get(), "shape", {}, 0, status.get());
TFE_OpSetAttrString(op.get(), "container", "", 0);
TFE_OpSetAttrString(op.get(), "shared_name", "", 0);
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
executed = false;
TFE_Execute(op.get(), &var_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
auto handle_cleaner = tensorflow::gtl::MakeCleanup(
[var_handle]() { TFE_DeleteTensorHandle(var_handle); });
// Assign to the variable, copying to the custom device.
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> one(
TestScalarTensorHandle(context.get(), 111.f), TFE_DeleteTensorHandle);
op.reset(TFE_NewOp(context.get(), "AssignVariableOp", status.get()));
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpAddInput(op.get(), one.get(), status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
executed = false;
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
// Read the variable's value.
op.reset(TFE_NewOp(context.get(), "ReadVariableOp", status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpAddInput(op.get(), var_handle, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(op.get(), "dtype", TF_FLOAT);
executed = false;
num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
ASSERT_EQ(
tensorflow::string(name),
tensorflow::string(TFE_TensorHandleDeviceName(var_value, status.get())));
TFE_DeleteTensorHandle(var_value);
// Free the backing buffer for the variable.
op.reset(TFE_NewOp(context.get(), "DestroyResourceOp", status.get()));
TFE_OpAddInput(op.get(), var_handle, status.get());
TFE_OpSetDevice(op.get(), name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(CUSTOM_DEVICE, InputBasedPlacement) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const char* custom0 = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
const char* custom1 = "/job:localhost/replica:0/task:0/device:CUSTOM:1";
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), custom0,
/*strict_scope_placement=*/false, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), custom1,
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcpu(
TestMatrixTensorHandle(context.get()), TFE_DeleteTensorHandle);
ASSERT_FALSE(arrived);
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom0(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom0,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
arrived = false;
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> hcustom1(
TFE_TensorHandleCopyToDevice(hcpu.get(), context.get(), custom1,
status.get()),
TFE_DeleteTensorHandle);
ASSERT_TRUE(arrived);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Base case: two CPU inputs executes fine.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> matmul(
MatMulOp(context.get(), hcpu.get(), hcpu.get()), TFE_DeleteOp);
TFE_TensorHandle* retval;
int num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in same custom device works.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom0.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ASSERT_TRUE(executed);
TFE_DeleteTensorHandle(retval);
// Custom device: inputs in different custom devices fails.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcustom1.get()));
num_retvals = 1;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
ASSERT_NE(TF_OK, TF_GetCode(status.get()));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom0));
ASSERT_TRUE(absl::StrContains(TF_Message(status.get()), custom1));
// Custom device: mix of custom/physical places the op on the custom device.
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_TRUE(executed);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteTensorHandle(retval);
// Explicit placement still forces the op onto the requested device
matmul.reset(MatMulOp(context.get(), hcustom0.get(), hcpu.get()));
TFE_OpSetDevice(matmul.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
// Custom devices can refuse to do type-based dispatch (as hcustom1 is
// configured to do)
matmul.reset(MatMulOp(context.get(), hcustom1.get(), hcpu.get()));
num_retvals = 1;
executed = false;
TFE_Execute(matmul.get(), &retval, &num_retvals, status.get());
EXPECT_FALSE(executed);
ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK);
}
TEST(CUSTOM_DEVICE, InvalidRegistrationError) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
bool executed = false;
RegisterLoggingDevice(context.get(), "/device:CUSTOM:0",
/*strict_scope_placement=*/true, &arrived, &executed,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
<< TF_Message(status.get());
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
RegisterLoggingDevice(context.get(), name, /*strict_scope_placement=*/true,
&arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
RegisterLoggingDevice(
context.get(), "/job:localhost/replica:0/task:0/device:CPU:0",
/*strict_scope_placement=*/true, &arrived, &executed, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_ALREADY_EXISTS)
<< TF_Message(status.get());
}

View File

@ -0,0 +1,194 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
// Set to true whenever an operation is executed
bool* executed_flag;
// If true, only explicit op placements are accepted. If false, uses
// type-based dispatch.
bool strict_scope_placement;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* context, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
const char* requested_placement = TFE_OpGetDevice(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
if (dev->strict_scope_placement && *requested_placement == '\0') {
TF_SetStatus(s, TF_INTERNAL,
"Ops must be placed on the device explicitly, or their inputs "
"first copied to other devices.");
return;
}
TFE_Context* context = TFE_OpGetContext(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const char* operation_name = TFE_OpGetName(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
TFE_Op* op(TFE_NewOp(context, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddAttrs(op, attributes);
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
if (TF_GetCode(s) != TF_OK) return;
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
if (TF_GetCode(s) != TF_OK) return;
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
std::move(logged_tensor), s);
}
*(dev->executed_flag) = true;
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
} // namespace
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->arrived_flag = arrived_flag;
device->executed_flag = executed_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
device->strict_scope_placement = strict_scope_placement;
TFE_RegisterCustomDevice(context, custom_device, name, device, status);
}
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status) {
return reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
->tensor;
}
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info) {
TFE_CustomDevice* custom_device = new TFE_CustomDevice;
custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device->delete_device = &DeleteLoggingDevice;
custom_device->execute = &LoggingDeviceExecute;
*device = custom_device;
LoggingDevice* logging_device = new LoggingDevice;
logging_device->arrived_flag = arrived_flag;
logging_device->executed_flag = executed_flag;
logging_device->device_name = name;
logging_device->underlying_device =
"/job:localhost/replica:0/task:0/device:CPU:0";
logging_device->strict_scope_placement = true;
*device_info = reinterpret_cast<void*>(logging_device);
}

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_

View File

@ -0,0 +1,348 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/dlpack.h"
#include "include/dlpack/dlpack.h" // from @dlpack
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace {
// Managing context for the DLManagedTensor, will manage the lifetime of
// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
// original framework of destruction, and this context will be deleted also.
struct TfDlManagedTensorCtx {
TensorReference reference;
std::vector<int64_t> shape;
std::vector<int64_t> strides;
DLManagedTensor tensor;
explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
};
// Gets tensor from eager tensor handle.
const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (handle->Type() != TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support ", handle->TypeString(), " tensor");
return nullptr;
}
const tensorflow::Tensor* tensor;
status->status = handle->Tensor(&tensor);
if (!status->status.ok()) {
return nullptr;
}
return tensor;
}
// Deleter for DLManagedTensor
void DLManagedTensorDeleter(DLManagedTensor* arg) {
TfDlManagedTensorCtx* owner =
static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
owner->reference.Unref();
delete owner;
}
// Converts TF_DATAType to DLPack data type.
DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
DLDataType dtype;
dtype.lanes = 1;
dtype.bits = TF_DataTypeSize(data_type) * 8;
switch (data_type) {
case TF_DataType::TF_HALF:
case TF_DataType::TF_FLOAT:
case TF_DataType::TF_DOUBLE:
dtype.code = DLDataTypeCode::kDLFloat;
break;
case TF_DataType::TF_INT8:
case TF_DataType::TF_INT16:
case TF_DataType::TF_INT32:
case TF_DataType::TF_INT64:
dtype.code = DLDataTypeCode::kDLInt;
break;
case TF_DataType::TF_BOOL:
case TF_DataType::TF_UINT8:
case TF_DataType::TF_UINT16:
case TF_DataType::TF_UINT32:
case TF_DataType::TF_UINT64:
dtype.code = DLDataTypeCode::kDLUInt;
break;
case TF_DataType::TF_BFLOAT16:
dtype.code = DLDataTypeCode::kDLBfloat;
break;
default:
status->status = tensorflow::errors::InvalidArgument(
DataType_Name(static_cast<DataType>(data_type)),
" is not supported by dlpack");
break;
}
return dtype;
}
// Gets DLPack's DLContext from eager tensor handle.
DLContext GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
DLContext ctx;
const char* device_name =
tensorflow::unwrap(h)->BackingDeviceName(&status->status);
DeviceNameUtils::ParsedName parsed_name;
tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
std::string device_type = parsed_name.type;
int device_id = 0;
if (parsed_name.has_id) {
device_id = parsed_name.id;
}
ctx.device_id = device_id;
if (device_type == "CPU") {
ctx.device_type = DLDeviceType::kDLCPU;
} else if (device_type == "GPU") {
ctx.device_type = DLDeviceType::kDLGPU;
} else {
status->status = tensorflow::errors::InvalidArgument(
"Unsupported Device Type for dlpack");
}
return ctx;
}
// Converts DLContext to TF device name.
absl::optional<std::string> DeviceNameFromDlContext(const DLContext& ctx,
TF_Status* status) {
switch (ctx.device_type) {
case DLDeviceType::kDLCPU:
return "CPU:0";
case DLDeviceType::kDLGPU:
return absl::StrCat("GPU:", ctx.device_id);
default:
return absl::nullopt;
}
}
// Converts DLPack data type to TF_DATATYPE.
Status TfDataTypeFormDlDataType(const DLDataType& dtype,
TF_DataType* tf_dtype) {
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_UINT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_UINT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_UINT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_UINT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLInt:
switch (dtype.bits) {
case 8:
*tf_dtype = TF_DataType::TF_INT8;
return Status::OK();
case 16:
*tf_dtype = TF_DataType::TF_INT16;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_INT32;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_INT64;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
dtype.bits);
}
return Status::OK();
case DLDataTypeCode::kDLFloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_HALF;
return Status::OK();
case 32:
*tf_dtype = TF_DataType::TF_FLOAT;
return Status::OK();
case 64:
*tf_dtype = TF_DataType::TF_DOUBLE;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
dtype.bits);
}
break;
case DLDataTypeCode::kDLBfloat:
switch (dtype.bits) {
case 16:
*tf_dtype = TF_DataType::TF_BFLOAT16;
return Status::OK();
default:
return tensorflow::errors::InvalidArgument(
"Unsupported BFloat bits: ", dtype.bits);
}
break;
default:
return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
dtype.code);
}
}
// Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
TFE_CallDLManagedTensorDeleter(dlmt_vptr);
}
// Checks whether the stride array matches the layout of compact, row-majored
// data.
bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
int ndim) {
if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
return false;
}
for (int i = ndim - 2; i >= 0; --i) {
if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
return false;
}
}
return true;
}
} // namespace
void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
if (dlMTensor->deleter != nullptr) {
dlMTensor->deleter(dlMTensor);
}
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
shape_arr->resize(ndim);
stride_arr->resize(ndim, 1);
for (int i = 0; i < ndim; i++) {
(*shape_arr)[i] = tensor->dim_size(i);
}
for (int i = ndim - 2; i >= 0; --i) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = stride_arr->data();
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
}
TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_Context* ctx) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
DLTensor* dl_tensor = &dlmt->dl_tensor;
absl::optional<std::string> device_name =
DeviceNameFromDlContext(dl_tensor->ctx, status);
if (!device_name.has_value()) {
status->status =
tensorflow::errors::InvalidArgument("Unsupported Device Type");
return nullptr;
}
TF_DataType dtype;
Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
if (!s.ok()) {
status->status = std::move(s);
return nullptr;
}
int num_dims = dl_tensor->ndim;
const int64_t* dims = dl_tensor->shape;
void* data = dl_tensor->data;
size_t total_bytes = dl_tensor->dtype.bits / 8;
for (int i = 0; i < num_dims; i++) {
total_bytes *= dims[i];
}
if (dl_tensor->strides != nullptr &&
!IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
num_dims)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid strides array from DLPack");
return nullptr;
}
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, dlmt, status);
return handle;
}
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_DLPACK_H_
#define TENSORFLOW_C_EAGER_DLPACK_H_
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
// PyCapsule name for DLPack Tensor
const char* const kDlTensorCapsuleName = "dltensor";
// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the
// void* for further PyCapsule construction.
TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h,
TF_Status* status);
// Converts DLPack (DLManagedTensor*) to eager tensor handle.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm,
TF_Status* status,
TFE_Context* ctx);
// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule.
TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr);
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_DLPACK_H_

View File

@ -0,0 +1,204 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_tensor.h"
namespace tensorflow {
namespace gradients {
using namespace std;
// ================== Helper functions =================
// Fills data with values [start,end) with given step size.
void Range(vector<int>* data, int start, int end, int step = 1) {
for (int i = start; i < end; i += step) {
(*data)[i] = i;
}
}
// Fills out_dims with the dimensions of the given tensor.
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t);
for (int i = 0; i < num_dims; i++) {
out_dims[i] = TF_Dim(t, i);
}
}
// Runs model as is if output is a scalar,
// else sums the output tensor before returning.
Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
bool use_function) {
std::vector<AbstractTensorHandle*> model_outputs(1);
// Run the model.
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
absl::MakeSpan(model_outputs), use_function));
AbstractTensorHandlePtr model_out(model_outputs[0]);
TF_Tensor* model_out_tensor;
TF_RETURN_IF_ERROR(GetValue(model_out.get(), &model_out_tensor));
int num_dims_out = TF_NumDims(model_out_tensor);
TF_DeleteTensor(model_out_tensor);
// If the output is a scalar, then return the scalar output
if (num_dims_out == 0) {
outputs[0] = model_out.release();
return Status::OK();
}
// Else, reduce sum the output to get a scalar
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims;
{
vector<int> vals(num_dims_out);
int64_t vals_shape[] = {num_dims_out};
Range(&vals, 0, num_dims_out);
AbstractTensorHandle* sum_dims_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
1, &sum_dims_raw));
sum_dims.reset(sum_dims_raw);
}
// Reduce sum the output on all dimensions.
TF_RETURN_IF_ERROR(
ops::Sum(ctx, {model_out.get(), sum_dims.get()}, outputs, "sum_output"));
return Status::OK();
}
// ========================= End Helper Functions==============================
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle* const> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad) {
vector<AbstractTensorHandle*> theta_inputs(inputs.size());
for (int i{}; i < inputs.size(); ++i) {
theta_inputs[i] = inputs[i];
}
AbstractTensorHandle* theta =
theta_inputs[input_index]; // parameter we are grad checking
// Convert from AbstractTensor to TF_Tensor.
TF_Tensor* theta_tensor;
TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor));
// Get number of elements and fill data.
int num_elems = TF_TensorElementCount(theta_tensor);
vector<float> theta_data(num_elems);
memcpy(theta_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
// Initialize space for the numerical gradient.
vector<float> dtheta_approx(num_elems);
// Get theta shape and store in theta_dims.
int num_dims = TF_NumDims(theta_tensor);
vector<int64_t> theta_dims(num_dims);
GetDims(theta_tensor, theta_dims.data());
// Initialize auxilary data structures.
vector<float> thetaPlus_data(num_elems);
vector<float> thetaMinus_data(num_elems);
std::vector<AbstractTensorHandle*> f_outputs(1);
// Numerical Grad Check
for (int i = 0; i < num_elems; i++) {
// Get relative epsilon value
float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4);
AbstractTensorHandlePtr two_eps;
{
AbstractTensorHandle* two_eps_raw = nullptr;
TF_RETURN_IF_ERROR(
TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
two_eps.reset(two_eps_raw);
}
// Initialize theta[i] + epsilon.
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus;
{
AbstractTensorHandle* thetaPlus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
&thetaPlus_raw));
thetaPlus.reset(thetaPlus_raw);
}
// Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus;
{
AbstractTensorHandle* thetaMinus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
&thetaMinus_raw));
thetaMinus.reset(thetaMinus_raw);
}
// Get f(theta + eps):
theta_inputs[input_index] = thetaPlus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandlePtr fPlus(f_outputs[0]);
// Get f(theta - eps):
theta_inputs[input_index] = thetaMinus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandlePtr fMinus(f_outputs[0]);
// Take Difference of both estimates: (f(theta + eps) - f(theta - eps)).
TF_RETURN_IF_ERROR(ops::Sub(ctx, {fPlus.get(), fMinus.get()},
absl::MakeSpan(f_outputs), "sub_top"));
AbstractTensorHandlePtr fDiff(f_outputs[0]);
// Calculate using the difference quotient definition:
// (f(theta + eps) - f(theta - eps)) / (2 * eps).
TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff.get(), two_eps.get()},
absl::MakeSpan(f_outputs), "diff_quotient"));
AbstractTensorHandlePtr diff_quotient(f_outputs[0]);
TF_Tensor* grad_tensor;
TF_RETURN_IF_ERROR(GetValue(diff_quotient.get(), &grad_tensor));
float grad_data[1];
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
TF_TensorByteSize(grad_tensor));
TF_DeleteTensor(grad_tensor);
dtheta_approx[i] = grad_data[0];
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
TF_DeleteTensor(theta_tensor);
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
namespace tensorflow {
namespace gradients {
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
* parameter specified by `input_index`.
*
* I.e. if y = <output of the forward model> and w = inputs[input_index],
* this will calculate dy/dw numerically.
*
* `use_function` indicates whether to use graph mode(true) or eager(false).
*
* `numerical_grad` is the pointer to the AbstractTensorHandle* which will
* hold the numerical gradient data at the end of the function.
*/
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle* const> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_

View File

@ -0,0 +1,175 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
void CompareNumericalAndManualGradients(
Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, int input_index,
float* expected_grad, int num_grad, bool use_function,
double abs_error = 1e-2) {
Status s;
AbstractTensorHandlePtr numerical_grad;
{
AbstractTensorHandle* numerical_grad_raw;
s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
&numerical_grad_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
numerical_grad.reset(numerical_grad_raw);
}
TF_Tensor* numerical_tensor;
s = GetValue(numerical_grad.get(), &numerical_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
ASSERT_EQ(num_elem_numerical, num_grad);
float* dnumerical = new float[num_elem_numerical]{0};
memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
TF_TensorByteSize(numerical_tensor));
for (int j = 0; j < num_grad; j++) {
ASSERT_NEAR(dnumerical[j], expected_grad[j], abs_error);
}
delete[] dnumerical;
TF_DeleteTensor(numerical_tensor);
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::MatMul(ctx, inputs, outputs, "MatMul",
/*transpose_a=*/false,
/*transpose_b=*/false);
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Mul(ctx, inputs, outputs, "Mul");
}
// TODO(vnvo2409): Add more tests from `python/ops/gradient_checker_v2_test.py`.
// These tests should not be confused with `[*]_grad_test` which compare the
// result of `gradient_checker` and `[*]_grad`. The tests here test the
// functionality of `gradient_checker` by comparing the result with expected
// manual user-provided gradients.
class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
}
}
AbstractContextPtr ctx_;
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
};
TEST_P(GradientCheckerTest, TestMatMul) {
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
AbstractTensorHandlePtr B;
{
AbstractTensorHandle* B_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
B.reset(B_raw);
}
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
MatMulModel, ctx_.get(), {A.get(), B.get()}, 0, expected_dA, 4,
UseFunction()));
}
TEST_P(GradientCheckerTest, TestMul) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
float expected_dx[1] = {7.0f};
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
MulModel, ctx_.get(), {x.get(), y.get()}, 0, expected_dx, 1,
UseFunction()));
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,489 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
namespace {
// TODO(b/172558015): Using the pointer address as the identifier for the tensor
// may lead to collisions. Introduce another way to get a unique id for this
// tensor.
int64 ToId(const AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(t)).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
} // namespace
Status GradientRegistry::Register(
const string& op_name, GradientFunctionFactory gradient_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, gradient_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* gradient_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
gradient_function->reset(iter->second(op));
return Status::OK();
}
TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
handle_->Ref();
}
TapeTensor::~TapeTensor() { handle_->Unref(); }
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
tensorflow::DataType TapeTensor::GetDType() const {
return handle_->DataType();
}
AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
// Returns the number of elements in the gradient tensor.
int64 NumElements(AbstractTensorHandle* tensor) const override;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
// Calls the passed-in backward function.
// op_type is the op's name provided in RecordOperation.
Status CallBackwardFunction(
const string& op_type, GradientFunction* gradient_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
absl::Span<AbstractTensorHandle*> result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override;
// Converts a Gradient to a TapeTensor.
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
void MarkAsResult(AbstractTensorHandle* gradient) const override;
void DeleteGradient(AbstractTensorHandle* gradient) const override;
private:
// The context where the aggregation op `Add` is to be created.
AbstractContext* ctx_;
};
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
// TODO(srbs): It seems like this is used only for performance optimization
// and not for correctness. The only downside of keeping this 1 seems to be
// that the gradient accumulation is unbounded and we will never
// aggressively aggregate accumulated gradients to recover memory.
// Revisit and fix.
return 1;
}
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* TapeVSpace::AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
if (gradient_tensors.size() == 1) {
return gradient_tensors[0];
}
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
s = op->AddInputList(gradient_tensors);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Calls the passed-in backward function.
// op_type is the op's name provided in RecordOperation.
Status TapeVSpace::CallBackwardFunction(
const string& op_type, GradientFunction* gradient_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
absl::Span<AbstractTensorHandle*> result) const {
if (gradient_function == nullptr) {
return errors::InvalidArgument(
"Provided null gradient_function for '", op_type, "'.\n",
"If the intent is to treat this op as non-differentiable consider "
"using RegisterNotDifferentiable or "
"NotDifferentiableGradientFunction.");
}
return gradient_function->Compute(ctx_, output_gradients, result);
}
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const {
AbstractOperationPtr op(ctx_->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(t.GetHandle())).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t.GetHandle()));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
// Looks up the ID of a Gradient.
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
return ToId(tensor);
}
// Converts a Gradient to a TapeTensor.
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
return TapeTensor(g);
}
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Unref();
}
void Tape::Watch(const AbstractTensorHandle* t) {
GradientTape::Watch(ToId(t));
}
void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle* const> outputs,
GradientFunction* gradient_function,
const string& op_name) {
std::vector<int64> input_ids(inputs.size());
std::vector<tensorflow::DataType> input_dtypes(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
input_ids[i] = ToId(inputs[i]);
input_dtypes[i] = inputs[i]->DataType();
}
std::vector<TapeTensor> tape_tensors;
for (auto t : outputs) {
tape_tensors.push_back(TapeTensor(t));
}
GradientTape::RecordOperation(
op_name, tape_tensors, input_ids, input_dtypes,
[gradient_function]() -> GradientFunction* { return gradient_function; },
[](GradientFunction* ptr) {
if (ptr) {
delete ptr;
}
});
}
bool Tape::ShouldRecord(
absl::Span<const AbstractTensorHandle* const> tensors) const {
std::vector<int64> tensor_ids(tensors.size());
std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
for (int i = 0; i < tensors.size(); i++) {
tensor_ids[i] = ToId(tensors[i]);
tensor_dtypes[i] = tensors[i]->DataType();
}
return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
}
void Tape::DeleteTrace(const AbstractTensorHandle* t) {
GradientTape::DeleteTrace(ToId(t));
}
std::vector<int64> MakeTensorIDList(
absl::Span<AbstractTensorHandle* const> tensors) {
std::vector<int64> ids(tensors.size());
for (int i = 0; i < tensors.size(); i++) {
ids[i] = ToId(tensors[i]);
}
return ids;
}
Status Tape::ComputeGradient(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
absl::Span<AbstractTensorHandle* const> sources,
absl::Span<AbstractTensorHandle* const> output_gradients,
absl::Span<AbstractTensorHandle*> result) {
TapeVSpace vspace(ctx);
std::vector<int64> target_tensor_ids = MakeTensorIDList(targets);
std::vector<int64> source_tensor_ids = MakeTensorIDList(sources);
tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(
source_tensor_ids.begin(), source_tensor_ids.end());
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (int i = 0; i < target_tensor_ids.size(); ++i) {
int64 target_id = target_tensor_ids[i];
if (sources_set.find(target_id) != sources_set.end()) {
auto tensor = targets[i];
sources_that_are_targets.insert(
std::make_pair(target_id, TapeTensor(tensor)));
}
}
TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
output_gradients, result, /*build_default_zeros_grads*/ false));
return Status::OK();
}
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to facilitate testing and are subject to change.
namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
forward_op_->attrs.Reset(op);
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInput(input));
forward_op_->inputs.push_back(input);
return Status::OK();
}
Status AddInputList(AbstractOperation* op_,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_->inputs.push_back(input);
}
return Status::OK();
}
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
const char* data, size_t length,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
return op_->SetAttrString(attr_name, data, length);
}
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
return op_->SetAttrInt(attr_name, value);
}
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrFloat(attr_name, value);
}
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrBool(attr_name, value);
}
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
DataType value, ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrType(attr_name, value);
}
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
const int64_t* dims, const int num_dims,
ForwardOperation* forward_op_) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_->attrs.Set(attr_name, proto);
return op_->SetAttrShape(attr_name, dims, num_dims);
}
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
const AbstractOperation* value,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
const char* value, size_t length,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
AbstractTensorInterface* tensor,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation* forward_op_) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_->attrs.Set(attr_name, v);
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
const float* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return op_->SetAttrFloatList(attr_name, values, num_values);
}
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
const int64_t* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return op_->SetAttrIntList(attr_name, values, num_values);
}
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return op_->SetAttrTypeList(attr_name, values, num_values);
}
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation* forward_op_) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return op_->SetAttrBoolList(attr_name, values, num_values);
}
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation* forward_op_) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation* forward_op_, Tape* tape,
const GradientRegistry& registry) {
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_->attrs.BuildNodeDef();
std::unique_ptr<GradientFunction> gradient_fn;
TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
op_->Name());
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,176 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
namespace tensorflow {
namespace gradients {
// =============== Experimental C++ API for computing gradients ===============
// Sample gradient function:
//
// class AddGradientFunction : public GradientFunction {
// public:
// Status Compute(Context* ctx,
// absl::Span<AbstractTensorHandle* const> grad_inputs,
// absl::Span<AbstractTensorHandle*> grad_outputs) override {
// grad_outputs[0] = grad_inputs[0];
// grad_outputs[1] = grad_inputs[0];
// grad_outputs[0]->Ref();
// grad_outputs[1]->Ref();
// return Status::OK();
// }
// ~AddGradientFunction() override {}
// };
//
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
// // More complex gradient functions can use inputs/attrs etc. from the
// // forward `op`.
// return new AddGradientFunction;
// }
//
// Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer);
// }
class GradientFunction {
public:
virtual Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
struct ForwardOperation {
public:
string op_name;
std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs;
std::vector<int64> skip_input_indices;
AttrBuilder attrs;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
GradientFunctionFactory gradient_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* gradient_function) const;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
};
// TODO(srbs): Figure out if we can avoid declaring this in the public header.
// Wrapper for a tensor output of an operation executing under a tape.
//
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
// the op that produced it (or -1 if this tensor was watched using
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
// inputs and outputs and the gradient function These data structures combined
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
// Figure out a way to avoid this.
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor {
public:
explicit TapeTensor(AbstractTensorHandle* handle);
TapeTensor(const TapeTensor& other);
~TapeTensor();
tensorflow::int64 GetID() const;
tensorflow::DataType GetDType() const;
AbstractTensorHandle* ZerosLike() const;
AbstractTensorHandle* GetHandle() const;
private:
AbstractTensorHandle* handle_;
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this tape must support handling null incoming
// gradients.
class Tape : protected eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor> {
public:
using GradientTape<AbstractTensorHandle, GradientFunction,
TapeTensor>::GradientTape;
// Returns whether the tape is persistent, i.e., whether the tape will hold
// onto its internal state after a call to `ComputeGradient`.
using GradientTape<AbstractTensorHandle, GradientFunction,
TapeTensor>::IsPersistent;
// Adds this tensor to the list of watched tensors.
//
// This is a no-op if the tensor is already being watched either from an
// earlier call to `GradientTape::Watch` or being an output of an op with
// watched inputs.
void Watch(const AbstractTensorHandle*);
// Records an operation with given inputs and outputs
// on the tape and marks all its outputs as watched if at
// least one input of the op is watched and has a trainable dtype.
// op_name is optional and is used for debugging only.
void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle* const> outputs,
GradientFunction* gradient_function,
const string& op_name = "");
// Returns whether any tensor in a list of tensors is being watched and has
// a trainable dtype.
bool ShouldRecord(
absl::Span<const AbstractTensorHandle* const> tensors) const;
// Unwatches this tensor on the tape. Mainly used for cleanup when deleting
// eager tensors.
void DeleteTrace(const AbstractTensorHandle*);
// Consumes the internal state of the tape (so cannot be called more than
// once unless the tape is persistent) and produces the gradient of the target
// tensors with respect to the source tensors. The output gradients are used
// if not empty and not null. The result is populated with one tensor per
// target element.
Status ComputeGradient(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
absl::Span<AbstractTensorHandle* const> sources,
absl::Span<AbstractTensorHandle* const> output_gradients,
absl::Span<AbstractTensorHandle*> result);
};
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
// Records the op name in the `ForwardOperation`.
Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
ForwardOperation*);
// Records the inputs in the `ForwardOperation`.
Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
Status AddInputList(AbstractOperation*,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation*);
// Sets the attrs in the `ForwardOperation`.
Status SetAttrString(AbstractOperation*, const char* attr_name,
const char* data, size_t length, ForwardOperation*);
Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
ForwardOperation*);
Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
ForwardOperation*);
Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
ForwardOperation*);
Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
ForwardOperation*);
Status SetAttrShape(AbstractOperation*, const char* attr_name,
const int64_t* dims, const int num_dims, ForwardOperation*);
Status SetAttrFunction(AbstractOperation*, const char* attr_name,
const AbstractOperation* value, ForwardOperation*);
Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
const char* value, size_t length, ForwardOperation*);
Status SetAttrTensor(AbstractOperation*, const char* attr_name,
AbstractTensorInterface* tensor, ForwardOperation*);
Status SetAttrStringList(AbstractOperation*, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation*);
Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
const float* values, int num_values, ForwardOperation*);
Status SetAttrIntList(AbstractOperation*, const char* attr_name,
const int64_t* values, int num_values, ForwardOperation*);
Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation*);
Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation*);
Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation*);
Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation*);
// Make the call to `Tape::RecordOperation`.
Status Execute(AbstractOperation*, AbstractContext*,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation*, Tape*, const GradientRegistry&);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_

View File

@ -0,0 +1,891 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/not_differentiable.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
// AddV2Registerer.
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
return Status::OK();
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs),
"Add")); // Compute x+y.
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto add_output : add_outputs) {
add_output->Unref();
}
return Status::OK();
}
// Computes
// y = exp(inputs[0])
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/exp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
return Status::OK();
}
// Computes
// y = sqrt(inputs[0])
// return grad(y, {inputs[0]})
Status SqrtGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sqrt_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]);
tape->Watch(inputs[1]);
vector<AbstractTensorHandle*> identity_n_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
/*targets=*/{identity_n_outputs[1]},
/*sources=*/{inputs[0], inputs[1]},
/*output_gradients=*/{}, outputs));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
return Status::OK();
}
// Computes
// y = - inputs[0]
// return grad(y, {inputs[0]})
Status NegGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]);
std::vector<AbstractTensorHandle*> neg_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/neg_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto neg_output : neg_outputs) {
neg_output->Unref();
}
return Status::OK();
}
// Computes
// y = inputs[0] - inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status SubGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sub_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
return Status::OK();
}
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> mul_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
absl::MakeSpan(mul_outputs),
"Mul")); // Compute x*y.
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mul_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto mul_output : mul_outputs) {
mul_output->Unref();
}
delete tape;
return Status::OK();
}
// Computes
// y = log(1 + inputs[0])
// return grad(y, {inputs[0]})
Status Log1pGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> log1p_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
absl::MakeSpan(log1p_outputs),
"Log1p")); // Compute log(1 + x).
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/log1p_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto log1p_output : log1p_outputs) {
log1p_output->Unref();
}
delete tape;
return Status::OK();
}
// Computes
// y = inputs[0] / inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status DivNoNanGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> div_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
absl::MakeSpan(div_outputs),
"DivNoNan")); // Compute x / y.
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/div_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto div_output : div_outputs) {
div_output->Unref();
}
delete tape;
return Status::OK();
}
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
TEST_P(CppGradients, TestAddGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x + y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestExpGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
Status s =
RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 2.718, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSqrtGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// y = sqrt(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
Status s =
RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestNegGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = - x
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSubGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x - y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
Status s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, -1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x * y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
Status s = RunModel(MulGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 2.0);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestLog1pGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// y = log(1 + x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
Status s =
RunModel(Log1pGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestDivNoNanGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x / y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
Status s = RunModel(DivNoNanGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 0.5, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, -0.25, 0.001);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr t;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
t.reset(x_raw);
}
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
ForwardOperation forward_op;
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
/*raw_device_name=*/nullptr, &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
if (isa<TracingOperation>(check_numerics_op.get())) {
s = dyn_cast<TracingOperation>(check_numerics_op.get())
->SetOpName("check_numerics");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
s = AddInput(check_numerics_op.get(), t.get(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string message = "This is the way!";
s = SetAttrString(check_numerics_op.get(), "message", message.data(),
message.length(), &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry;
s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto tape = std::make_unique<Tape>(/*persistent=*/false);
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
string read_message;
s = forward_op.attrs.Get("message", &read_message);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(read_message, message);
}
Status RecordOperationWithNullGradientFunctionModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> neg_outputs(1);
TF_RETURN_IF_ERROR(ops::Neg(ctx, inputs, absl::MakeSpan(neg_outputs), "Neg"));
tape.RecordOperation(inputs, neg_outputs, nullptr, "Neg");
return tape.ComputeGradient(ctx, /*targets=*/neg_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs);
}
TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(RecordOperationWithNullGradientFunctionModel, ctx.get(),
{x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_EQ(
"Provided null gradient_function for 'Neg'.\nIf the intent is to treat "
"this op as non-differentiable consider using RegisterNotDifferentiable "
"or NotDifferentiableGradientFunction.",
s.error_message());
ASSERT_EQ(nullptr, outputs[0]);
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,320 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
TFE_TensorHandle** result) {
float data[] = {value};
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) {
AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) {
y.reset(y_raw);
}
return y;
}
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate) {
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,88 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gradients {
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data.
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val);
// Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs given model in either graph or eager mode depending on value of
// use_function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
// Builds context and returns inside *ctx.
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More