merge with master

This commit is contained in:
Daniel Nguyen 2020-08-13 23:35:48 +00:00
commit 19251497c8
1745 changed files with 69313 additions and 24074 deletions

View File

@ -18,8 +18,10 @@
#
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# c++1z: Build with C++17 options
# c++17: Build with C++17 options (links with libc++)
# c++1z: Build with C++17 options (links with libc++)
# c++17_gcc: Build with C++17 options (links with stdlibc++)
# c++1z_gcc: Build with C++17 options (links with stdlibc++)
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# native_arch_linux: Build with instruction sets available to the host machine on linux
@ -165,14 +167,29 @@ build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
build:using_cuda --action_env TF_NEED_CUDA=1
build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
# Enable the mlir generated GPU kernels only for cuda builds.
build --define=tensorflow_enable_mlir_generated_gpu_kernels=0
# This is a more specific option, so it takes precedence over the line above for cuda builds.
build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1
# This config refers to building CUDA op kernels with nvcc.
build:cuda --config=using_cuda
build:cuda --define=using_cuda_nvcc=true
@ -263,6 +280,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
@ -353,7 +372,6 @@ build --config=v2
test --config=v2
# Enable XLA
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS

View File

@ -123,20 +123,21 @@ Build Type | Status
### Community Supported Builds
Build Type | Status | Artifacts
----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly <br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
Build Type | Status | Artifacts
----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/)
**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/)
**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/)
**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/)
**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/)
**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/)
**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/)
**Linux aarch64 CPU** Nightly <br> Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master)
**Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/)
**Red Hat® Enterprise Linux® 7.6 CPU & GPU** <br> Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/)
## Resources

View File

@ -22,10 +22,21 @@
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
* Start enforcing input shape assumptions when calling Functional API Keras
models. This may potentially break some users, in case there is a mismatch
between the shape used when creating `Input` objects in a Functional model,
and the shape of the data passed to that model. You can fix this mismatch by
either calling the model with correctly-shaped data, or by relaxing `Input`
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
## Known Caveats
@ -65,6 +76,11 @@
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added support for tf.data service dispatcher fault tolerance. To enable
fault tolerance, configure a `work_dir` when running your dispatcher
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
store its state to `work_dir`, so that on restart it can continue from its
previous state after restart.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
@ -74,9 +90,11 @@
option.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Given the same seed, the stateless functions
produce the same results independent of how many times the function is
called, and independent of global seed settings.
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
@ -88,24 +106,53 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` hyperparameter to FTRL optimizer classes (Keras and others)
to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf).
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the previous
iteration. These types of loops must run at least one iteration, and will
raise a runtime error otherwise.
Example:
```
for batch in data:
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
* `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
* <ADD RELEASE NOTES HERE>
* TPU Enhancements:
* Added support for the `beta` parameter of the FTRL optimizer for TPU
embeddings. Users of other TensorFlow platforms can implement equivalent
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>

View File

@ -16,5 +16,5 @@
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python %configure_dir%\configure.py %* || ( exit /b )
python "%configure_dir%\configure.py" %* || ( exit /b )
echo Configuration finished

View File

@ -260,6 +260,36 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "armeabi",
values = {"cpu": "armeabi"},
visibility = ["//visibility:public"],
)
config_setting(
name = "armeabi-v7a",
values = {"cpu": "armeabi-v7a"},
visibility = ["//visibility:public"],
)
config_setting(
name = "arm64-v8a",
values = {"cpu": "arm64-v8a"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "arm_any",
match_any = [
":arm",
":armeabi",
":armeabi-v7a",
":arm64-v8a",
":linux_aarch64",
":linux_armhf",
],
)
config_setting(
name = "freebsd",
values = {"cpu": "freebsd"},

View File

@ -137,7 +137,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -147,7 +147,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"c_api_macros.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
@ -61,6 +62,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"c_api_macros.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
@ -79,6 +81,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"c_api_macros.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
@ -213,6 +216,17 @@ tf_cuda_library(
alwayslink = 1,
)
cc_library(
name = "logging",
srcs = ["logging.cc"],
hdrs = ["logging.h"],
deps = [
":c_api_macros",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:stringprintf",
],
)
tf_cuda_library(
name = "tf_status_internal",
hdrs = [
@ -299,6 +313,7 @@ cc_library(
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
@ -325,6 +340,7 @@ tf_cuda_library(
],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",

View File

@ -213,7 +213,6 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers,
namespace tensorflow {
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out) {
if (out->data != nullptr) {
@ -306,8 +305,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
}
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
Status LoadDynamicLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
@ -552,7 +551,7 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary(
status->status = tensorflow::LoadDynamicLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length);
if (!status->status.ok()) {

View File

@ -30,4 +30,17 @@ limitations under the License.
#endif // _WIN32
#endif // SWIG
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
// the datatype for boolean tensors.
#ifndef TF_Bool
#define TF_Bool unsigned char
#endif // TF_Bool
// Macro used to calculate struct size for maintaining ABI stability across
// different struct implementations.
#ifndef TF_OFFSET_OF_END
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
#endif // TF_OFFSET_OF_END
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -240,6 +240,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",
@ -249,6 +250,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -508,6 +510,27 @@ tf_cuda_cc_test(
],
)
tf_cuda_library(
name = "c_api_remote_test_util",
testonly = 1,
srcs = ["c_api_remote_test_util.cc"],
hdrs = ["c_api_remote_test_util.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_remote_test",
size = "small",
@ -524,6 +547,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_remote_test_util",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
@ -540,6 +564,25 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_remote_function_test",
size = "small",
srcs = [
"c_api_remote_function_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
deps = [
":c_api_remote_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",

View File

@ -724,7 +724,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
return tensorflow::wrap(new tfrt::ContextInterface(opts->async));
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;

View File

@ -518,7 +518,8 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunctionWithGraphError();
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -0,0 +1,56 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
// TODO(b/162618595): Enable this test once we remove the check of remote
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
} // namespace

View File

@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -115,225 +117,24 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
}
} // namespace

View File

@ -0,0 +1,215 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using ::tensorflow::string;
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

View File

@ -30,6 +30,9 @@ using tensorflow::string;
namespace tensorflow {
namespace {
// The tests are parameterized on:
// - a string representing the tracing implementation: "mlir" or "graphdef".
// - a boolean that when true enables TFRT as the execution engine.
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
@ -983,6 +986,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
// The above tests are run for a combination of:
// - graphdef and MLIR tracing engine
// - Using TFRT as an execution runtime (true == enable TFRT)
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
@ -23,25 +24,97 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
namespace {
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("ZerosLike", ToId(t)).c_str()));
}
TF_RETURN_IF_ERROR(op->AddInput(t));
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(
op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs));
*result = outputs[0];
return Status::OK();
}
} // namespace
class IncomingGradientsImpl : public IncomingGradients {
public:
explicit IncomingGradientsImpl(
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
DefaultGradientFunction* default_gradients)
: grad_inputs_(grad_inputs),
ctx_(ctx),
default_gradients_(default_gradients) {}
AbstractTensorHandle* operator[](int i) const override {
return default_gradients_->get(ctx_, grad_inputs_, i);
}
size_t size() const override { return grad_inputs_.size(); }
private:
absl::Span<AbstractTensorHandle* const> grad_inputs_;
Context* ctx_;
DefaultGradientFunction* default_gradients_;
};
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
: outputs_(op.outputs) {
for (auto output : outputs_) {
output->Ref();
}
}
AbstractTensorHandle* AllZerosDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
if (grad_inputs[i]) {
return grad_inputs[i];
}
if (cached_default_grads_[i]) {
return cached_default_grads_[i].get();
}
AbstractTensorHandle* result = nullptr;
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
if (!s.ok()) {
if (result) {
result->Unref();
}
VLOG(1) << "Failed to create ZerosLike for index " << i;
return nullptr;
}
cached_default_grads_[i].reset(result);
return result;
}
PassThroughDefaultGradients::PassThroughDefaultGradients(
const ForwardOperation& op) {}
AbstractTensorHandle* PassThroughDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
return grad_inputs[i];
}
Status GradientRegistry::Register(
const string& op_name, BackwardFunctionFactory backward_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
registry_.insert({op_name, backward_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
std::unique_ptr<BackwardFunction>* backward_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
backward_function->reset(iter->second(op));
return Status::OK();
}
@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const {
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
IncomingGradientsImpl incoming_gradients(
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
return backward_function->GetGradientFunction()->Compute(
&ctx, incoming_gradients, result);
}
// Looks up the ID of a Gradient.
@ -363,21 +414,25 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
[registry, forward_op_]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry.Lookup(*forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
return backward_fn.release();
},
[](GradientFunction* ptr) {
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}

View File

@ -55,18 +55,25 @@ struct Context {
public:
AbstractContext* ctx;
};
class IncomingGradients {
public:
virtual AbstractTensorHandle* operator[](int i) const = 0;
virtual size_t size() const = 0;
virtual ~IncomingGradients() {}
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
// gradient registerer to instantiate a BackwardFunction.
struct ForwardOperation {
public:
string op_name;
@ -76,18 +83,86 @@ struct ForwardOperation {
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
// Interface for building default zeros gradients for op outputs which are
// missing incoming gradients. Custom implementations of this can be used to
// control which of the forward op's output tensors/their metadata needs to
// be kept around in memory to build the default zeros grad.
//
// Some common helper implementations are provided below.
class DefaultGradientFunction {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
virtual AbstractTensorHandle* get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) = 0;
virtual ~DefaultGradientFunction() {}
};
// Returns zeros for any `nullptr` in `grad_inputs`.
//
// This may require keeping track of all of forward op's output
// tensors and hence may incur a higher memory footprint. Use sparingly.
//
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
// handle.
//
// The destructor of this class `Unref`'s any cached tensor handles so users of
// those tensor handles should `Ref` them in order to keep them alive if needed.
class AllZerosDefaultGradients : public DefaultGradientFunction {
public:
explicit AllZerosDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
// TODO(srbs): We do not always need to keep the tensors around. In immediate
// execution mode we just need to store the shape and dtype. During tracing
// we may need to keep the tensor around if the shape is not full defined.
std::vector<AbstractTensorHandle*> outputs_;
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
};
// Passes through `grad_inputs` as-is. The `GradientFunction`
// will be expected to deal with nullptr in `grad_inputs` if any.
class PassThroughDefaultGradients : public DefaultGradientFunction {
public:
explicit PassThroughDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
};
// A `BackwardFunction` wraps a `GradientFunction` and a
// `DefaultGradientFunction`. Both are owned by this class' instance.
class BackwardFunction {
public:
BackwardFunction(GradientFunction* gradient_function,
DefaultGradientFunction* default_gradients)
: gradient_function_(gradient_function),
default_gradients_(default_gradients) {}
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
DefaultGradientFunction* GetDefaultGradientFunction() {
return default_gradients_.get();
}
private:
std::unique_ptr<GradientFunction> gradient_function_;
std::unique_ptr<DefaultGradientFunction> default_gradients_;
};
using BackwardFunctionFactory =
std::function<BackwardFunction*(const ForwardOperation& op)>;
// Map from op name to a `BackwardFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
BackwardFunctionFactory backward_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const;
private:
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t);
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// This also implements `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
//
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
// of default zeros grads is handled by the `DefaultGradientFunction` registered
// for each op.
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
// Figure out a way to avoid this.
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
@ -123,7 +205,7 @@ class TapeTensor {
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
// The context where OnesLike ops are to be created.
AbstractContext* ctx_;
};
@ -132,7 +214,7 @@ class TapeTensor {
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
@ -147,7 +229,7 @@ class TapeVSpace
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
@ -168,8 +250,14 @@ class TapeVSpace
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this library support handling null incoming
// gradients. `Tape::ComputeGradient` should be called with
// `build_default_zeros_grads=false`. Calling with
// `build_default_zeros_grads=true` (the default) is equivalent but just results
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
BackwardFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -23,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
@ -35,6 +37,8 @@ namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
@ -45,7 +49,10 @@ class CppGradients
};
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
return Status::OK();
}
// Computes `inputs[0] + inputs[1]` and records it on the tape.
@ -58,9 +65,9 @@ Status Add(AbstractContext* ctx, Tape* tape,
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
@ -69,6 +76,46 @@ Status Add(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
@ -91,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx,
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto add_output : add_outputs) {
add_output->Unref();
}
@ -101,6 +149,71 @@ Status AddGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = exp(inputs[0])
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
absl::MakeSpan(identity_n_outputs), registry));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -132,26 +245,42 @@ Status RunModel(Model model, AbstractContext* ctx,
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Unref();
output_list.outputs[1]->Unref();
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
@ -160,8 +289,19 @@ Status RunModel(Model model, AbstractContext* ctx,
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
@ -264,13 +404,118 @@ TEST_P(CppGradients, TestAddGrad) {
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestExpGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 2.718, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
// supported. It is needed for IdentityN.
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(true, false),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(

View File

@ -146,13 +146,16 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
// When running backward functions, builds zeros-like tensors for
// incoming grads which are nullptrs, unless `build_default_zeros_grads`
// is set to false.
Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result);
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
bool IsPersistent() const { return persistent_; }
@ -655,8 +658,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
bool build_default_zeros_grads) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
@ -717,14 +720,14 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() &&
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
out_gradients.push_back(nullptr);
zero_indices.push_back(i);
out_gradients.push_back(nullptr);
if (build_default_zeros_grads) {
auto func_name_it =
FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type);
if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() ||
func_name_it->second.find(i) == func_name_it->second.end()) {
zero_indices.push_back(i);
}
}
} else {
any_gradient_nonzero = true;
@ -745,6 +748,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
}
std::vector<Gradient*> in_gradients;
DCHECK(build_default_zeros_grads || zero_indices.empty());
if (any_gradient_nonzero) {
for (const auto i : zero_indices) {
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

View File

@ -191,8 +191,8 @@ void* TF_LoadSharedLibrary(const char* library_filename, TF_Status* status) {
void* handle = nullptr;
TF_SetStatus(status, TF_OK, "");
::tensorflow::Set_TF_Status_from_Status(
status,
::tensorflow::Env::Default()->LoadLibrary(library_filename, &handle));
status, ::tensorflow::Env::Default()->LoadDynamicLibrary(library_filename,
&handle));
return handle;
}

View File

@ -35,8 +35,8 @@ using UniquePtrTo_TF_Status =
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
Status ModularFileSystem::NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) {
if (ops_->new_random_access_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewRandomAccessFile()"));
@ -55,8 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
}
Status ModularFileSystem::NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_writable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewWritableFile()"));
@ -75,8 +75,8 @@ Status ModularFileSystem::NewWritableFile(
}
Status ModularFileSystem::NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_appendable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewAppendableFile()"));
@ -95,8 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
}
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) {
if (ops_->new_read_only_memory_region_from_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname,
@ -116,8 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::FileExists(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::FileExists(const std::string& fname,
TransactionToken* token) {
if (ops_->path_exists == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support FileExists()"));
@ -129,9 +129,9 @@ Status ModularFileSystem::FileExists(
return StatusFromTF_Status(plugin_status.get());
}
bool ModularFileSystem::FilesExist(
const std::vector<std::string>& files,
std::vector<Status>* status /*, TransactionToken* token */) {
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
TransactionToken* token,
std::vector<Status>* status) {
if (ops_->paths_exist == nullptr)
return FileSystem::FilesExist(files, status);
@ -162,9 +162,9 @@ bool ModularFileSystem::FilesExist(
return result;
}
Status ModularFileSystem::GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetChildren(const std::string& dir,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_children == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dir, " does not support GetChildren()"));
@ -188,9 +188,9 @@ Status ModularFileSystem::GetChildren(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_matching_paths == nullptr)
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
@ -211,8 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteFile(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteFile(const std::string& fname,
TransactionToken* token) {
if (ops_->delete_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support DeleteFile()"));
@ -224,9 +224,10 @@ Status ModularFileSystem::DeleteFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) {
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
return errors::FailedPrecondition(
"DeleteRecursively must not be called with `undeleted_files` or "
@ -247,8 +248,8 @@ Status ModularFileSystem::DeleteRecursively(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->delete_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support DeleteDir()"));
@ -260,8 +261,8 @@ Status ModularFileSystem::DeleteDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname);
@ -272,8 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::CreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->create_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support CreateDir()"));
@ -285,9 +286,8 @@ Status ModularFileSystem::CreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token */) {
Status ModularFileSystem::Stat(const std::string& fname,
TransactionToken* token, FileStatistics* stat) {
if (ops_->stat == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support Stat()"));
@ -310,8 +310,8 @@ Status ModularFileSystem::Stat(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::IsDirectory(
const std::string& name /*, TransactionToken* token */) {
Status ModularFileSystem::IsDirectory(const std::string& name,
TransactionToken* token) {
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -321,9 +321,9 @@ Status ModularFileSystem::IsDirectory(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token */) {
Status ModularFileSystem::GetFileSize(const std::string& fname,
TransactionToken* token,
uint64* file_size) {
if (ops_->get_file_size == nullptr) {
FileStatistics stat;
Status status = Stat(fname, &stat);
@ -342,9 +342,9 @@ Status ModularFileSystem::GetFileSize(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::RenameFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->rename_file == nullptr) {
Status status = CopyFile(src, target);
if (status.ok()) status = DeleteFile(src);
@ -359,9 +359,9 @@ Status ModularFileSystem::RenameFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CopyFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::CopyFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -372,8 +372,7 @@ Status ModularFileSystem::CopyFile(
return StatusFromTF_Status(plugin_status.get());
}
std::string ModularFileSystem::TranslateName(
const std::string& name /*, TransactionToken* token */) const {
std::string ModularFileSystem::TranslateName(const std::string& name) const {
if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name);
char* p = ops_->translate_name(filesystem_.get(), name.c_str());
@ -385,7 +384,7 @@ std::string ModularFileSystem::TranslateName(
return ret;
}
void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) {
void ModularFileSystem::FlushCaches(TransactionToken* token) {
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
}
@ -462,7 +461,7 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;

View File

@ -59,71 +59,48 @@ class ModularFileSystem final : public FileSystem {
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
Status NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewWritableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile(
const std::string& fname,
std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token = nullptr */) override;
Status FileExists(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status FileExists(const std::string& fname, TransactionToken* token) override;
bool FilesExist(const std::vector<std::string>& files,
std::vector<Status>*
status /*, TransactionToken* token = nullptr */) override;
Status GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token = nullptr */)
override;
Status GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>*
results /*, TransactionToken* token = nullptr */) override;
Status DeleteFile(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override;
Status DeleteDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status CreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token = nullptr */) override;
Status IsDirectory(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token = nullptr */) override;
Status RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token = nullptr */)
override;
Status CopyFile(const std::string& src,
const std::string&
target /*, TransactionToken* token = nullptr */) override;
std::string TranslateName(
const std::string& name /*, TransactionToken* token = nullptr */)
const override;
void FlushCaches(/* TransactionToken* token=nullptr */) override;
TransactionToken* token,
std::vector<Status>* status) override;
Status GetChildren(const std::string& dir, TransactionToken* token,
std::vector<std::string>* result) override;
Status GetMatchingPaths(const std::string& pattern, TransactionToken* token,
std::vector<std::string>* results) override;
Status DeleteFile(const std::string& fname, TransactionToken* token) override;
Status DeleteRecursively(const std::string& dirname, TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) override;
Status DeleteDir(const std::string& dirname,
TransactionToken* token) override;
Status RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) override;
Status CreateDir(const std::string& dirname,
TransactionToken* token) override;
Status Stat(const std::string& fname, TransactionToken* token,
FileStatistics* stat) override;
Status IsDirectory(const std::string& fname,
TransactionToken* token) override;
Status GetFileSize(const std::string& fname, TransactionToken* token,
uint64* file_size) override;
Status RenameFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
Status CopyFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
std::string TranslateName(const std::string& name) const override;
void FlushCaches(TransactionToken* token) override;
private:
std::unique_ptr<TF_Filesystem> filesystem_;

View File

@ -33,7 +33,6 @@ limitations under the License.
// Windows defines the following macros to convert foo to fooA or fooW,
// depending on the type of the string argument. We don't use these macros, so
// undefine them here.
#undef LoadLibrary
#undef CopyFile
#undef DeleteFile
#undef TranslateName

View File

@ -33,6 +33,7 @@ cc_library(
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
@ -663,28 +664,179 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
}
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
static void StatForObject(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
GcsFileStat* stat, TF_Status* status) {
if (object.empty())
return TF_SetStatus(
status, TF_INVALID_ARGUMENT,
("'object' must be a non-empty string. (File: " + path + ")").c_str());
TF_SetStatus(status, TF_OK, "");
gcs_file->stat_cache->LookupOrCompute(
path, stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
}
static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
TF_Status* status) {
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return !stat.base.is_directory;
}
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
TF_Status* status) {
auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static std::vector<std::string> GetChildrenBounded(
GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive,
bool include_self_directory_marker, TF_Status* status) {
std::string bucket, prefix;
MaybeAppendSlash(&dir);
ParseGCSPath(dir, true, &bucket, &prefix, status);
std::vector<std::string> result;
uint64_t count = 0;
std::string delimiter = recursive ? "" : "/";
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) {
if (count == max_results) {
TF_SetStatus(status, TF_OK, "");
return result;
}
if (!item) {
TF_SetStatusFromGCSStatus(item.status(), status);
return result;
}
auto value = *std::move(item);
std::string children = absl::holds_alternative<std::string>(value)
? absl::get<std::string>(value)
: absl::get<gcs::ObjectMetadata>(value).name();
auto pos = children.find(prefix);
if (pos != 0) {
TF_SetStatus(status, TF_INTERNAL,
("Unexpected response: the returned file name " + children +
" doesn't match the prefix " + prefix)
.c_str());
return result;
}
children.erase(0, prefix.length());
if (!children.empty() || include_self_directory_marker) {
result.emplace_back(children);
}
++count;
}
return result;
}
static bool FolderExists(GCSFile* gcs_file, std::string dir,
TF_Status* status) {
ExpiringLRUCache<GcsFileStat>::ComputeFunc compute_func =
[gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) {
auto children =
GetChildrenBounded(gcs_file, dir, 1, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
if (!children.empty()) {
stat->base = {0, 0, true};
return TF_SetStatus(status, TF_OK, "");
} else {
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!");
}
};
GcsFileStat stat;
MaybeAppendSlash(&dir);
gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT)
return false;
if (TF_GetCode(status) == TF_INVALID_ARGUMENT) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->RemoveFile(path);
gcs_file->stat_cache->Delete(path);
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
bool result = BucketExists(gcs_file, bucket, status);
if (result) return TF_SetStatus(status, TF_OK, "");
}
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_NOT_FOUND) return;
bool result = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result))
return;
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string dir = path;
MaybeAppendSlash(&dir);
std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
bool is_directory = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return;
if (!is_directory)
TF_SetStatus(status, TF_NOT_FOUND,
("The specified bucket " + dir + " was not found.").c_str());
return;
}
MaybeAppendSlash(&object);
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
TF_SetStatusFromGCSStatus(object_metadata.status(), status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
auto insert_metadata =
gcs_file->gcs_client.InsertObject(bucket, object, "");
TF_SetStatusFromGCSStatus(insert_metadata.status(), status);
} else if (TF_GetCode(status) == TF_OK) {
PathExists(filesystem, dir.c_str(), status);
if (TF_GetCode(status) == TF_OK)
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
auto metadata = gcs_file->gcs_client.InsertObject(
bucket, object, "",
// Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
// will be returned if the object already exists, so avoid reuploading.
gcs::IfGenerationMatch(0));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
}
// TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the
@ -700,79 +852,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path);
}
// Checks that the directory is empty (i.e no objects with this prefix exist).
// Deletes the GCS directory marker if it exists.
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
MaybeAppendSlash(&object);
// A directory is considered empty either if there are no matching objects
// with the corresponding name prefix or if there is exactly one matching
// object and it is the directory marker. Therefore we need to retrieve
// at most two children for the prefix to detect if a directory is empty.
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
int object_count = 0;
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
++object_count;
// We consider a path is a non-empty directory in two cases:
// - There are more than two objects whose keys start with the name of this
// directory.
// - There is one object whose key contains the name of this directory ( but
// not equal ).
if (object_count > 1 || metadata->name() != object) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
return;
}
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be
// some differents compared to the default implementation. Will be refactored.
static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files,
uint64_t* undeleted_dirs, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) != TF_OK) return;
*undeleted_dirs = 0;
*undeleted_files = 0;
}
// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND`
// if the object does not exist. In that case, we will have to check if the
// `src` is a directory or not to set the correspondent `status` (i.e
// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if
// path `src` is a directory).
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty()))
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (childrens.size() == 1 && childrens[0].empty()) {
// This is the directory marker object. Delete it.
std::string dir = path;
MaybeAppendSlash(&dir);
DeleteFile(filesystem, dir.c_str(), status);
return;
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src);
TF_SetStatusFromGCSStatus(gcs_status, status);
TF_SetStatus(status, TF_OK, "");
}
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
@ -791,31 +895,6 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
// TODO(vnvo2409): This approach can cause a problem when our path is
// `path/to/dir` and there is an object with key `path/to/directory`. Will be
// fixed when refactoring.
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
// We consider a path exists if there is at least one object whose key
// contains the path.
return TF_SetStatus(status, TF_OK, "");
}
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
@ -824,41 +903,127 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK)
return true;
else
return false;
bool result = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!result)
TF_SetStatus(
status, TF_NOT_FOUND,
("The specified bucket gs://" + bucket + " was not found.").c_str());
return result;
}
// We check if there is an object with this key on the GCS server.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
TF_SetStatus(status, TF_OK, "");
if (metadata->name().back() == '/')
return true;
else
return false;
}
bool is_folder = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_folder) return true;
// If there is no object with this key on the GCS server. We check if there is
// any object whose key contains that path.
MaybeAppendSlash(&object);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return false;
}
TF_SetStatus(status, TF_OK, "");
return true;
bool is_object = ObjectExists(gcs_file, path, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_object) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
absl::StrCat("The specified path ", path, " is not a directory.")
.c_str());
return false;
}
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
return false;
}
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
RenameObject(filesystem, src, dst, status);
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string src_dir = src;
std::string dst_dir = dst;
MaybeAppendSlash(&src_dir);
MaybeAppendSlash(&dst_dir);
for (const std::string& children : childrens) {
RenameObject(filesystem, src_dir + children, dst_dir + children, status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files, uint64_t* undeleted_dirs,
TF_Status* status) {
if (!undeleted_files || !undeleted_dirs)
return TF_SetStatus(
status, TF_INTERNAL,
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
*undeleted_files = 0;
*undeleted_dirs = 0;
if (!IsDirectory(filesystem, path, status)) {
*undeleted_dirs = 1;
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string dir = path;
MaybeAppendSlash(&dir);
for (const std::string& children : childrens) {
const std::string& full_path = dir + children;
DeleteFile(filesystem, full_path.c_str(), status);
if (TF_GetCode(status) != TF_OK) {
if (IsDirectory(filesystem, full_path.c_str(), status))
// The object is a directory marker.
(*undeleted_dirs)++;
else
(*undeleted_files)++;
}
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = childrens.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(childrens[i].c_str());
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
@ -896,6 +1061,17 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
}
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
static void FlushCaches(const TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->Flush();
gcs_file->stat_cache->Clear();
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -912,6 +1088,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
@ -921,6 +1104,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir;
ops->filesystem_ops->delete_recursively =
tf_gcs_filesystem::DeleteRecursively;
ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists;
ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory;
ops->filesystem_ops->stat = tf_gcs_filesystem::Stat;
ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName;
ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -29,5 +29,7 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//third_party/hadoop:hdfs",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <sstream>
#include <string>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
@ -162,27 +163,19 @@ class LibHDFS {
void* handle_;
};
static const LibHDFS* libhdfs(TF_Status* status) {
static const LibHDFS* libhdfs = new LibHDFS(status);
return libhdfs;
}
// We rely on HDFS connection caching here. The HDFS client calls
// org.apache.hadoop.fs.FileSystem.get(), which caches the connection
// internally.
hdfsFS Connect(const std::string& path, TF_Status* status) {
auto hdfs_file = libhdfs(status);
if (TF_GetCode(status) != TF_OK) return nullptr;
hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) {
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
std::string scheme, namenode, nodepath;
ParseHadoopPath(path, &scheme, &namenode, &nodepath);
hdfsBuilder* builder = hdfs_file->hdfsNewBuilder();
hdfsBuilder* builder = libhdfs->hdfsNewBuilder();
if (scheme == "file") {
hdfs_file->hdfsBuilderSetNameNode(builder, nullptr);
libhdfs->hdfsBuilderSetNameNode(builder, nullptr);
} else if (scheme == "viewfs") {
char* defaultFS = nullptr;
hdfs_file->hdfsConfGetStr("fs.defaultFS", &defaultFS);
libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS);
std::string defaultScheme, defaultCluster, defaultPath;
ParseHadoopPath(defaultFS, &defaultScheme, &defaultCluster, &defaultPath);
@ -195,17 +188,17 @@ hdfsFS Connect(const std::string& path, TF_Status* status) {
// The default NameNode configuration will be used (from the XML
// configuration files). See:
// https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259
hdfs_file->hdfsBuilderSetNameNode(builder, "default");
libhdfs->hdfsBuilderSetNameNode(builder, "default");
} else if (scheme == "har") {
std::string path_har = path;
SplitArchiveNameAndPath(&path_har, &namenode, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
hdfs_file->hdfsBuilderSetNameNode(builder, namenode.c_str());
libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str());
} else {
hdfs_file->hdfsBuilderSetNameNode(
libhdfs->hdfsBuilderSetNameNode(
builder, namenode.empty() ? "default" : namenode.c_str());
}
auto fs = hdfs_file->hdfsBuilderConnect(builder);
auto fs = libhdfs->hdfsBuilderConnect(builder);
if (fs == nullptr)
TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno));
else
@ -216,16 +209,178 @@ hdfsFS Connect(const std::string& path, TF_Status* status) {
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct HDFSFile {
std::string path;
std::string hdfs_path;
hdfsFS fs;
LibHDFS* libhdfs;
absl::Mutex mu;
hdfsFile handle ABSL_GUARDED_BY(mu);
HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs,
hdfsFile handle)
: path(std::move(path)),
hdfs_path(std::move(hdfs_path)),
fs(fs),
libhdfs(libhdfs),
mu(),
handle(handle) {}
} HDFSFile;
// TODO(vnvo2409): Implement later
void Cleanup(TF_RandomAccessFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
{
absl::MutexLock l(&hdfs_file->mu);
if (hdfs_file->handle != nullptr) {
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
}
}
delete hdfs_file;
}
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
auto libhdfs = hdfs_file->libhdfs;
auto fs = hdfs_file->fs;
auto hdfs_path = hdfs_file->hdfs_path.c_str();
auto path = hdfs_file->path.c_str();
char* dst = buffer;
bool eof_retried = false;
int64_t r = 0;
while (TF_GetCode(status) == TF_OK && !eof_retried) {
// We lock inside the loop rather than outside so we don't block other
// concurrent readers.
absl::MutexLock l(&hdfs_file->mu);
auto handle = hdfs_file->handle;
// Max read length is INT_MAX-2, for hdfsPread function take a parameter
// of int32. -2 offset can avoid JVM OutOfMemoryError.
size_t read_n =
(std::min)(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
r = libhdfs->hdfsPread(fs, handle, static_cast<tOffset>(offset), dst,
static_cast<tSize>(read_n));
if (r > 0) {
dst += r;
n -= r;
offset += r;
} else if (!eof_retried && r == 0) {
// Always reopen the file upon reaching EOF to see if there's more data.
// If writers are streaming contents while others are concurrently
// reading, HDFS requires that we reopen the file to see updated
// contents.
//
// Fixes #5438
if (handle != nullptr && libhdfs->hdfsCloseFile(fs, handle) != 0) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0);
if (handle == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
eof_retried = true;
} else if (eof_retried && r == 0) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
} else if (errno == EINTR || errno == EAGAIN) {
// hdfsPread may return EINTR too. Just retry.
} else {
TF_SetStatusFromIOError(status, errno, path);
}
}
return r;
}
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct HDFSFile {
std::string hdfs_path;
hdfsFS fs;
LibHDFS* libhdfs;
hdfsFile handle;
HDFSFile(std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, hdfsFile handle)
: hdfs_path(std::move(hdfs_path)),
fs(fs),
libhdfs(libhdfs),
handle(handle) {}
} HDFSFile;
// TODO(vnvo2409): Implement later
static void Cleanup(TF_WritableFile* file) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle);
hdfs_file->fs = nullptr;
hdfs_file->handle = nullptr;
delete hdfs_file;
}
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
auto libhdfs = hdfs_file->libhdfs;
auto fs = hdfs_file->fs;
auto handle = hdfs_file->handle;
size_t cur_pos = 0, write_len = 0;
bool retry = false;
// max() - 2 can avoid OutOfMemoryError in JVM .
static const size_t max_len_once =
static_cast<size_t>(std::numeric_limits<tSize>::max() - 2);
while (cur_pos < n) {
write_len = (std::min)(n - cur_pos, max_len_once);
tSize w = libhdfs->hdfsWrite(fs, handle, buffer + cur_pos,
static_cast<tSize>(write_len));
if (w == -1) {
if (!retry && (errno == EINTR || errno == EAGAIN)) {
retry = true;
} else {
return TF_SetStatusFromIOError(status, errno,
hdfs_file->hdfs_path.c_str());
}
} else {
cur_pos += w;
}
}
TF_SetStatus(status, TF_OK, "");
}
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
int64_t position =
hdfs_file->libhdfs->hdfsTell(hdfs_file->fs, hdfs_file->handle);
if (position == -1)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
return position;
}
void Flush(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
if (hdfs_file->libhdfs->hdfsHFlush(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
if (hdfs_file->libhdfs->hdfsHSync(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
else
TF_SetStatus(status, TF_OK, "");
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto hdfs_file = static_cast<HDFSFile*>(file->plugin_file);
TF_SetStatus(status, TF_OK, "");
if (hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle) != 0)
TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str());
hdfs_file->fs = nullptr;
hdfs_file->handle = nullptr;
}
} // namespace tf_writable_file
@ -241,6 +396,248 @@ namespace tf_read_only_memory_region {
// ----------------------------------------------------------------------------
namespace tf_hadoop_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new LibHDFS(status);
if (TF_GetCode(status) != TF_OK) return;
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
delete libhdfs;
}
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_RDONLY, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_random_access_file::HDFSFile(path, hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(),
O_WRONLY | O_APPEND, 0, 0, 0);
if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path);
file->plugin_file =
new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle);
TF_SetStatus(status, TF_OK, "");
}
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
// hadoopReadZero() technically supports this call with the following
// caveats:
// - It only works up to 2 GB. We'd have to Stat() the file to ensure that
// it fits.
// - If not on the local filesystem, the entire file will be read, making
// it inefficient for callers that assume typical mmap() behavior.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"HDFS does not support ReadOnlyMemoryRegion");
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0)
TF_SetStatus(status, TF_OK, "");
else
TF_SetStatus(status, TF_NOT_FOUND,
(std::string(path) + " not found").c_str());
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path);
stats->length = static_cast<int64_t>(info->mSize);
stats->mtime_nsec = static_cast<int64_t>(info->mLastMod) * 1e9;
stats->is_directory = info->mKind == kObjectKindDirectory;
libhdfs->hdfsFreeFileInfo(info, 1);
TF_SetStatus(status, TF_OK, "");
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
TF_SetStatus(status, TF_OK, "");
auto size = static_cast<int64_t>(info->mSize);
libhdfs->hdfsFreeFileInfo(info, 1);
return size;
}
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// Count the number of entries in the directory, and only delete if it's
// non-empty. This is consistent with the interface, but note that there's
// a race condition where a file may be added after this check, in which
// case the directory will still be deleted.
int entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries);
if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries);
// Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty
// folder, especially for Kerberos enable setup, EAGAIN is quite common when
// the call is actually successful. Check again by Stat.
if (info == nullptr && errno != 0) {
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return;
}
if (entries > 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src);
ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst);
if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 &&
libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0)
return TF_SetStatusFromIOError(status, errno, dst);
if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) !=
0)
TF_SetStatusFromIOError(status, errno, src);
else
TF_SetStatus(status, TF_OK, "");
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// hdfsListDirectory returns nullptr if the directory is empty. Do a separate
// check to verify the directory exists first.
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries);
if (info == nullptr) {
if (stat.is_directory) {
// Assume it's an empty directory.
TF_SetStatus(status, TF_OK, "");
return 0;
}
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
auto BaseName = [](const std::string& name) {
return name.substr(name.find_last_of('/') + 1);
};
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(BaseName(info[i].mName).c_str());
}
libhdfs->hdfsFreeFileInfo(info, num_entries);
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
// TODO(vnvo2409): Implement later
} // namespace tf_hadoop_filesystem

View File

@ -1058,6 +1058,66 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path,
}
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
Aws::String copy_src = bucket_src + "/" + object_src;
Aws::String bucket_dst, object_dst;
ParseS3Path(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
if (object_src.back() == '/') {
if (object_dst.back() != '/') {
object_dst.push_back('/');
}
} else {
if (object_dst.back() == '/') {
object_dst.pop_back();
}
}
Aws::S3::Model::DeleteObjectRequest delete_object_request;
Aws::S3::Model::ListObjectsRequest list_objects_request;
list_objects_request.WithBucket(bucket_src)
.WithPrefix(object_src)
.WithMaxKeys(kS3GetChildrenMaxKeys);
list_objects_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
Aws::S3::Model::ListObjectsResult list_objects_result;
do {
auto list_objects_outcome =
s3_file->s3_client->ListObjects(list_objects_request);
if (!list_objects_outcome.IsSuccess())
return TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
list_objects_result = list_objects_outcome.GetResult();
for (const auto& object : list_objects_result.GetContents()) {
Aws::String key_src = object.GetKey();
Aws::String key_dst = key_src;
key_dst.replace(0, object_src.length(), object_dst);
CopyFile(filesystem, ("s3://" + bucket_src + "/" + key_src).c_str(),
("s3://" + bucket_dst + "/" + key_dst).c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
delete_object_request.WithBucket(bucket_src).WithKey(key_src);
auto delete_object_outcome =
s3_file->s3_client->DeleteObject(delete_object_request);
if (!delete_object_outcome.IsSuccess())
return TF_SetStatusFromAWSError(delete_object_outcome.GetError(),
status);
}
list_objects_request.SetMarker(list_objects_result.GetNextMarker());
} while (list_objects_result.GetIsTruncated());
TF_SetStatus(status, TF_OK, "");
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
Aws::String bucket, prefix;
@ -1161,6 +1221,7 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->delete_file = tf_s3_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_s3_filesystem::DeleteDir;
ops->filesystem_ops->copy_file = tf_s3_filesystem::CopyFile;
ops->filesystem_ops->rename_file = tf_s3_filesystem::RenameFile;
ops->filesystem_ops->path_exists = tf_s3_filesystem::PathExists;
ops->filesystem_ops->get_file_size = tf_s3_filesystem::GetFileSize;
ops->filesystem_ops->stat = tf_s3_filesystem::Stat;

View File

@ -92,6 +92,10 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status);
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status);
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status);
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status);
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_

View File

@ -439,6 +439,51 @@ TEST_F(S3FilesystemTest, StatFile) {
EXPECT_FALSE(stat.is_directory);
}
TEST_F(S3FilesystemTest, SimpleCopyFile) {
const std::string src = GetURIForPath("SimpleCopySrc");
const std::string dst = GetURIForPath("SimpleCopyDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::CopyFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ(result, "test");
}
TEST_F(S3FilesystemTest, RenameFile) {
const std::string src = GetURIForPath("RenameFileSrc");
const std::string dst = GetURIForPath("RenameFileDst");
WriteString(src, "test");
ASSERT_TF_OK(status_);
tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test", result);
}
TEST_F(S3FilesystemTest, RenameFileOverwrite) {
const std::string src = GetURIForPath("RenameFileOverwriteSrc");
const std::string dst = GetURIForPath("RenameFileOverwriteDst");
WriteString(src, "test_old");
ASSERT_TF_OK(status_);
WriteString(dst, "test_new");
ASSERT_TF_OK(status_);
tf_s3_filesystem::PathExists(filesystem_, dst.c_str(), status_);
EXPECT_TF_OK(status_);
tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_);
EXPECT_TF_OK(status_);
auto result = ReadAll(dst);
EXPECT_TF_OK(status_);
EXPECT_EQ("test_old", result);
}
// Test against large file.
TEST_F(S3FilesystemTest, ReadLargeFile) {
auto local_path = GetLocalLargeFile();
@ -458,6 +503,29 @@ TEST_F(S3FilesystemTest, ReadLargeFile) {
EXPECT_EQ(local_content, server_content);
}
TEST_F(S3FilesystemTest, CopyLargeFile) {
auto server_path = GetServerLargeFile();
if (server_path.empty()) GTEST_SKIP();
auto path = GetURIForPath("CopyLargeFile");
constexpr size_t buffer_size = 5 * 1024 * 1024;
auto s3_file =
static_cast<tf_s3_filesystem::S3File*>(filesystem_->plugin_filesystem);
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
buffer_size;
tf_s3_filesystem::CopyFile(filesystem_, server_path.c_str(), path.c_str(),
status_);
EXPECT_TF_OK(status_);
auto server_size =
tf_s3_filesystem::GetFileSize(filesystem_, server_path.c_str(), status_);
EXPECT_TF_OK(status_);
auto actual_size =
tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_);
EXPECT_TF_OK(status_);
EXPECT_EQ(server_size, actual_size);
}
} // namespace
} // namespace tensorflow

View File

@ -3,6 +3,24 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_grad",
srcs = ["array_grad.cc"],
hdrs = [
"array_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
@ -18,6 +36,7 @@ cc_library(
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
namespace tensorflow {
namespace gradients {
namespace {
using std::vector;
class IdentityNGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(grad_inputs.size(), nullptr);
for (int i = 0; i < grad_inputs.size(); i++) {
auto grad_input = grad_inputs[i];
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
// that takes care of this similar to `Tensor`?
if (grad_input) {
grad_input->Ref();
}
(*grad_outputs)[i] = grad_input;
}
return Status::OK();
}
~IdentityNGradientFunction() override {}
};
} // namespace
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
auto gradient_function = new IdentityNGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "tensorflow/c/eager/gradients.h"
namespace mlir {
namespace tensorflow {
namespace gradients {
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
namespace {
bool register_all_passes = ([] {
mhlo::registerAllMhloPasses();
lmhlo::registerAllLmhloPasses();
}(), true);
} // namespace
} // namespace mlir
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_

View File

@ -14,9 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul;
namespace tensorflow {
namespace gradients {
@ -24,11 +30,10 @@ namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
@ -44,10 +49,47 @@ class AddGradientFunction : public GradientFunction {
~AddGradientFunction() override {}
};
class ExpGradientFunction : public GradientFunction {
public:
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), "ExpGradMul"));
return Status::OK();
}
~ExpGradientFunction() override {}
private:
AbstractTensorHandlePtr exp_;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
auto gradient_function = new AddGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -19,7 +19,8 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
@ -22,3 +23,26 @@ cc_library(
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "math_ops",
srcs = [
"math_ops.cc",
],
hdrs = [
"math_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {

View File

@ -15,9 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
using tensorflow::tracing::TracingOperation;
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
int num_retvals = 1;
return mul_op->Execute(outputs, &num_retvals);
}
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
auto dtype = inputs[0]->DataType();
if (DataTypeIsFloating(BaseType(dtype)) ||
DataTypeIsInteger(BaseType(dtype))) {
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
} else {
return errors::Unimplemented("Conj does not support complex types yet.");
}
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_

View File

@ -216,6 +216,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "signature_flattening_test",
srcs = [
"signature_flattening_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "tf_concrete_function_loading_test",
srcs = [

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
@ -36,52 +37,8 @@ namespace tensorflow {
namespace internal {
namespace {
// This returns the size of `tf.nest.flatten(value)`, on values that are
// used in tf.function's input_signatures.
int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (value.kind_case()) {
case StructuredValue::kDictValue: {
const DictValue& dict = value.dict_value();
int size = 0;
for (const auto& field : dict.fields()) {
size += FlattenedSize(field.second, status);
}
return size;
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = value.tuple_value();
int size = 0;
for (const StructuredValue& value : tuple.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kListValue: {
const ListValue& list = value.list_value();
int size = 0;
for (const StructuredValue& value : list.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kTensorSpecValue: {
return 1;
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return 0;
}
default: {
status->Update(errors::Internal("Unhandled structured value kind ",
value.kind_case()));
return 0;
}
}
}
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
// Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input
@ -111,34 +68,34 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
const std::string& name = function_def->signature().name();
const StructuredValue& input_signature =
saved_concrete_function.canonicalized_input_signature();
Status status;
int input_signature_size = FlattenedSize(input_signature, &status);
TF_RETURN_IF_ERROR(status);
if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
std::vector<const TensorSpecProto*> input_specs;
TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs));
if (input_specs.size() + saved_concrete_function.bound_inputs_size() !=
function_def->signature().input_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().input_arg_size(),
" inputs, but the SavedConcreteFunction has ", input_signature_size,
" inputs, but the SavedConcreteFunction has ", input_specs.size(),
" flattened user inputs and ",
saved_concrete_function.bound_inputs_size(), " captured inputs.");
}
const StructuredValue& output_signature =
saved_concrete_function.output_signature();
int output_signature_size = FlattenedSize(output_signature, &status);
TF_RETURN_IF_ERROR(status);
if (output_signature_size != function_def->signature().output_arg_size()) {
std::vector<const TensorSpecProto*> output_specs;
TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs));
if (output_specs.size() != function_def->signature().output_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().output_arg_size(),
" outputs, but the SavedConcreteFunction has ", output_signature_size,
" outputs, but the SavedConcreteFunction has ", output_specs.size(),
" flattened outputs.");
}
return status;
return Status();
}
} // namespace
@ -197,6 +154,62 @@ Status LoadTFConcreteFunction(
out);
}
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (signature.kind_case()) {
case StructuredValue::kDictValue: {
// Dictionaries must be sorted in order of keys
const DictValue& dict = signature.dict_value();
std::vector<const StructuredValueDictEntry*> entries;
entries.reserve(dict.fields_size());
for (const auto& field : dict.fields()) {
entries.push_back(&field);
}
std::sort(entries.begin(), entries.end(),
[](const StructuredValueDictEntry* x,
const StructuredValueDictEntry* y) {
return x->first < y->first;
});
for (const auto& entry : entries) {
TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs));
}
return Status();
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = signature.tuple_value();
for (const StructuredValue& value : tuple.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kListValue: {
const ListValue& list = signature.list_value();
for (const StructuredValue& value : list.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kTensorSpecValue: {
flattened_specs->push_back(&signature.tensor_spec_value());
return Status();
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return Status();
}
default: {
return errors::Internal("Unhandled structured value kind ",
signature.kind_case());
}
}
}
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph) {
const auto& nodes = object_graph.nodes();

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace internal {
@ -59,10 +60,17 @@ Status LoadTFConcreteFunction(
captured_objects,
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
// Find the SavedObject in `object_graph` at location `path`. `path` must be a
// dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph` outlives
// the returned pointer.
// Flattens `signature` into a vector of TensorSpecProto pointers back into
// `signature`. `signature` must outlive flattened_specs. `signature` must also
// be the input or output signature of a SavedConcreteFunction (i.e. "nested
// structures of tensorspecs").
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs);
// Find the SavedObject in `object_graph` at location `path`. `path` must be
// a dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph`
// outlives the returned pointer.
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph);

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace {
// Validates names, shapes, and dtypes of two tensorspecprotos are equivalent.
bool TensorSpecsAreEqual(const TensorSpecProto& spec,
const std::string& expected_name,
const PartialTensorShape& expected_shape,
DataType expected_dtype) {
return spec.name() == expected_name &&
PartialTensorShape(spec.shape()).IsIdenticalTo(expected_shape) &&
spec.dtype() == expected_dtype;
}
// This tests the common case for a tf.function w/o inputs. This ends up
// being serialized as a tuple of an empty tuple + empty dictionary
// (corresponding to the args, kwargs) of the function.
TEST(SignatureFlatteningTest, ZeroArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
// This tests the common case for a tf.function w/o outputs. This ends up
// being serialized as a "NoneValue".
TEST(SignatureFlatteningTest, ZeroRetOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
TEST(SignatureFlatteningTest, SingleArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1, 10},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, SingleReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, ThreeArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "y",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "z",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
// This test has an exotic outputsignature of tuple of a
// dictionary<string,tensor>, tensor
TEST(SignatureFlatteningTest, ThreeReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "0/a",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "0/b",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "1",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
} // namespace
} // namespace tensorflow

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
@ -241,8 +242,11 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
// TODO(bmzhao): This requires using the newly added Save/Restore
// functions from
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
return errors::Unimplemented(
"Restoring non-variable objects has not been implemented yet. ");
LOG(WARNING) << "Restoring non-variable objects has not been "
"implemented yet. (Kind="
<< bundle->saved_object_graph().nodes(node).kind_case()
<< ")";
return Status::OK();
}
Variable* variable =

View File

@ -38,8 +38,6 @@ cc_library(
":concrete_function_type",
":function_metadata",
":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:abstract_tensor_handle",
@ -167,38 +165,6 @@ cc_library(
],
)
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",
@ -216,7 +182,6 @@ tf_cc_test(
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"

View File

@ -24,7 +24,6 @@ exports_files(
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
"tensorhandle_list.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -40,7 +39,6 @@ cc_library(
":concrete_function_list",
":function_metadata",
":saved_model_api",
":tensorhandle_list",
],
)
@ -63,8 +61,3 @@ alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus
extern "C" {

View File

@ -1,43 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -261,7 +261,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
size_t len, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
@ -311,3 +310,42 @@ TF_Tensor* TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context,
}
return tf_tensor_output;
}
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
int64_t* dims, int num_dims,
TF_AllocatorAttributes* attributes,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
TF_SetStatus(status, TF_OK, "");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
if (attributes && !attributes->struct_size) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"TF_AllocatorAttributes struct "
"size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
return nullptr;
}
tensorflow::AllocatorAttributes allocator_attr;
if (attributes && attributes->on_host) {
allocator_attr.set_on_host(true);
}
tensorflow::Status s;
tensorflow::Tensor tensor;
s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimarray), &tensor,
allocator_attr);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor;
tf_tensor = TF_TensorFromTensor(tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
@ -210,6 +211,15 @@ TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status);
// Allocates a temporary Tensor of the specified type and shape. The
// Tensor must not be used after kernel construction is
// complete.
//
// num_dims must equal the size of array dims
TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp(
TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims,
TF_AllocatorAttributes* alloc_attrs, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -24,6 +24,35 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "summary_op",
prefix = "summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/kernels:tensor_shape_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "histogram_summary_op",
prefix = "histogram_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_gen_op_libs(
op_lib_names = ["bitcast"],
deps = [
@ -35,6 +64,24 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = ["summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_gen_op_libs(
op_lib_names = ["histogram_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bitcast_op_test",
srcs = ["bitcast_op_test.cc"],
@ -48,6 +95,45 @@ tf_cc_test(
],
)
tf_cc_test(
name = "summary_op_test",
srcs = ["summary_op_test.cc"],
deps = [
":summary_op",
"//tensorflow/c:kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "tensor_shape_utils",
srcs = ["tensor_shape_utils.cc"],
hdrs = ["tensor_shape_utils.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/c:tf_tensor",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "tensor_shape_utils_test",
srcs = ["tensor_shape_utils_test.cc"],
deps = [
":tensor_shape_utils",
"//tensorflow/c:tf_tensor_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# Changes to the Android srcs here should be replicated in
# tensorflow/contrib/makefile/tf_op_files.txt.
#
@ -59,11 +145,19 @@ filegroup(
name = "android_all_op_kernels",
srcs = [
"bitcast_op.cc",
"histogram_summary_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
"tensor_shape_utils.h",
],
)
# LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt)
filegroup(
name = "android_all_ops",
srcs = ["ops/bitcast.cc"],
srcs = [
"ops/bitcast.cc",
"ops/histogram_summary.cc",
"ops/summary.cc",
],
)

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status.
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope.
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// Used to pass the operation node name from kernel construction to
// kernel computation.
struct HistogramSummaryOp {
std::string op_node_name;
};
void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) {
HistogramSummaryOp* kernel = new HistogramSummaryOp;
TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx);
kernel->op_node_name =
std::string(string_view_name.data, string_view_name.len);
return kernel;
}
void HistogramSummaryOp_Delete(void* kernel) {
delete static_cast<HistogramSummaryOp*>(kernel);
}
template <typename T>
void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel);
TF_Tensor* tags;
TF_Tensor* values;
Safe_TF_StatusPtr status(TF_NewStatus());
TF_GetInput(ctx, 0, &tags, status.get());
Safe_TF_TensorPtr safe_tags_ptr(tags);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
TF_GetInput(ctx, 1, &values, status.get());
Safe_TF_TensorPtr safe_values_ptr(values);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
if (TF_NumDims(safe_tags_ptr.get()) != 0) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
// Cast values to array to access tensor elements by index
auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get()));
tensorflow::histogram::Histogram histo;
for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) {
const double double_val = static_cast<double>(values_array[i]);
if (Eigen::numext::isnan(double_val)) {
std::ostringstream err;
err << "Nan in summary histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
} else if (Eigen::numext::isinf(double_val)) {
std::ostringstream err;
err << "Infinity in Histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
histo.Add(double_val);
}
tensorflow::Summary s;
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& tag =
*(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get())));
v->set_tag(tag.data(), tag.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
template <typename T>
void RegisterHistogramSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create,
&HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("HistogramSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Histogram Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) {
RegisterHistogramSummaryOpKernel<tensorflow::int64>();
RegisterHistogramSummaryOpKernel<tensorflow::uint64>();
RegisterHistogramSummaryOpKernel<tensorflow::int32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint16>();
RegisterHistogramSummaryOpKernel<tensorflow::int16>();
RegisterHistogramSummaryOpKernel<tensorflow::int8>();
RegisterHistogramSummaryOpKernel<tensorflow::uint8>();
RegisterHistogramSummaryOpKernel<Eigen::half>();
RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>();
RegisterHistogramSummaryOpKernel<float>();
RegisterHistogramSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void histogram_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_HistogramSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("HistogramSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tag: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype = DT_FLOAT");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &histogram_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "HistogramSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool HistogramSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("HistogramSummary")) {
Register_HistogramSummaryOp();
}
return true;
}();

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void scalar_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_ScalarSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("ScalarSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tags: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &scalar_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "ScalarSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool SummaryScalarOpRegistered = []() {
if (SHOULD_REGISTER_OP("ScalarSummary")) {
Register_ScalarSummaryOp();
}
return true;
}();

View File

@ -0,0 +1,172 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Struct that stores the status and TF_Tensor inputs to the opkernel.
// Used to delete tensor and status in its destructor upon kernel return.
struct Params {
TF_Tensor* tags;
TF_Tensor* values;
TF_Status* status;
explicit Params(TF_OpKernelContext* ctx)
: tags(nullptr), values(nullptr), status(nullptr) {
status = TF_NewStatus();
TF_GetInput(ctx, 0, &tags, status);
if (TF_GetCode(status) == TF_OK) {
TF_GetInput(ctx, 1, &values, status);
}
}
~Params() {
TF_DeleteStatus(status);
TF_DeleteTensor(tags);
TF_DeleteTensor(values);
}
};
// dummy functions used for kernel registration
void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
void ScalarSummaryOp_Delete(void* kernel) {}
// Helper functions for compute method
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
// Returns a string representation of a single tag or empty string if there
// are multiple tags
std::string SingleTag(TF_Tensor* tags);
template <typename T>
void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
Params params(ctx);
if (TF_GetCode(params.status) != TF_OK) {
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
if (!IsSameSize(params.tags, params.values)) {
std::ostringstream err;
err << "tags and values are not the same shape: "
<< tensorflow::ShapeDebugString(params.tags)
<< " != " << tensorflow::ShapeDebugString(params.values)
<< SingleTag(params.tags);
TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
// Convert tags and values tensor to array to access elements by index
tensorflow::Summary s;
auto tags_array =
static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
auto values_array = static_cast<T*>(TF_TensorData(params.values));
// Copy tags and values into summary protobuf
for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& Ttags_i = tags_array[i];
v->set_tag(Ttags_i.data(), Ttags_i.size());
v->set_simple_value(static_cast<float>(values_array[i]));
}
TF_Tensor* summary_tensor =
TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
sizeof(tensorflow::tstring), params.status);
if (TF_GetCode(params.status) != TF_OK) {
TF_DeleteTensor(summary_tensor);
TF_OpKernelContext_Failure(ctx, params.status);
return;
}
tensorflow::tstring* output_tstring =
reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
CHECK(SerializeToTString(s, output_tstring));
TF_DeleteTensor(summary_tensor);
}
bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
return false;
}
for (int d = 0; d < TF_NumDims(tensor1); d++) {
if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
return false;
}
}
return true;
}
std::string SingleTag(TF_Tensor* tags) {
if (TF_TensorElementCount(tags) == 1) {
const char* single_tag =
static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
} else {
return "";
}
}
template <typename T>
void RegisterScalarSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
&ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("ScalarSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Scalar Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the ScalarSummary kernel.
TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
RegisterScalarSummaryOpKernel<tensorflow::int64>();
RegisterScalarSummaryOpKernel<tensorflow::uint64>();
RegisterScalarSummaryOpKernel<tensorflow::int32>();
RegisterScalarSummaryOpKernel<tensorflow::uint32>();
RegisterScalarSummaryOpKernel<tensorflow::uint16>();
RegisterScalarSummaryOpKernel<tensorflow::int16>();
RegisterScalarSummaryOpKernel<tensorflow::int8>();
RegisterScalarSummaryOpKernel<tensorflow::uint8>();
RegisterScalarSummaryOpKernel<Eigen::half>();
RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
RegisterScalarSummaryOpKernel<float>();
RegisterScalarSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,186 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace {
class DummyDevice : public DeviceBase {
public:
explicit DummyDevice(Env* env) : DeviceBase(env) {}
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
};
// Helper for comparing ouput and expected output
void ExpectSummaryMatches(const Summary& actual, const string& expected_str) {
Summary expected;
ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected));
EXPECT_EQ(expected.DebugString(), actual.DebugString());
}
void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output,
error::Code expected_code) {
// Initialize node used to fetch OpKernel
Status status;
NodeDef def;
def.set_op("ScalarSummary");
def.set_device(DEVICE_CPU);
AttrValue valuesTypeAttr;
SetAttrValue(values->dtype(), &valuesTypeAttr);
(*def.mutable_attr())["T"] = valuesTypeAttr;
def.add_input(strings::StrCat("input1: ", DataTypeString(tags->dtype())));
def.add_input(strings::StrCat("input2: ", DataTypeString(values->dtype())));
std::unique_ptr<OpKernel> kernel =
CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status);
ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params;
DummyDevice dummy_device(nullptr);
params.device = &dummy_device;
params.op_kernel = kernel.get();
AllocatorAttributes alloc_attrs;
params.output_attr_array = &alloc_attrs;
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.emplace_back(tags);
inputs.emplace_back(values);
params.inputs = &inputs;
OpKernelContext ctx(&params, 1);
kernel->Compute(&ctx);
ASSERT_EQ(expected_code, ctx.status().code());
if (expected_code == error::OK) {
Summary summary;
ASSERT_TRUE(ParseProtoUnlimited(
&summary, ctx.mutable_output(0)->scalar<tstring>()()));
ExpectSummaryMatches(summary, expected_output);
} else {
EXPECT_TRUE(absl::StrContains(ctx.status().ToString(), expected_output))
<< ctx.status();
}
}
TEST(ScalarSummaryOpTest, SimpleFloat) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_FLOAT, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -0.73f;
values.vec<float>()(2) = 10000.0f;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleDouble) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_DOUBLE, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<double>()(0) = 1.0;
values.vec<double>()(1) = -0.73;
values.vec<double>()(2) = 10000.0;
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -0.73}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, SimpleHalf) {
int vectorSize = 3;
Tensor tags(DT_STRING, {vectorSize});
Tensor values(DT_HALF, {vectorSize});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
tags.vec<tstring>()(2) = "tag3";
values.vec<Eigen::half>()(0) = Eigen::half(1.0);
values.vec<Eigen::half>()(1) = Eigen::half(-2.0);
values.vec<Eigen::half>()(2) = Eigen::half(10000.0);
TestScalarSummaryOp(&tags, &values, R"(
value { tag: 'tag1' simple_value: 1.0 }
value { tag: 'tag2' simple_value: -2.0}
value { tag: 'tag3' simple_value: 10000.0})",
error::OK);
}
TEST(ScalarSummaryOpTest, Error_WrongDimsTags) {
Tensor tags(DT_STRING, {2, 1});
Tensor values(DT_FLOAT, {2});
tags.matrix<tstring>()(0, 0) = "tag1";
tags.matrix<tstring>()(1, 0) = "tag2";
values.vec<float>()(0) = 1.0f;
values.vec<float>()(1) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongValuesTags) {
Tensor tags(DT_STRING, {2});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
tags.vec<tstring>()(1) = "tag2";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, Error_WrongWithSingleTag) {
Tensor tags(DT_STRING, {1});
Tensor values(DT_FLOAT, {2, 1});
tags.vec<tstring>()(0) = "tag1";
values.matrix<float>()(0, 0) = 1.0f;
values.matrix<float>()(1, 0) = -2.0f;
TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape",
error::INVALID_ARGUMENT);
}
TEST(ScalarSummaryOpTest, IsRegistered) {
const OpRegistrationData* reg;
TF_CHECK_OK(OpRegistry::Global()->LookUp("ScalarSummary", &reg));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include <string>
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
namespace tensorflow {
std::string ShapeDebugString(TF_Tensor* tensor) {
// A TF_Tensor cannot have an unknown rank.
CHECK_GE(TF_NumDims(tensor), 0);
tensorflow::string s = "[";
for (int i = 0; i < TF_NumDims(tensor); ++i) {
if (i > 0) tensorflow::strings::StrAppend(&s, ",");
int64_t dim = TF_Dim(tensor, i);
// A TF_Tensor cannot have an unknown dimension.
CHECK_GE(dim, 0);
tensorflow::strings::StrAppend(&s, dim);
}
tensorflow::strings::StrAppend(&s, "]");
return s;
}
} // namespace tensorflow

View File

@ -13,25 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
// This file contains shape utilities to be used by kernels and is not part of
// the C API. As such, it is subject to change at any time.
#include <vector>
#ifndef TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#define TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include <string>
// Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on.
typedef struct TF_TensorHandleList TF_TensorHandleList;
#include "tensorflow/c/tf_tensor.h"
namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
TF_TensorHandleList)
// The following are utils for the shape of a TF_Tensor type.
// These functions may later be subsumed by the methods for a
// TF_TensorShape type.
// Returns a string representation of the TF_Tensor shape.
std::string ShapeDebugString(TF_Tensor* tensor);
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_
#endif // TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels/tensor_shape_utils.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// A wrapper that will automatically delete the allocated TF_Tensor
// once out of scope.
struct TF_TensorWrapper {
TF_Tensor* tf_tensor;
explicit TF_TensorWrapper(TF_Tensor* tensor) { tf_tensor = tensor; }
~TF_TensorWrapper() { TF_DeleteTensor(tf_tensor); }
};
void TestShapeMatch(TensorShape shape) {
Tensor tensor(DT_FLOAT, shape);
Status status;
TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status);
TF_TensorWrapper tensor_wrapper = TF_TensorWrapper(tf_tensor);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_EQ(tensor.shape().DebugString(), ShapeDebugString(tf_tensor));
}
TEST(ShapeDebugString, RegularShape) { TestShapeMatch(TensorShape({5, 4, 7})); }
TEST(ShapeDebugString, ScalarShape) { TestShapeMatch(TensorShape({})); }
} // namespace
} // namespace tensorflow

View File

@ -368,6 +368,16 @@ class DeviceKernelOpTest : public OpsTestBase {
#endif
};
// Validates that the tensor has shape and type corresponding to
// dims and dtype.
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype);
// Copies data of length tensor_size_bytes from values to tensor.
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx);
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
@ -379,22 +389,11 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(1, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set output to 3
float* data = reinterpret_cast<float*>(TF_TensorData(output));
float value = 3.0f;
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
tensor_size_bytes);
#else
*data = value;
#endif
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -417,12 +416,8 @@ TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/0, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(0, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -442,27 +437,16 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
TF_Status* s = TF_NewStatus();
// Allocate 2x3 output
int64_t dim[2] = {2, 3};
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(2, TF_NumDims(output));
EXPECT_EQ(2, TF_Dim(output, 0));
EXPECT_EQ(3, TF_Dim(output, 1));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set output to [1 2 3 4 5 6]
void* data = TF_TensorData(output);
float value[6] = {1, 2, 3, 4, 5, 6};
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
tensor_size_bytes);
#else
memcpy(data, value, tensor_size_bytes);
#endif
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -475,6 +459,112 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
// Allocate scalar TF_Tensor
TF_Status* s = TF_NewStatus();
int64_t dim = 1;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set TF_Tensor value to 3
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp0").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
// Allocate empty TF_Tensor
int64_t dim = 0;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [0] values: >",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
// Allocate 2x3 TF_Tensor
int64_t dim[2] = {2, 3};
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set TF_Tensor values to [1 2 3 4 5 6]
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
output->DebugString(100));
}
TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
const char* node_name = "TestForwardInputOrAllocateOutputKernel";
const char* op_name = "BazOp";
@ -484,7 +574,7 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
.Input("input1: float")
.Input("input2: float")
.Output("output1: float")
.Attr("SomeDataTypeAttr: type");;
.Attr("SomeDataTypeAttr: type");
// A kernel whose Compute function that forwards a scalar input to output
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
@ -501,6 +591,7 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(0, TF_NumDims(output));
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
@ -540,4 +631,26 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
}
}
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype) {
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
EXPECT_EQ(num_dims, TF_NumDims(tensor));
for (int i = 0; i < num_dims; ++i) {
EXPECT_EQ(dims[i], TF_Dim(tensor, i));
}
}
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx) {
T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
tensor_size_bytes);
#else
memcpy(data, values, tensor_size_bytes);
#endif
}
} // namespace tensorflow

59
tensorflow/c/logging.cc Normal file
View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/logging.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stringprintf.h"
static ::tensorflow::string BuildMessage(const char* fmt, va_list args) {
::tensorflow::string message;
::tensorflow::strings::Appendv(&message, fmt, args);
return message;
}
void TF_Log(TF_LogLevel level, const char* fmt, ...) {
if (level < TF_INFO || level > TF_FATAL) return;
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
switch (level) {
case TF_INFO:
LOG(INFO) << message;
break;
case TF_WARNING:
LOG(WARNING) << message;
break;
case TF_ERROR:
LOG(ERROR) << message;
break;
case TF_FATAL:
LOG(FATAL) << message;
break;
}
}
void TF_VLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
VLOG(level) << message;
}
void TF_DVLog(int level, const char* fmt, ...) {
va_list args;
va_start(args, fmt);
auto message = BuildMessage(fmt, args);
DVLOG(level) << message;
}

View File

@ -12,25 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_LOGGING_H_
#define TENSORFLOW_C_LOGGING_H_
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include "tensorflow/c/c_api_macros.h"
#include <stddef.h>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
// --------------------------------------------------------------------------
// C API for tensorflow::Logging.
#ifdef __cplusplus
extern "C" {
#endif
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
return tensorflow::unwrap(list)->size();
typedef enum TF_LogLevel {
TF_INFO = 0,
TF_WARNING = 1,
TF_ERROR = 2,
TF_FATAL = 3,
} TF_LogLevel;
TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...);
TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...);
TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...);
#ifdef __cplusplus
}
#endif
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"
#endif // TENSORFLOW_C_LOGGING_H_

View File

@ -288,7 +288,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))};
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -45,6 +46,16 @@ limitations under the License.
extern "C" {
#endif
// Allocator Attributes used for tensor allocation.
typedef struct TF_AllocatorAttributes {
size_t struct_size;
// Set boolean to 1 for CPU allocation, else 0.
TF_Bool on_host;
} TF_AllocatorAttributes;
#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \
TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host)
// --------------------------------------------------------------------------
// TF_Tensor holds a multi-dimensional array of elements of a single data type.
// For all types other than TF_STRING, the data buffer stores elements

View File

@ -558,6 +558,7 @@ tf_gen_op_wrappers_cc(
"io_ops",
"linalg_ops",
"list_ops",
"map_ops",
"logging_ops",
"lookup_ops",
"manip_ops",

View File

@ -128,22 +128,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_interpreter_device",
srcs = ["xla_interpreter_device.cc"],
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
cc_library(
name = "xla_tensor",
srcs = ["xla_tensor.cc"],
@ -211,6 +195,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
"//tensorflow/stream_executor/platform",
]
@ -221,16 +206,19 @@ cc_library(
"xla_device.cc",
"xla_device_context.cc",
"xla_device_ops.cc",
"xla_ops_on_regular_devices.cc",
"xla_platform_info.cc",
],
hdrs = [
"xla_compile_on_demand_op.h",
"xla_device.h",
"xla_device_context.h",
"xla_device_ops.h",
"xla_platform_info.h",
],
# Public visibility is needed for external TF/XLA backends.
visibility = ["//visibility:public"],
deps = XLA_DEVICE_DEPS,
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
)
cc_library(
@ -394,20 +382,23 @@ cc_library(
alwayslink = 1,
)
# Linked by tensorflow core, without registration of jit compilation passes
# which is not necessary to create and run a XlaLocalLaunchBase kernel.
# Linking jit compilation passes could cause programs stuck right now (b/140069592).
cc_library(
name = "xla_kernel_creator_util",
name = "xla_kernel_creator",
srcs = [
"xla_kernel_creator_util.cc",
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
visibility = [
":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
"//tensorflow/core/common_runtime/eager:__pkg__",
],
hdrs = ["xla_kernel_creator_util.h"],
visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"],
deps = [
":common",
":compilability_check_util",
":compilation_passes",
":flags",
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
@ -422,25 +413,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_kernel_creator",
srcs = [
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
deps = [
":compilability_check_util",
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_cc_test(
name = "xla_kernel_creator_test",
srcs = [

View File

@ -159,7 +159,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
device_flags->tf_xla_enable_xla_devices = false;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
@ -268,10 +268,4 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -162,14 +162,6 @@ MlirCommonFlags* GetMlirCommonFlags();
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -63,38 +64,6 @@ namespace tensorflow {
namespace {
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
// A closure describing how to run a compiled version of a TensorFlow function.
//
@ -178,31 +147,6 @@ class XlaExecutableClosureStore {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
// Return allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
// platform_info.
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
} // namespace
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
@ -214,65 +158,9 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants),
resources_(resources),
function_(function),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
has_ref_vars_(has_ref_vars) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
@ -292,7 +180,7 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
return BuildCompilationCache(ctx, platform_info, cache);
return BuildXlaCompilationCache(ctx, platform_info, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
@ -302,32 +190,14 @@ static Status CompileToLocalExecutable(
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options;
options.client = *client;
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
constant_args.insert({i, ctx->input(i)});
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array
@ -503,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {}
@ -591,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -31,61 +32,6 @@ limitations under the License.
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
// The only difference is that it does not require arguments to follow

View File

@ -1952,6 +1952,7 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
"ParallelDynamicStitch",
"ParameterizedTruncatedNormal",
"PartitionedCall",
"PopulationCount",
"Qr",
"QuantizeAndDequantizeV2",
"QuantizeAndDequantizeV3",

View File

@ -44,6 +44,11 @@ using ::tensorflow::testing::FindNodeByName;
namespace tensorflow {
namespace {
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");

View File

@ -406,37 +406,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output dynamic_slice_operand =
ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32,
ops::Placeholder::Attrs{});
Output dynamic_slice_begin = ops::Placeholder(
s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice_size = ops::Placeholder(
s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice =
ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand,
dynamic_slice_begin, dynamic_slice_size);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape =
ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice);
AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
Node* n = FindNodeByName(*graph, "dynamic_slice");
ASSERT_NE(n, nullptr);
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
const char* const kClusteredProducer0Name = "ClusteredProducer0";
const char* const kClusteredProducer1Name = "ClusteredProducer1";

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -41,18 +42,21 @@ static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
}
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
// Builds an XLA allocator for the device.
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
platform_info_.xla_device_metadata()
? platform_info_.xla_device_metadata()->UseMultipleStreams()
: false);
std::map<int, const Tensor*> snapshot_ptrs;
for (auto& p : variable_args) {
@ -70,12 +74,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
@ -94,71 +97,54 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
return Status::OK();
}
Status XlaCompileOnDemandOp::MustArgumentBeConstant(
const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime, bool* result) {
*result = false;
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
// TODO(jmolloy): This could be expensive, so memoize.
std::vector<int> constant_input_indices;
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
op_kernel, &constant_input_indices, flib_runtime));
*result = absl::c_binary_search(constant_input_indices, argument_idx);
return Status::OK();
}
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
CHECK(absl::c_is_sorted(constant_input_indices));
// TODO(ycao): Remove the need to call ShouldArgumentBeConstant. Its benefit is
// not clear yet and it causes heavy constant analysis to run twice.
Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime, bool* result) {
return MustArgumentBeConstant(op_kernel, argument_idx, flib_runtime, result);
}
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
if (xla_tensor->has_host_tensor()) {
bool should_arg_be_const;
TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i,
ctx->function_library(),
&should_arg_be_const));
if (should_arg_be_const) {
if (absl::c_binary_search(constant_input_indices, i)) {
constant_arguments[i] = xla_tensor->host_tensor();
}
}
}
if (constant_arguments.count(i) == 0) {
bool must_argument_be_const;
TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i,
ctx->function_library(),
&must_argument_be_const));
if (must_argument_be_const) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
if (!constant_arguments.count(i)) {
if (absl::c_binary_search(constant_input_indices, i)) {
if (ctx->input_memory_type(i) != HOST_MEMORY &&
ctx->op_device_context()) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
device_tensor.shape(),
&host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
}
constant_arguments[i] = host_tensor;
} else {
constant_arguments[i] = device_tensor;
}
constant_arguments[i] = host_tensor;
}
}
}
@ -168,24 +154,16 @@ Status XlaCompileOnDemandOp::Compile(
ResourceMgr* rm = ctx->resource_manager();
CHECK(rm);
XlaCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
*cache = new XlaCompilationCache(metadata.client(),
metadata.jit_device_type());
return Status::OK();
rm->default_container(), "xla_cache", cache,
[&](XlaCompilationCache** write_into_cache) {
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.shape_representation_fn = metadata.shape_representation_fn();
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(**cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
@ -206,19 +184,25 @@ Status XlaCompileOnDemandOp::Compile(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* result;
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
ResourceVarsSnapshot variable_args;
XlaCompilationCache* cache;
OP_REQUIRES(ctx, ctx->function_library(),
errors::Internal("Function library missing"));
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
Compile(ctx, &result, &cache, &variable_args, &executable));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -35,25 +36,24 @@ namespace tensorflow {
// vanilla TensorFlow op as long as the bridge supports it.
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void Compute(OpKernelContext* ctx) override;
private:
XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i);
Status ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime,
bool* result);
Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime,
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Compile(OpKernelContext* ctx,
const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Run(OpKernelContext* ctx, XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
const XlaPlatformInfo platform_info_;
};
} // namespace tensorflow

View File

@ -61,6 +61,21 @@ limitations under the License.
namespace tensorflow {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
// XlaDeviceAllocator is created on demand and is associated with a
// XlaDevice. It outlives the device itself (for instance, the buffer
@ -116,20 +131,6 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
namespace {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
const string& device_name,

View File

@ -280,6 +280,8 @@ struct XlaDeviceOpRegistrations {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device);
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_

View File

@ -1,106 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
constexpr std::array<DataType, 10> kExecAllTypes = {
{DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
return Status::OK();
}
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations;
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);
TF_ASSIGN_OR_RETURN(
auto platform, se::MultiPlatformManager::PlatformWithName("Interpreter"));
XlaDevice::Options options;
options.platform = platform;
options.device_name_prefix = name_prefix;
options.device_name = DEVICE_XLA_INTERPRETER;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false;
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}
// Set priority to be below the default priority (50), so that Interpreter is
// not selected as a high priority device over other default devices. See
// constructor comments for Registrar in
// tensorflow/core/common_runtime/device_factory.h for a list of priority for
// devices.
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER,
XlaInterpreterDeviceFactory, 40);
// Kernel registrations
static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
kExecAllTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
} // namespace tensorflow

View File

@ -14,10 +14,62 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // end namespace
namespace tensorflow {
@ -27,6 +79,121 @@ bool XlaKernelCreator::CanCreateKernel(
return CanCreateXlaKernel(props->node_def);
}
static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
for (const auto& it : uncompilable_nodes_map) {
for (const auto& info : it.second.second) {
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n",
"\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
return errors::InvalidArgument(message);
}
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
@ -34,19 +201,12 @@ Status XlaKernelCreator::CreateKernel(
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {
bool RegisterLaunchOpCreator() {
static bool RegisterLaunchOpCreator() {
XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator();
RegisterDefaultCustomKernelCreator(xla_kernel_creator);
return true;
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -1,186 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // namespace
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
for (const auto& it : uncompilable_nodes_map) {
for (const auto& info : it.second.second) {
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n",
"\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
return errors::InvalidArgument(message);
}
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Register XlaXXX operations on regular CPU/GPU devices using
// `XlaCompileOnDemandOp`.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
#define REGISTER_XLA_OPS_ON_DEVICE(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("XlaConv") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.HostMemory("lhs_dilation") \
.HostMemory("rhs_dilation") \
.HostMemory("feature_group_count") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER( \
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelfAdjointEig").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSvd").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaPad").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \
.HostMemory("window_dimensions") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSend").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaGather") \
.HostMemory("start_indices") \
.HostMemory("slice_sizes") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \
XlaCompileOnDemandOp);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_CPU);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_GPU);
} // namespace tensorflow

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/xla/client/client_library.h"
namespace tensorflow {
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
CHECK(ctx->function_library());
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache.device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
return options;
}
} // namespace tensorflow

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of
// abstraction for normal and XLA devices.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// Returns created XLA compilation cache.
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_

View File

@ -0,0 +1,4 @@
build
llvm-project
llvm-build

View File

@ -60,7 +60,7 @@ gentbl(
strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/",
tbl_outs = [
(
"-gen-pass-decls",
"-gen-pass-decls -name MHLO",
"include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc",
),
],
@ -76,7 +76,7 @@ gentbl(
strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/",
tbl_outs = [
(
"-gen-pass-decls",
"-gen-pass-decls -name LMHLO",
"include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc",
),
],
@ -341,6 +341,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
@ -403,6 +404,7 @@ cc_library(
cc_library(
name = "lhlo_legalize_to_llvm",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"],
deps = [
":lhlo",
"@llvm-project//mlir:IR",
@ -758,8 +760,6 @@ cc_library(
":lhlo_legalize_to_llvm", # build-cleaner: keep
":materialize_broadcasts", # build-cleaner: keep
":unfuse_batch_norm", # build-cleaner: keep
"@llvm-project//mlir:AffineToStandardTransforms",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMDialect",
@ -806,13 +806,6 @@ cc_library(
],
)
cc_library(
name = "register_all_passes",
srcs = ["lib/Dialect/mhlo/transforms/register_all_passes.cc"],
deps = [":all_passes"],
alwayslink = 1,
)
cc_binary(
name = "mlir-hlo-opt",
srcs = [

View File

@ -0,0 +1,94 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.13.4)
if(POLICY CMP0068)
cmake_policy(SET CMP0068 NEW)
set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
endif()
if(POLICY CMP0075)
cmake_policy(SET CMP0075 NEW)
endif()
if(POLICY CMP0077)
cmake_policy(SET CMP0077 NEW)
endif()
#-------------------------------------------------------------------------------
# Project setup and globals
#-------------------------------------------------------------------------------
project(mlir-hlo LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
#-------------------------------------------------------------------------------
# Options and settings
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# MSVC defaults
#-------------------------------------------------------------------------------
if(MSVC)
add_compile_options(
$<$<CONFIG:>:/MD>
$<$<CONFIG:Debug>:/MD>
$<$<CONFIG:Release>:/MD>
)
endif()
#-------------------------------------------------------------------------------
# MLIR/LLVM Configuration
#-------------------------------------------------------------------------------
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
if(LLVM_ENABLE_ZLIB)
find_package(ZLIB)
endif()
include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
#-------------------------------------------------------------------------------
# Directory setup
#-------------------------------------------------------------------------------
set(MLIR_HLO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(MLIR_HLO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
add_custom_target(check-mlir-hlo)
add_subdirectory(include/mlir-hlo)
add_subdirectory(lib)
add_subdirectory(tools)
add_subdirectory(tests)

View File

@ -0,0 +1,233 @@
# MLIR-HLO: A Standalone "HLO" MLIR-based Compiler
The code here exists in two places:
* https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/hlo;
this is the canonical location and where contributions should be made using
GitHub pull-requests.
* https://github.com/tensorflow/mlir-hlo; this is a standalone repository with
a view to the same code to allow other projects to use this without
depending on the entire TF monorepo.
This implements a self-contained compiler for a linear algebra set of operations
inspired by XLA
[HLO IR](https://www.tensorflow.org/xla/architecture#how_does_xla_work) using
MLIR components. It is designed to provide an end-to-end flow independent of
TensorFlow and XLA, but usable inside of these projects.
Coding practice and conventions in this repository follow the
[MLIR Developer Guide](https://mlir.llvm.org/getting_started/DeveloperGuide/) in
this repo as part of the intent to act as an incubator for technology to
upstream.
## QuickStart: building and testing
These instructions work on Linux, you may have to adjust for your plaform.
To build the code in this repository, you need a clone of the LLVM/MLIR git
repository:
$ git clone https://github.com/llvm/llvm-project.git
You need to make sure you have the right commit checked out in the LLVM
repository (you need to do this every time you pull from this repo):
$ (cd llvm-project && git checkout $(cat build_tools/llvm_version.txt))
We provide a script to configure and build LLVM/MLIR:
$ build_tools/build_mlir.sh ${PWD}/llvm-project/ ${PWD}/llvm-build
Again this is something to do every time you pull from this repository and the
LLVM revision changes.
Finally you can build and test this repository:
$ mkdir build && cd build
$ cmake .. -GNinja \
-DLLVM_ENABLE_LLD=ON \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=On \
-DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir
$ ninja check-mlir-hlo
## Overview
MLIR-HLO aims to provide an end-to-end compiler for CPU and GPU, as well as
building reusable blocks for other accelerators. This is heavily inspired by the
success of XLA.
[XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) is a
domain-specific compiler framework and execution environment for linear algebra,
which powers code-generation for ML frameworks like TensorFlow, JAX, and others.
A cornerstone of XLA is the HLO (High Level Optimizer) IR, which offers a
carefully fixed selected list of operations, mostly orthogonal to each other. It
provides an efficient optimizer for computations expressed with this set of
operations and generate codes for hardware platforms like CPU, GPU, and TPUs.
Its goal is to provide a uniform interface to compile and execute these
optimized HLO programs independently of the targeted device. It is not a
front-end ML system like TensorFlow or JAX, rather it is a backend framework
that optimizes HLO and lowers to machine code.
The HLO set of operations is closed and has well defined semantics. HLO
operations operate on immutable Tensors with static shapes (actually bounded
shapes to be exact) and explicit broadcasts.
[MLIR](https://mlir.llvm.org/) is a compiler infrastructure which intends to
come with "battery included", as such it intends to provide all the blocks
required to assemble graph optimization and codegen pipelines. The longer term
roadmap for MLIR is to provide a
[Tensor Compute Primitive](https://llvm.discourse.group/c/mlir/MLIR-TCP-WG/36)
(TCP) dialect, which should hopefully be general enough to model what HLO
represents today (see
[slides](https://drive.google.com/open?id=1iljcpTQ5NPaMfGpoPDFml1XkYxjK_6A4) and
[recording](https://drive.google.com/open?id=1jSPa8TwPKUt0WuLquGc8OgSUVYJHMvWZ)
for a technical discussion on this topic).
The work on MLIR-HLO can be seen as a stepping stone towards building TCP, while
integrating intermediate components into XLA itself by relying on the
well-proven HLO IR and introducing more pieces from upstream MLIR
([Linalg](https://mlir.llvm.org/docs/Dialects/Linalg/),
[Vector](https://mlir.llvm.org/docs/Dialects/Vector/),
[GPU](https://mlir.llvm.org/docs/Dialects/GPU/) dialect, ...).
[This document](https://www.tensorflow.org/mlir/xla_gpu_codegen) provides more
information on the current migration of the XLA GPU codegen.
## MLIR Dialects for XLA-style compilation
This repository defines three dialects to support a HLO-like compilation
pipeline using MLIR:
* `chlo`: the "client" HLO dialect, intended to be closer to the frontend
(including implicit broadcast semantics).
* `mhlo`: "meta"-HLO dialect ; similar to `xla_hlo`, but with extensions for
dynamic shape support.
* `lmhlo`: "late"-"meta"-HLO, it is the IR after buffer allocation is
performed. In XLA the buffer allocation is a side-datastructure which keeps
track of these informations, while this separate dialect materializes it in
the IR.
We describe these in more details below.
### HLO Client Dialect: `chlo`.
* It was originaly designed to map the
[XLA client APIs](https://www.tensorflow.org/xla/operation_semantics) (e.g.,
ops supports implicit broadcast and roughly modeled on XlaBuilder API)
modulo support for dynamic shapes and additional ops required to support
dynamic client side HLOs.
* Ops can be from either the XlaBuilder or XLA helper functions can be
converted into ops (e.g., given ambiguity in what constitutes these ops,
there is some freedom to decide), the goal of this dialect is to correspond
close to client level and enable a thin layer between client use and op
construction (making it cheap to construct and optimizations on the dialect
close to optimizations on the client ops).
Entry:
* The vast majority of old "client" interactions are via the XlaBuilder APIs.
These APIs are used by TF2XLA kernels, JAX, PyTorch bridge and directly. The
legalization path (described below) can also reuse the XlaBuilder's APIs to
construct XLA Client HLO ops directly (this uses MlirXlaBuilder which is a
subclass of XlaBuilder).
* The other entry point is during legalization from TensorFlow ops in the TF
Graph Compiler and other tools (e.g., SavedModel lowering and TFCompile).
Exit:
* MHLO
* May be exported to xla::HloInstructionProto by invoking the XlaBuilder APIs
(with regular XlaBuilder)
The `chlo` dialect started originally as mapping to the XLA client Builder APIs.
It enables it to both be constructed and converted back to existing XLA
interfaces using the XlaBuilder API. Due to the way that translation into and
out of the dialect works, there is no expectation that this dialect roundtrips
to XLA (e.g., it is only intended to be translated to MLIR and then legalized to
another dialect or translated to HloInstructionProto).
The export approach of reusing the XlaBuilders enables reusing a lot of logic
that was already implemented in terms of computing shapes, inserting broadcasts
etc.
An important topic here is that XLA Client HLO ops are not a well defined set.
And in particular what some would consider helper functions, others would
consider ops. It should be easy to move between these and so define a new op
along with the helper function or autogenerate the helper functions from the
descriptions of the ops. For the former, a simple approach would be to simply
consider the context in which the op is being constructed and if an MLIR one,
construct a op in the client dialect instead of further calls into XlaBuilder.
The latter could be implemented by adding the op and a legalization of the op to
other known ops, from which a helper function can get generated that could be
used as regular.
Status: Exists but need to be cleaned up.
### Meta HLO Dialect `mhlo`
* Dialect is closer to current HLO server ops (e.g., no implicit broadcast)
* MHLO dialect where we can deviate from the requirements of the client or
server dialect, in particular:
* Control flow ops with implicit capture to enable simpler optimizations
(e.g., generic LICM, unroll & jam, etc.)
* Multiple results ops (e.g., no tuples)
* More ops (for example, unique op or assert op), and ops that don't need
to be added to either client or server dialect.
* Op set not constrained by implementation (e.g., hlo.add operating on say
i79 or !mydialect.weird_type is allowed even though no XLA backend
supports it). Verification on types happening at the boundaries.
* It does not need to preserve some deprecated XLA constructs (e.g.
stateful RNG HLO).
* More dynamic shape support ops without need for updating all
users/backends.
* This dialect enables evolving HLO independently from XLA in order to
experiment with features we'd like to upstream in MLIR TCP. In particular it
intends to be user-extensible through
[interfaces](https://mlir.llvm.org/docs/Interfaces/).
* It should have no TensorFlow, or proto, or other Google internal
dependencies.
* It need not be a complete superset of ops compared to XLA HLO dialect.
Entry:
* Legalization from `chlo` dialect or conversion from XLA HLO.
* Directly emitted from TF Graph Compiler;
* Builder call (e.g., EDSL);
Exit:
* LMHLO, Linalg IREE, directly used in codegen.
* XLA HLO.
The MHLO dialect has no direct export format, it is only meant as an
intermediate optimization dialect/format. It is also where we can experiment
cheaply with new ops. This format will be where the representation would differ
from existing end points.
Status: Exists but need to be cleaned up and evolved, in particular with respect
to supporting dynamic shapes.
### LMHLO
LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref)
with side-effecting operations. The lowering from `mhlo` dialect proceeds by way
of scheduling, memory and buffer allocation. The current mapping is directly on
XLA Client HLOs but without implicit broadcast and with operation on memrefs.
This dialect will instead be rebased on `mhlo` dialect but operating on buffers
still.
Entry:
* Post buffer assignment on `mhlo` dialect, or from XLA after buffer
assignment.
Exit:
* Codegen (LLVM IR in the common cases at the moment)
## End-to-End pipeline
TODO

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
if [[ $# -ne 2 ]] ; then
echo "Usage: $0 <path/to/llvm> <build_dir>"
exit 1
fi
# LLVM source
LLVM_SRC_DIR="$1"
build_dir="$2"
if ! [ -f "$LLVM_SRC_DIR/llvm/CMakeLists.txt" ]; then
echo "Expected the path to LLVM to be set correctly (got '$LLVM_SRC_DIR'): can't find CMakeLists.txt"
exit 1
fi
echo "Using LLVM source dir: $LLVM_SRC_DIR"
# Setup directories.
echo "Building MLIR in $build_dir"
mkdir -p "$build_dir"
echo "Beginning build (commands will echo)"
set -x
cmake -GNinja \
"-H$LLVM_SRC_DIR/llvm" \
"-B$build_dir" \
-DLLVM_INSTALL_UTILS=ON \
-DLLVM_ENABLE_LLD=ON \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \
-DLLVM_INCLUDE_TOOLS=ON \
-DLLVM_BUILD_TOOLS=OFF \
-DLLVM_INCLUDE_TESTS=OFF \
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DLLVM_ENABLE_ASSERTIONS=On
cmake --build "$build_dir" --target all --target mlir-cpu-runner

View File

@ -0,0 +1,2 @@
<LLVM_COMMIT>

View File

@ -0,0 +1,16 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(Dialect)

View File

@ -0,0 +1,16 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(mhlo)

View File

@ -0,0 +1,17 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(IR)
add_subdirectory(transforms)

View File

@ -0,0 +1,31 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR
function(add_mlir_hlo_dialect dialect dialect_namespace)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cc.inc -gen-op-defs)
mlir_tablegen(${dialect}_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(${dialect}_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()
add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(hlo_ops mhlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo)
add_mlir_interface(infer_fusibility_op_interface)

Some files were not shown because too many files have changed in this diff Show More