Merge branch 'master' into 16x8_pad

This commit is contained in:
Elena Zhelezina 2020-06-16 13:39:45 +01:00 committed by GitHub
commit 7452075897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3250 changed files with 170226 additions and 57732 deletions
.bazelrc.bazelversionCODE_OF_CONDUCT.mdISSUE_TEMPLATE.mdRELEASE.mdWORKSPACEconfigure.py
tensorflow
BUILD
c
cc/saved_model
compiler

View File

@ -200,6 +200,8 @@ build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nonccl --define=no_nccl_support=true
build:stackdriver_support --define=stackdriver_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
@ -386,32 +388,32 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -441,8 +443,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows"
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"

View File

@ -1 +1 @@
3.0.0
3.1.0

View File

@ -1,7 +1,11 @@
# TensorFlow Code of Conduct
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and our
community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, nationality, personal appearance, race, religion, or sexual identity
and orientation.
## Our Standards

View File

@ -4,26 +4,31 @@ https://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy:
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
1. It must be a bug, a feature request, or a significant problem with the
documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go
[here](https://github.com/tensorflow/tensorboard/issues).
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
------------------------
### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
- **Have I written custom code (as opposed to using a stock example script
provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
happens on a mobile device**:
- **TensorFlow installed from (source or binary)**:
- **TensorFlow version (use command below)**:
- **Python version**:
- **Bazel version (if compiling from source)**:
- **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
You can collect some of this information using our environment capture script:

View File

@ -90,89 +90,150 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
## Bug Fixes and Other Changes
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the linter.
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar.
* Expose option to limit the number of partitions that will be delegated to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`.
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options.
* TF Core:
* `tf.constant` always creates CPU tensors irrespective of the current
device context.
* Eager `TensorHandles` maintain a list of mirrors for any copies to local
or remote devices. This avoids any redundant copies due to op execution.
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer
experimental and is available as simply `.ref()`.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops.
Vectorizing `tf.cond` is also supported now.
* Set as much partial shape as we can infer statically within the gradient
impl of the gather op.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body
functions are stateless. This allows multiple gradients while ops to run
in parallel under distribution strategy.
* Speed up `GradientTape` in eager mode by auto-generating list of op
inputs/outputs which are unused and hence not cached for gradient
functions.
* Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Improve error message when attempting to use `None` in data-dependent
control flow.
* Add `RaggedTensor.numpy()`.
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow
indexing into uniform dimensions.
* Update `tf.expand_dims` to always insert the new dimension as a
non-ragged dimension.
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm`
when `ids` is ragged.
* Allow `batch_dims==rank(indices)` in `tf.gather`.
* Add support for bfloat16 in `tf.print`.
* `tf.distribute`:
* Support `embedding_column` with variable-length input features for
`MultiWorkerMirroredStrategy`.
* `tf.keras`:
* Added `experimental_aggregate_gradients` argument to
`tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom
gradient aggregation and processing aggregated gradients in custom
training loop.
* Allow `pathlib.Path` paths for loading models via Keras API.
* `tf.function`/AutoGraph:
* AutoGraph is now available in `ReplicaContext.merge_call`,
`Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Experimental support for shape invariants has been enabled in
`tf.function`. See the API docs for
`tf.autograph.experimental.set_loop_options` for additonal info.
* AutoGraph error messages now exclude frames corresponding to APIs
internal to AutoGraph.
* Improve shape inference for `tf.function` input arguments to unlock more
Grappler optimizations in TensorFlow 2.x.
* Improve automatic control dependency management of resources by allowing
resource reads to occur in parallel and synchronizing only on writes.
* Fix execution order of multiple stateful calls to `experimental_run_v2`
in `tf.function`.
* You can now iterate over `RaggedTensors` using a for loop inside
`tf.function`.
* `tf.lite`:
* Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android
10
* TFLite Android AARs now include the C headers and APIs are required to
use TFLite from native code.
* Refactors the delegate and delegate kernel sources to allow usage in the
linter.
* Limit delegated ops to actually supported ones if a device name is
specified or `NNAPI` CPU Fallback is disabled.
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* TFLite's unpack op now supports boolean tensor inputs.
* Microcontroller and embedded code moved from experimental to main
TensorFlow Lite folder
* Check for large TFLite tensors.
* Fix GPU delegate crash with C++17.
* Add 5D support to TFLite `strided_slice`.
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to
be accelerated.
* Fix segmentation fault when running a model with LSTM nodes using
`NNAPI` Delegate
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum
operation is a scalar.
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a
scalar.
* Expose option to limit the number of partitions that will be delegated
to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine
operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left
rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`,
`tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int`
types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to
tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204),
tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to
`tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip
package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning
rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip
package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation()
on 32-bit ARM. This ensures a deterministic early exit instead of a hard
to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to
XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable
`TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends
deterministic `tf.nn.bias_add` back-prop functionality (and therefore
also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm
worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier
debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC
[#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing
error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to
`tensorflow/core/framework:*_pyclif`.
## Thanks to our Contributors

View File

@ -114,6 +114,14 @@ http_archive(
],
)
http_archive(
name = "person_detect_data",
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
],
)
# Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0'
_TF_MIN_BAZEL_VERSION = '3.1.0'
_TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [
@ -484,8 +484,8 @@ def check_bazel_version(min_version, max_version):
stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'],
allow_non_zero = True,
stderr = stderr)
allow_non_zero=True,
stderr=stderr)
if curr_version.startswith('bazel '):
curr_version = curr_version.split('bazel ')[1]
@ -1011,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
default_cuda_compute_capabilities = native_cuda_compute_capabilities
ask_cuda_compute_capabilities = (
'Please specify a list of comma-separated '
'CUDA compute capabilities you want to '
'build with.\nYou can find the compute '
'capability of your device at: '
'https://developer.nvidia.com/cuda-gpus.\nPlease'
' note that each additional compute '
'capability significantly increases your '
'build time and binary size, and that '
'TensorFlow only supports compute '
'capabilities >= 3.5 [Default is: %s]: ' %
default_cuda_compute_capabilities)
'Please specify a list of comma-separated CUDA compute capabilities '
'you want to build with.\nYou can find the compute capability of your '
'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
'can be specified as "x.y" or "compute_xy" to include both virtual and'
' binary GPU code, or as "sm_xy" to only include the binary '
'code.\nPlease note that each additional compute capability '
'significantly increases your build time and binary size, and that '
'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
'%s]: ' % default_cuda_compute_capabilities)
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
@ -1033,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
# We now support sm_35,sm_50,sm_60,compute_70.
sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
compute_capability)
if not sm_compute_match:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
else:
ver = int(sm_compute_match.group(2))
if ver < 30:
print(
'ERROR: TensorFlow only supports small CUDA compute'
' capabilities of sm_30 and higher. Please re-specify the list'
' of compute capabilities excluding version %s.' % ver)
all_valid = False
if ver < 35:
print('WARNING: XLA does not support CUDA compute capabilities '
'lower than sm_35. Disable XLA when running on older GPUs.')
else:
ver = float(m.group(0))
if ver < 3.0:
@ -1225,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
See also
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be
@ -1372,7 +1386,7 @@ def main():
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e:
print("Error checking bazel version: ", e.output.decode('UTF-8').strip())
print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)

View File

@ -298,6 +298,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# Experimental features
config_setting(
name = "stackdriver_support",
define_values = {"stackdriver_support": "true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
@ -531,6 +538,7 @@ package_group(
# Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported.
@ -540,6 +548,11 @@ package_group(
packages = ["//learning/deepmind/tensorflow/replicator/..."],
)
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "structured_tensor_whitelist")
filegroup(
name = "intel_binary_blob",
data = if_mkl_ml(

View File

@ -216,6 +216,7 @@ tf_cuda_library(
],
visibility = [
"//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
],
deps = select({
"//tensorflow:android": [

View File

@ -589,14 +589,16 @@ void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response);
if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response;
}
@ -1384,6 +1386,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
cpp_type v; \
status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
if (!status->status.ok()) return; \
*value = static_cast<c_type>(v); \
} \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
@ -2178,6 +2181,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
}
return new_session;
} else {
LOG(ERROR) << status->status;
DCHECK_EQ(nullptr, session);
return nullptr;
}

View File

@ -144,6 +144,24 @@ cc_library(
],
)
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"],
@ -184,7 +202,6 @@ cc_library(
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -351,8 +368,41 @@ tf_cuda_cc_test(
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",
srcs = [
"c_api_distributed_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
"notsan", # b/157098283
],
deps = [
":c_api",
@ -383,7 +433,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
@ -514,6 +567,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",

View File

@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
@ -1479,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
}
destination->CopyAttributes(*tensorflow::unwrap(attrs));
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -30,26 +30,6 @@ namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie();

View File

@ -0,0 +1,506 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -35,26 +35,6 @@ namespace {
using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
@ -356,472 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
/*heavy_load_on_streaming_rpc=*/true);
}
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string;
@ -296,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
TF_DeleteDeviceList(devices);
return false;
}
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
// Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
@ -72,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
const char* device_type);
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
int num_tasks);
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(int num_tasks);
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_

View File

@ -58,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa).
class AbstractTensor {
protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public:
@ -101,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors.
class AbstractOp {
protected:
enum AbstractOpKind { kGraphOp, kEagerOp };
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public:
@ -129,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation.
struct ExecutionContext {
protected:
enum ExecutionContextKind { kGraphContext, kEagerContext };
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public:

View File

@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace
} // namespace tensorflow

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -84,11 +84,10 @@ class AbstractContextInterface {
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
tensorflow::Status* status) = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
// code that is layered on top of the runtime.
virtual bool UsesTFRT() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
@ -104,6 +103,14 @@ class AbstractContextInterface {
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -12,39 +12,98 @@ package(
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
name = "headers",
name = "lib_headers",
srcs = ["parallel_device_lib.h"],
)
filegroup(
name = "lib_sources",
srcs = ["parallel_device_lib.cc"],
)
filegroup(
name = "device_headers",
srcs = ["parallel_device.h"],
)
filegroup(
name = "device_sources",
srcs = ["parallel_device.cc"],
)
filegroup(
name = "headers",
srcs = [
":device_headers",
":lib_headers",
],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
srcs = ["parallel_device.cc"],
srcs = [
":device_sources",
":lib_sources",
],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
srcs = [":sources"],
hdrs = [":headers"],
srcs = [":device_sources"],
hdrs = [":device_headers"],
visibility = ["//tensorflow:internal"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_lib",
srcs = [":lib_sources"],
hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_testlib",
testonly = 1,
srcs = ["parallel_device_testlib.cc"],
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "parallel_device_test",
srcs = ["parallel_device_test.cc"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
@ -55,6 +114,27 @@ tf_cc_test(
],
)
tf_cc_test(
name = "parallel_device_remote_test",
srcs = ["parallel_device_remote_test.cc"],
# TODO(b/136478427): Enable global heap checking when servers shut down
# cleanly.
args = ["--heap_check=local"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests.
filegroup(

View File

@ -23,25 +23,13 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@ -49,224 +37,46 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
// placed on the parallel device.
class NamedParallelDevice {
public:
ParallelDevice(const std::string& name,
const std::vector<std::string>& devices);
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
NamedParallelDevice(const std::string& name,
std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
const std::string& name() const { return device_name_; }
const ParallelDevice& device() const { return *parallel_device_; }
private:
// The name of the parallel device
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
std::string device_name_;
std::unique_ptr<ParallelDevice> parallel_device_;
};
// The internal representation of a TFE_TensorHandle placed on a
// ParallelDevice. Contains a tuple of tensors, one on each of the
// `underlying_devices_` of the ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
const ParallelDevice& parallel_device,
const std::string& parallel_device_name, TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != underlying_devices_.size()) {
if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
inputs.size()));
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
@ -289,7 +99,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
@ -300,10 +110,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) {
if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ",
underlying_devices_.size(),
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
@ -329,15 +139,38 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status));
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
parallel_inputs.reserve(inputs.size());
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
for (const auto& input : inputs) {
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
std::unique_ptr<ParallelTensor> parallel_tensor(
parallel_device.CopyToParallelDevice(
context, absl::get<TFE_TensorHandle*>(input), status));
if (TF_GetCode(status) != TF_OK) return result;
parallel_inputs.push_back(parallel_tensor.get());
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
} else {
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
}
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
parallel_device.Execute(context, parallel_inputs, operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
@ -351,144 +184,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
@ -496,17 +191,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t,
TF_Status* status) {
TensorHandlePtr ParallelTensorToTensorHandle(
const std::string& parallel_device_name, TFE_Context* context,
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_,
t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
&ParallelTensorDeallocator, nullptr, status));
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@ -522,12 +218,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status));
dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
status)
return ParallelTensorToTensorHandle(named_device->name(), context,
std::move(parallel_tensor), status)
.release();
}
@ -561,14 +259,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) {
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@ -580,8 +279,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
*num_outputs, status));
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@ -602,8 +302,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensor::AsTensorHandle(
context,
outputs[i] = ParallelTensorToTensorHandle(
named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
@ -620,7 +320,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info);
delete reinterpret_cast<NamedParallelDevice*>(device_info);
}
} // namespace
@ -639,8 +339,10 @@ void AllocateParallelDevice(const char* device_name,
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
*device_info = new ParallelDevice(device_name, underlying_devices_vector);
std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
}
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace eager {
namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
} // namespace eager
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,376 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class StatusDeleter {
public:
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
};
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
const std::vector<ParallelTensor*>& inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
int first_op_output_count = 0;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
std::vector<TFE_TensorHandle*> device_inputs;
device_inputs.reserve(device_inputs.size());
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
// Parallel tensors are divided between operations by device.
device_inputs.push_back(inputs[input_index]->tensor(device_index));
}
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace parallel_device {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ParallelTensor;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// The number of devices operations run on.
size_t num_underlying_devices() const { return underlying_devices_.size(); }
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors. Wraps the
// resulting per-device and per-output TFE_TensorHandles into one
// ParallelTensor per output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of thread wrappers, one per device, for executing operations in
// parallel.
//
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
// ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
TF_DataType dtype() const { return dtype_; }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0");
worker_server1.release();
worker_server2.release();
}
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), in_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Loop to make synchronization failures more deterministic
for (int i = 0; i < 100; ++i) {
TensorHandlePtr multiply_result(
Multiply(context.get(), combined_value.get(), combined_value.get(),
status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> out_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&out_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(out_components[0].get(), 9.);
ExpectScalarEq<float>(out_components[1].get(), 4.);
}
worker_server1.release();
worker_server2.release();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
@ -28,390 +29,6 @@ limitations under the License.
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -790,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -806,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{
@ -835,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {

View File

@ -0,0 +1,308 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status);
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status);
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status);
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status);
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device);
// Implementations of templated functions ******************************
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_

View File

@ -0,0 +1,31 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for GCS environments
tf_cc_shared_object(
name = "gcs_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":gcs_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
],
)

View File

@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(vnvo2409): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(vnvo2409): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 1;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "gs");
}

View File

@ -57,6 +57,7 @@ cc_library(
":concrete_function",
":saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) {
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI {
public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out);
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default;
private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_;
};

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
@ -34,10 +38,21 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i]));
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) {
return nullptr;
}

View File

@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
":loader_util",
":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
],
)
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test(
name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK();
}
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -67,13 +67,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -94,8 +94,8 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:Support", # fixdeps: keep
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
],
)
@ -109,12 +109,12 @@ cc_library(
name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:Target",
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]),
)
@ -286,9 +286,9 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:core",
"@llvm-project//llvm:support",
"@llvm-project//llvm:target",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
],
)

View File

@ -1,5 +1,5 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
# Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and

View File

@ -4,6 +4,7 @@ traces: {
value: {
file_line_cols: {
line: 1
col: 1
}
}
}
@ -12,9 +13,11 @@ traces: {
value: {
file_line_cols: {
line: 3
col: 1
}
file_line_cols: {
line: 4
col: 1
}
}
}
@ -23,6 +26,7 @@ traces: {
value: {
file_line_cols: {
line: 2
col: 1
}
}
}

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list;
absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names;
};
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -395,12 +395,11 @@ static void ShowXlaDeviceDeprecationWarning(
if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] {
LOG(WARNING)
<< "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
"removed in subsequent releases. Instead, use either "
"@tf.function(experimental_compile=True) for must-compile "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"for auto-clustering best-effort compilation.";
});
}
}

View File

@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def), ".\n");
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message =

View File

@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
// Pass remaining parameters.
const Tensor* t;
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = arg_buffers_[i].get();
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
}
}
}
@ -470,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,

View File

@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};

View File

@ -27,10 +27,6 @@ namespace tensorflow {
return xla_tensor;
}
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor);

View File

@ -39,8 +39,6 @@ class XlaTensor {
// fails.
static XlaTensor* FromTensor(const Tensor* tensor);
static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is
@ -57,7 +55,7 @@ class XlaTensor {
// manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
const xla::ShapedBuffer& shaped_buffer() const {
@ -70,8 +68,7 @@ class XlaTensor {
}
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
shaped_buffer_ = std::move(shaped_buffer);
}
// Some tensors on the device may have known values on the host. We use these
@ -79,14 +76,12 @@ class XlaTensor {
// host value already.
// Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
bool has_host_tensor() const { return host_tensor_.has_value(); }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()
const Tensor& host_tensor() const { return *host_tensor_; }
// Sets the contained host tensor.
void set_host_tensor(const Tensor& tensor) {
host_tensor_.reset(new Tensor(tensor));
}
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
// Adds synchronization events to 'stream' that wait for this tensor to be
// defined on 'stream'. Does nothing if the tensor is already defined on that
@ -113,9 +108,9 @@ class XlaTensor {
private:
// The optional contained ShapedBuffer.
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
absl::optional<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -42,7 +42,7 @@ cc_library(
":init_mlir",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -102,8 +102,9 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
],
alwayslink = 1,
@ -154,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",

View File

@ -216,16 +216,16 @@ cc_library(
"ir/tfl_ops.h",
"transforms/passes.h",
"utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
@ -253,7 +253,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -272,23 +272,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -305,7 +289,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
@ -320,7 +304,7 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -330,6 +314,7 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
@ -373,7 +358,7 @@ cc_library(
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -399,7 +384,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -432,7 +417,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -457,7 +442,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -510,8 +495,8 @@ tf_native_cc_binary(
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -557,8 +542,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:analysis",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -635,7 +620,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -669,7 +654,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -729,7 +714,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
@ -759,7 +744,7 @@ cc_library(
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)
@ -771,7 +756,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -796,7 +781,7 @@ tf_cc_binary(
":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -821,7 +806,7 @@ tf_cc_binary(
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -890,7 +875,7 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -910,6 +895,6 @@ cc_library(
"//tensorflow/lite/experimental/mlir:__subpackages__",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

View File

@ -32,7 +32,6 @@ struct PassConfig {
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false),
unfold_batch_matmul(true),
legalize_tf_while(true),
@ -49,13 +48,8 @@ struct PassConfig {
llvm::ArrayRef<std::string> trim_functions_whitelist;
// All information about quantization.
QuantizationSpecs quant_specs;
// If `skip_control_dialect` is true, TF executor dialect is not converted to
// TF control dialect prior to legalization to TF Lite.
// TODO(b/142911013): Remove flag once control dialect is removed.
bool skip_control_dialect;
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
// If `form_clusters` is true , clusters are formed by grouping consecutive
// ops of the same device, under a `tf_device.launch` op.
bool form_clusters;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops.

View File

@ -525,11 +525,16 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue;
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt(
" if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx));
" if (failure_on_operand_type_mismatch) {\n"
" return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
}
os << " return top.verify();\n}\n";
}

View File

@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
@ -695,8 +699,6 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) {
queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
}
}
@ -801,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
}
for (auto output : func_outputs) {
bool is_constant = !is_op_output[output];
const bool is_func_input = input_index_set.contains(output);
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
// 1. `is_constant` = true, means this tensor is created from a constant op.
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
// tensor is function input and function input type is a scalar tensor.
const bool shapeless_is_scalar =
is_constant || (is_func_input && is_entry_point);
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
/*shapeless_are_scalars=*/is_constant,
shapeless_is_scalar,
/*is_constant=*/is_constant);
if (!type_or_err.ok()) {
emitError(func_loc, "error reading return types")
@ -858,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs));
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -46,28 +46,183 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL {
// Returns true when the given two types have the same shape or broadcastable
// shape within the given rank. If any given shapes are non-static, this method
// returns true.
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs,
int max_bcast_rank) {
// Ignore shape checking on the non-static shapes for model compatibility.
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>();
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
// Returns true when the given operand arguments have the same shape or
// broadcastable shape within the given rank. If any given shapes are
// non-static and maximum rank is within the given rank, this method returns
// true.
bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
if (indices.empty()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape()))
return true;
// First, it checks there are any inputs that has unknown rank.
bool has_unknown_shape_input = false;
bool has_same_shape = true;
bool reach_first_known_shape = false;
int64_t max_rank = -1;
ArrayRef<int64_t> pivot_shape;
SmallVector<int64_t, 4> current_shape;
SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(),
result_shape)) {
return false;
for (unsigned index : indices) {
ShapedType shaped_type =
op->getOperand(index).getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasRank()) {
// Marks that we have an unknown rank input.
has_unknown_shape_input = true;
continue;
}
max_rank = std::max(max_rank, shaped_type.getRank());
if (!shaped_type.hasStaticShape()) {
// Marks that we have an unknown shape input.
has_unknown_shape_input = true;
continue;
}
ArrayRef<int64_t> shape = shaped_type.getShape();
if (!reach_first_known_shape) {
pivot_shape = shape;
current_shape.assign(shape.begin(), shape.end());
reach_first_known_shape = true;
continue;
}
if (!pivot_shape.equals(shape)) {
has_same_shape = false;
}
// Checks if all the inputs are broadcastable since they have not all the
// same shapes.
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
result_shape)) {
return false;
}
current_shape = result_shape;
}
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank;
// It will treat the unknown shape inputs as acceptable inputs for model
// compatibility unless there is an known rank that is bigger than the allowed
// broadcast maximum rank.
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
// If all the shape is known and same, CPU kernels are able to handle inputs
// regardless of dimension size.
return has_same_shape || max_rank <= max_bcast_rank;
}
// Return true when the given element_type is QI8.
bool IsQI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
quantized_type.isSigned();
}
// Return true when the given element_type is QUI8.
bool IsQUI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
!quantized_type.isSigned();
}
// Return true when the given element_type is QI16.
bool IsQI16Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 16 &&
quantized_type.isSigned();
}
// Return true when the given element_type is I32.
bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
// Allows QI16 output when operands have the same shape.
if (IsQI16Type(element_type)) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return false;
}
// Return true if the given Sub operation has the CPU kernel supported shapes.
bool VerifySubOpShapeConstraints(SubOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows QI8 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsQI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
// Return true if the given Mul operation has the CPU kernel supported shapes.
bool VerifyMulOpShapeConstraints(MulOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
// output type is not QI16. If the output type is Q16, allows onlt the same
// shape operands.
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
//===----------------------------------------------------------------------===//
@ -1882,7 +2037,7 @@ static LogicalResult Verify(TransposeConvOp op) {
auto expected_output_type =
RankedTensorType::get(output_shape, output_type.getElementType());
if (output_type != expected_output_type) {
if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}
@ -1966,9 +2121,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x().getType().cast<ShapedType>();
auto input_type = op.input().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>();
auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError(
@ -2004,7 +2159,8 @@ static LogicalResult Verify(TransposeOp op) {
}
auto expected_output_type =
RankedTensorType::get(transposed_shape, input_type.getElementType());
if (output_type != expected_output_type) {
if (failed(
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type));
}

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -56,7 +56,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -85,7 +85,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

View File

@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_mins;
std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
if (flag.has_mean_value() && flag.has_std_value()) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(),
flag.std_value(), inference_type));
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
} else {
node_mins->push_back(llvm::None);
node_maxs->push_back(llvm::None);
}
}
}
@ -254,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename);
return errors::InvalidArgument("Failed to open file in ", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));

View File

@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each
// input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs);
Status PopulateQuantizationSpecs(
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<llvm::Optional<double>>* node_mins,
std::vector<llvm::Optional<double>>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.

View File

@ -3,6 +3,10 @@ load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
)
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package(
default_visibility = [
@ -23,6 +27,7 @@ package_group(
exports_files([
"quantization_traits.h",
"quantization_config.h",
"quantization_utils.h",
])
filegroup(
@ -34,6 +39,25 @@ filegroup(
],
)
gentbl(
name = "quantization_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"quantization_interface.h.inc",
),
(
"-gen-op-interface-defs",
"quantization_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "quantization.td",
td_srcs = [
":quantization_td_files",
],
)
tf_proto_library(
name = "quantization_info_proto",
srcs = [
@ -56,7 +80,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -71,16 +95,18 @@ cc_library(
name = "quantization_lib",
srcs = [
"quantization_driver.cc",
"quantization_interface.cc.inc",
"quantization_utils.cc",
],
hdrs = [
"quantization_interface.h.inc",
"quantization_traits.h",
"quantization_utils.h",
],
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -99,7 +125,7 @@ cc_library(
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -109,8 +135,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -131,7 +157,7 @@ cc_library(
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
@ -146,7 +172,7 @@ cc_library(
":device_target",
":quantization_lib",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",

View File

@ -36,7 +36,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
@ -49,14 +49,16 @@ cc_library(
],
hdrs = [
"tfl_to_std.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
],
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
@ -71,7 +73,7 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
],
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir {
namespace TFL {
@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg());
dq.getResult().replaceAllUsesWith(dcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
q.erase();
}
});

View File

@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
// https://www.tensorflow.org/lite/performance/quantization_spec
//===----------------------------------------------------------------------===//
// TODO(b/157870442): replace all FixedResultScale trait
def FixedOutputRangeInterface : OpInterface<
"FixedOutputRangeInterface"> {
let description = [{
Interface for defining the fixed output range.
}];
let methods = [
InterfaceMethod<
[{Returns the fixed output range.}],
"UniformQuantizedType", "GetFixedOutputRange",
(ins "bool":$sign, "int":$bit_width)
>,
];
}
// Specify this trait if the op has a fixed output value range.
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;

View File

@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view inference_type,
QuantizationSpecs* quant_specs) {
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
std::vector<double> node_mins;
std::vector<llvm::Optional<double>> node_mins;
if (!min_values.empty()) {
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
for (int i = 0; i < node_mins_str.size(); i++) {
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
}
}
std::vector<double> node_maxs;
std::vector<llvm::Optional<double>> node_maxs;
if (!max_values.empty()) {
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
for (int i = 0; i < node_maxs_str.size(); i++) {
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
quant_specs);
}
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs) {
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
quant_specs->inference_type = inference_type;
// If min/max are not specified, just return;

View File

@ -69,7 +69,8 @@ struct QuantizationSpecs {
// arguments. They are only used when `weight_quantization` is set to false,
// and the model is required to have quantization parameters, either from
// quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges;
std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
input_ranges;
// The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests.
@ -130,11 +131,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
// Gets the quantization specification for input arrays. The array names are not
// stored in the spec, and will be matched by position. The min/max will be
// ignored if the inference_type isn't a quantized type. Returns true if failed.
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names,
const std::vector<double>& node_mins,
const std::vector<double>& node_maxs,
tensorflow::DataType inference_type,
QuantizationSpecs* quant_specs);
bool GetInputNodeQuantSpecs(
const std::vector<std::string>& node_names,
const std::vector<llvm::Optional<double>>& node_mins,
const std::vector<llvm::Optional<double>>& node_maxs,
tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
} // namespace TFL
} // namespace mlir

View File

@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult());
// This attribute is set to distinguish the quantize ops being added by the
// quantization pass. These ops can be removed without losing original
// program accuracy.
// TODO(fengliuai): make the attribute being part of op definition.
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
// `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize);

View File

@ -21,13 +21,18 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace OpTrait {
namespace quant {
using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType;
namespace mlir {
// This includes the interface class definition. It couldn't be in a namespace
// because the table gen doesn't emit the namespace when it is used.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
namespace OpTrait {
namespace quant {
// The base class that all the quantization related OpTrait implements.
template <typename ConcreteType, template <typename> class TraitType>
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
@ -436,6 +437,16 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
llvm::DenseSet<Operation*> redundant_stats_ops;
// Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
// op in case it overrides the information from training FakeQuant ops.
func.walk([&](quant::QuantizeCastOp q) {
auto input_op = q.arg().getDefiningOp();
if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
q.setOperand(stats.arg());
if (stats.use_empty()) stats.erase();
}
});
// Step 1: forward pass: propagate any value scales which are not produces
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales
// which are produced by the `restricted_output_params`.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -42,6 +43,11 @@ limitations under the License.
namespace mlir {
namespace quant {
// A unit attribute can be attached to the quantize/dequantize ops which are
// added by the quantization passes. These ops can be removed erased without
// losing accuracy.
constexpr char kVolatileOpAttrName[] = "volatile";
using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
@ -380,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
if (llvm::isa<FixedOutputRangeInterface>(def) ||
def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s
// CHECK-LABEL: import_stats_skip

View File

@ -32,7 +32,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s --dump-input-on-failure
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s
// Checks that tfl.reshape should be removed if its output's only user is
// another tfl.reshape

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: @add_float
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32>

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
# Add two tensor<4xi32> inputs and return the result

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s --dump-input-on-failure
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s
node {

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s --dump-input-on-failure
# RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s
node {
name: "tf.Const"

View File

@ -0,0 +1,421 @@
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
node {
name: "tf.Less"
op: "Less"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "my_equal"
op: "Equal"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "cst0"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 1.0
float_val: 2.0
float_val: 3.0
float_val: 4.0
}
}
}
}
node {
name: "cst1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 5.0
float_val: 6.0
float_val: 7.0
float_val: 8.0
}
}
}
}
node {
name: "StatefulIf"
op: "If"
input: "tf.Less"
input: "a"
input: "b"
input: "cst0"
input: "cst1"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true"
}
}
}
experimental_debug_info {
}
}
node {
name: "StatelessIf"
op: "StatelessIf"
input: "my_equal"
input: "a"
input: "b"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false_1"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true_1"
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "StatefulIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "main1"
op: "_Retval"
input: "StatelessIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "a"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "cond_true"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Add"
op: "Add"
input: "cond_true_arg2"
input: "cond_true_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Add"
}
}
ret {
key: "cond_true_ret"
value: "tf.Add:z:0"
}
}
function {
signature {
name: "cond_false"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Mul"
op: "Mul"
input: "cond_false_arg0"
input: "cond_false_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Mul"
}
}
ret {
key: "cond_false_ret"
value: "tf.Mul:z:0"
}
}
function {
signature {
name: "cond_true_1"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Sub"
op: "Sub"
input: "cond_true_arg0"
input: "cond_true_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Sub"
}
}
ret {
key: "cond_true_ret"
value: "tf.Sub:z:0"
}
}
function {
signature {
name: "cond_false_1"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Div"
op: "Div"
input: "cond_false_arg0"
input: "cond_false_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Div"
}
}
ret {
key: "cond_false_ret"
value: "tf.Div:z:0"
}
}
}
versions {
producer: 115
min_consumer: 12
}
# CHECK: func @StatefulIf_else
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK-NEXT: tfl.mul
# CHECK: func @StatefulIf_then
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
# CHECK-NEXT: return
# CHECK: func @StatelessIf_else
# CHECK-NEXT: tfl.div
# CHECK: func @StatelessIf_then
# CHECK-NEXT: tfl.sub
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then

View File

@ -54,7 +54,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -70,6 +70,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure basic_lstm roundtrip exactly
func @main(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>) -> tensor<1x96xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure constants roundtrip exactly
func @bool() -> tensor<4xi1> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck %s
// Ensure that `tfl.external_const` is imported when the flag `-use-external-constant` is enabled.
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm function references in if ops are preserved
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// This only test the exporter and importer are working without min/max quantization parameters.
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Tests -input-arrays flag.
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Tests input and output names from FlatBuffer are added to `tf.entry_function` attribute.

View File

@ -1,15 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure lstm roundtrip exactly
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
}

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm a wide array of attribute survives the round-trip
func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm float constants and operators survive a roundtrip
func @main(tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Test to make sure optional parameters survive a roundtrip
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.

Some files were not shown because too many files have changed in this diff Show More