Merge branch 'master' into 16x8_pad

This commit is contained in:
Elena Zhelezina 2020-06-16 13:39:45 +01:00 committed by GitHub
commit 7452075897
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3250 changed files with 170226 additions and 57732 deletions

View File

@ -200,6 +200,8 @@ build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true build:nohdfs --define=no_hdfs_support=true
build:nonccl --define=no_nccl_support=true build:nonccl --define=no_nccl_support=true
build:stackdriver_support --define=stackdriver_support=true
build --define=use_fast_cpp_protos=true build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true build --define=allow_oversize_protos=true
@ -386,32 +388,32 @@ build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1
test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --host_platform="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --platforms="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true
test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64"
build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt"
build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl"
build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true
build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7"
build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5"
build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6"
build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7"
build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8"
build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain"
@ -441,8 +443,8 @@ build:rbe_linux_py3 --python_path="/usr/bin/python3"
build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python" build:rbe_linux_py3 --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-manylinux2010-py3_config_python"
build:rbe_win --config=rbe build:rbe_win --config=rbe
build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:toolchain" build:rbe_win --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:toolchain"
build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/bazel_211:cc-toolchain-x64_windows" build:rbe_win --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/win/tf_win_08062020:cc-toolchain-x64_windows"
build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8" build:rbe_win --host_javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8" build:rbe_win --javabase="@org_tensorflow//third_party/toolchains/preconfig/win:windows_jdk8"
build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019" build:rbe_win --extra_execution_platforms="@org_tensorflow//third_party/toolchains/preconfig/win:rbe_windows_ltsc2019"

View File

@ -1 +1 @@
3.0.0 3.1.0

View File

@ -1,7 +1,11 @@
# TensorFlow Code of Conduct # TensorFlow Code of Conduct
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and our
community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, gender identity and expression, level of
experience, nationality, personal appearance, race, religion, or sexual identity
and orientation.
## Our Standards ## Our Standards

View File

@ -4,26 +4,31 @@ https://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy: If you open a GitHub issue, here is our policy:
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead). 1. It must be a bug, a feature request, or a significant problem with the
2. The form below must be filled out. documentation (for small docs fixes please send a PR instead).
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). 2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go
[here](https://github.com/tensorflow/tensorboard/issues).
**Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. **Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.
------------------------ ------------------------
### System information ### System information
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: - **Have I written custom code (as opposed to using a stock example script
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**: provided in TensorFlow)**:
- **TensorFlow installed from (source or binary)**: - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
- **TensorFlow version (use command below)**: - **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue
- **Python version**: happens on a mobile device**:
- **Bazel version (if compiling from source)**: - **TensorFlow installed from (source or binary)**:
- **GCC/Compiler version (if compiling from source)**: - **TensorFlow version (use command below)**:
- **CUDA/cuDNN version**: - **Python version**:
- **GPU model and memory**: - **Bazel version (if compiling from source)**:
- **Exact command to reproduce**: - **GCC/Compiler version (if compiling from source)**:
- **CUDA/cuDNN version**:
- **GPU model and memory**:
- **Exact command to reproduce**:
You can collect some of this information using our environment capture script: You can collect some of this information using our environment capture script:

View File

@ -90,89 +90,150 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https:
* The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3. * The current TensorFlow release now **requires** [gast](https://pypi.org/project/gast/) version 0.3.3.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* `tf.data`:
* Removed `autotune_algorithm` from experimental optimization options. * `tf.data`:
* TF Core: * Removed `autotune_algorithm` from experimental optimization options.
* `tf.constant` always creates CPU tensors irrespective of the current device context. * TF Core:
* Eager `TensorHandles` maintain a list of mirrors for any copies to local or remote devices. This avoids any redundant copies due to op execution. * `tf.constant` always creates CPU tensors irrespective of the current
* For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer experimental and is available as simply `.ref()`. device context.
* `pfor/vectorized_map`: Added support for vectorizing 56 more ops. Vectorizing `tf.cond` is also supported now. * Eager `TensorHandles` maintain a list of mirrors for any copies to local
* Set as much partial shape as we can infer statically within the gradient impl of the gather op. or remote devices. This avoids any redundant copies due to op execution.
* Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body functions are stateless. This allows multiple gradients while ops to run in parallel under distribution strategy. * For `tf.Tensor` & `tf.Variable`, `.experimental_ref()` is no longer
* Speed up `GradientTape` in eager mode by auto-generating list of op inputs/outputs which are unused and hence not cached for gradient functions. experimental and is available as simply `.ref()`.
* Support `back_prop=False` in `while_v2` but mark it as deprecated. * `pfor/vectorized_map`: Added support for vectorizing 56 more ops.
* Improve error message when attempting to use `None` in data-dependent control flow. Vectorizing `tf.cond` is also supported now.
* Add `RaggedTensor.numpy()`. * Set as much partial shape as we can infer statically within the gradient
* Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow indexing into uniform dimensions. impl of the gather op.
* Update `tf.expand_dims` to always insert the new dimension as a non-ragged dimension. * Gradient of `tf.while_loop` emits `StatelessWhile` op if `cond` and body
* Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm` when `ids` is ragged. functions are stateless. This allows multiple gradients while ops to run
* Allow `batch_dims==rank(indices)` in `tf.gather`. in parallel under distribution strategy.
* Add support for bfloat16 in `tf.print`. * Speed up `GradientTape` in eager mode by auto-generating list of op
* `tf.distribute`: inputs/outputs which are unused and hence not cached for gradient
* Support `embedding_column` with variable-length input features for `MultiWorkerMirroredStrategy`. functions.
* `tf.keras`: * Support `back_prop=False` in `while_v2` but mark it as deprecated.
* Added `experimental_aggregate_gradients` argument to `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom gradient aggregation and processing aggregated gradients in custom training loop. * Improve error message when attempting to use `None` in data-dependent
* Allow `pathlib.Path` paths for loading models via Keras API. control flow.
* `tf.function`/AutoGraph: * Add `RaggedTensor.numpy()`.
* AutoGraph is now available in `ReplicaContext.merge_call`, `Strategy.extended.update` and `Strategy.extended.update_non_slot`. * Update `RaggedTensor.__getitem__` to preserve uniform dimensions & allow
* Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for `tf.autograph.experimental.set_loop_options` for additonal info. indexing into uniform dimensions.
* AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph. * Update `tf.expand_dims` to always insert the new dimension as a
* Improve shape inference for `tf.function` input arguments to unlock more Grappler optimizations in TensorFlow 2.x. non-ragged dimension.
* Improve automatic control dependency management of resources by allowing resource reads to occur in parallel and synchronizing only on writes. * Update `tf.embedding_lookup` to use `partition_strategy` and `max_norm`
* Fix execution order of multiple stateful calls to `experimental_run_v2` in `tf.function`. when `ids` is ragged.
* You can now iterate over `RaggedTensors` using a for loop inside `tf.function`. * Allow `batch_dims==rank(indices)` in `tf.gather`.
* `tf.lite`: * Add support for bfloat16 in `tf.print`.
* Migrated the `tf.lite` C inference API out of experimental into lite/c. * `tf.distribute`:
* Add an option to disallow `NNAPI` CPU / partial acceleration on Android 10 * Support `embedding_column` with variable-length input features for
* TFLite Android AARs now include the C headers and APIs are required to use TFLite from native code. `MultiWorkerMirroredStrategy`.
* Refactors the delegate and delegate kernel sources to allow usage in the linter. * `tf.keras`:
* Limit delegated ops to actually supported ones if a device name is specified or `NNAPI` CPU Fallback is disabled. * Added `experimental_aggregate_gradients` argument to
* TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`. `tf.keras.optimizer.Optimizer.apply_gradients`. This allows custom
* TFLite's unpack op now supports boolean tensor inputs. gradient aggregation and processing aggregated gradients in custom
* Microcontroller and embedded code moved from experimental to main TensorFlow Lite folder training loop.
* Check for large TFLite tensors. * Allow `pathlib.Path` paths for loading models via Keras API.
* Fix GPU delegate crash with C++17. * `tf.function`/AutoGraph:
* Add 5D support to TFLite `strided_slice`. * AutoGraph is now available in `ReplicaContext.merge_call`,
* Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to be accelerated. `Strategy.extended.update` and `Strategy.extended.update_non_slot`.
* Fix segmentation fault when running a model with LSTM nodes using `NNAPI` Delegate * Experimental support for shape invariants has been enabled in
* Fix `NNAPI` delegate failure when an operand for Maximum/Minimum operation is a scalar. `tf.function`. See the API docs for
* Fix `NNAPI` delegate failure when Axis input for reduce operation is a scalar. `tf.autograph.experimental.set_loop_options` for additonal info.
* Expose option to limit the number of partitions that will be delegated to `NNAPI`. * AutoGraph error messages now exclude frames corresponding to APIs
* If a target accelerator is specified, use its feature level to determine operations to delegate instead of SDK version. internal to AutoGraph.
* `tf.random`: * Improve shape inference for `tf.function` input arguments to unlock more
* Various random number generation improvements: Grappler optimizations in TensorFlow 2.x.
* Add a fast path for default `random_uniform` * Improve automatic control dependency management of resources by allowing
* `random_seed` documentation improvement. resource reads to occur in parallel and synchronizing only on writes.
* `RandomBinomial` broadcasts and appends the sample shape to the left rather than the right. * Fix execution order of multiple stateful calls to `experimental_run_v2`
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`, `tf.random.stateless_poisson` in `tf.function`.
* `tf.random.stateless_uniform` now supports unbounded sampling of `int` types. * You can now iterate over `RaggedTensors` using a for loop inside
* Math and Linear Algebra: `tf.function`.
* Add `tf.linalg.LinearOperatorTridiag`. * `tf.lite`:
* Add `LinearOperatorBlockLowerTriangular` * Migrated the `tf.lite` C inference API out of experimental into lite/c.
* Add broadcasting support to tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204), tf.math.invert_permutation. * Add an option to disallow `NNAPI` CPU / partial acceleration on Android
* Add `tf.math.sobol_sample` op. 10
* Add `tf.math.xlog1py`. * TFLite Android AARs now include the C headers and APIs are required to
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`. use TFLite from native code.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to `tf.signal`. * Refactors the delegate and delegate kernel sources to allow usage in the
* TPU Enhancements: linter.
* Refactor `TpuClusterResolver` to move shared logic to a separate pip package. * Limit delegated ops to actually supported ones if a device name is
* Support configuring TPU software version from cloud tpu client. specified or `NNAPI` CPU Fallback is disabled.
* Allowed TPU embedding weight decay factor to be multiplied by learning rate. * TFLite now supports `tf.math.reciprocal1` op by lowering to `tf.div op`.
* XLA Support: * TFLite's unpack op now supports boolean tensor inputs.
* Add standalone XLA AOT runtime target + relevant .cc sources to pip package. * Microcontroller and embedded code moved from experimental to main
* Add check for memory alignment to MemoryAllocation::MemoryAllocation() on 32-bit ARM. This ensures a deterministic early exit instead of a hard to debug bus error later. TensorFlow Lite folder
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to XLA header+object files and include them in your C++ programs. * Check for large TFLite tensors.
* Enable `Igamma`, `Igammac` for XLA. * Fix GPU delegate crash with C++17.
* Deterministic Op Functionality: * Add 5D support to TFLite `strided_slice`.
* XLA reduction emitter is deterministic when the environment variable `TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends deterministic `tf.nn.bias_add` back-prop functionality (and therefore also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT complilation is enabled. * Fix error in delegation of `DEPTH_TO_SPACE` to `NNAPI` causing op not to
* Fix problem, when running on a CUDA GPU and when either environment variable `TF_DETERMINSTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!" be accelerated.
* Tracing and Debugging: * Fix segmentation fault when running a model with LSTM nodes using
* Add source, destination name to `_send` traceme to allow easier debugging. `NNAPI` Delegate
* Add traceme event to `fastpathexecute`. * Fix `NNAPI` delegate failure when an operand for Maximum/Minimum
* Other: operation is a scalar.
* Fix an issue with AUC.reset_states for multi-label AUC [#35852](https://github.com/tensorflow/tensorflow/issues/35852) * Fix `NNAPI` delegate failure when Axis input for reduce operation is a
* Fix the TF upgrade script to not delete files when there is a parsing error and the output mode is `in-place`. scalar.
* Move `tensorflow/core:framework/*_pyclif` rules to `tensorflow/core/framework:*_pyclif`. * Expose option to limit the number of partitions that will be delegated
to `NNAPI`.
* If a target accelerator is specified, use its feature level to determine
operations to delegate instead of SDK version.
* `tf.random`:
* Various random number generation improvements:
* Add a fast path for default `random_uniform`
* `random_seed` documentation improvement.
* `RandomBinomial` broadcasts and appends the sample shape to the left
rather than the right.
* Added `tf.random.stateless_binomial`, `tf.random.stateless_gamma`,
`tf.random.stateless_poisson`
* `tf.random.stateless_uniform` now supports unbounded sampling of `int`
types.
* Math and Linear Algebra:
* Add `tf.linalg.LinearOperatorTridiag`.
* Add `LinearOperatorBlockLowerTriangular`
* Add broadcasting support to
tf.linalg.triangular_solve[#26204](https://github.com/tensorflow/tensorflow/issues/26204),
tf.math.invert_permutation.
* Add `tf.math.sobol_sample` op.
* Add `tf.math.xlog1py`.
* Add `tf.math.special.{dawsn,expi,fresnel_cos,fresnel_sin,spence}`.
* Add a Modified Discrete Cosine Transform (MDCT) and its inverse to
`tf.signal`.
* TPU Enhancements:
* Refactor `TpuClusterResolver` to move shared logic to a separate pip
package.
* Support configuring TPU software version from cloud tpu client.
* Allowed TPU embedding weight decay factor to be multiplied by learning
rate.
* XLA Support:
* Add standalone XLA AOT runtime target + relevant .cc sources to pip
package.
* Add check for memory alignment to MemoryAllocation::MemoryAllocation()
on 32-bit ARM. This ensures a deterministic early exit instead of a hard
to debug bus error later.
* `saved_model_cli aot_compile_cpu` allows you to compile saved models to
XLA header+object files and include them in your C++ programs.
* Enable `Igamma`, `Igammac` for XLA.
* Deterministic Op Functionality:
* XLA reduction emitter is deterministic when the environment variable
`TF_DETERMINISTIC_OPS` is set to "true" or "1". This extends
deterministic `tf.nn.bias_add` back-prop functionality (and therefore
also deterministic back-prop of bias-addition in Keras layers) to
include when XLA JIT compilation is enabled.
* Fix problem, when running on a CUDA GPU and when either environment
variable `TF_DETERMINSTIC_OPS` or environment variable
`TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer
configurations led to an exception with the message "No algorithm
worked!"
* Tracing and Debugging:
* Add source, destination name to `_send` traceme to allow easier
debugging.
* Add traceme event to `fastpathexecute`.
* Other:
* Fix an issue with AUC.reset_states for multi-label AUC
[#35852](https://github.com/tensorflow/tensorflow/issues/35852)
* Fix the TF upgrade script to not delete files when there is a parsing
error and the output mode is `in-place`.
* Move `tensorflow/core:framework/*_pyclif` rules to
`tensorflow/core/framework:*_pyclif`.
## Thanks to our Contributors ## Thanks to our Contributors

View File

@ -114,6 +114,14 @@ http_archive(
], ],
) )
http_archive(
name = "person_detect_data",
sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip",
],
)
# Required for dependency @com_github_grpc_grpc # Required for dependency @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")

View File

@ -49,7 +49,7 @@ _TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
_TF_WORKSPACE_ROOT = '' _TF_WORKSPACE_ROOT = ''
_TF_BAZELRC = '' _TF_BAZELRC = ''
_TF_CURRENT_BAZEL_VERSION = None _TF_CURRENT_BAZEL_VERSION = None
_TF_MIN_BAZEL_VERSION = '2.0.0' _TF_MIN_BAZEL_VERSION = '3.1.0'
_TF_MAX_BAZEL_VERSION = '3.99.0' _TF_MAX_BAZEL_VERSION = '3.99.0'
NCCL_LIB_PATHS = [ NCCL_LIB_PATHS = [
@ -484,8 +484,8 @@ def check_bazel_version(min_version, max_version):
stderr = open(os.devnull, 'wb') stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'], curr_version = run_shell(['bazel', '--version'],
allow_non_zero = True, allow_non_zero=True,
stderr = stderr) stderr=stderr)
if curr_version.startswith('bazel '): if curr_version.startswith('bazel '):
curr_version = curr_version.split('bazel ')[1] curr_version = curr_version.split('bazel ')[1]
@ -1011,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
default_cuda_compute_capabilities = native_cuda_compute_capabilities default_cuda_compute_capabilities = native_cuda_compute_capabilities
ask_cuda_compute_capabilities = ( ask_cuda_compute_capabilities = (
'Please specify a list of comma-separated ' 'Please specify a list of comma-separated CUDA compute capabilities '
'CUDA compute capabilities you want to ' 'you want to build with.\nYou can find the compute capability of your '
'build with.\nYou can find the compute ' 'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
'capability of your device at: ' 'can be specified as "x.y" or "compute_xy" to include both virtual and'
'https://developer.nvidia.com/cuda-gpus.\nPlease' ' binary GPU code, or as "sm_xy" to only include the binary '
' note that each additional compute ' 'code.\nPlease note that each additional compute capability '
'capability significantly increases your ' 'significantly increases your build time and binary size, and that '
'build time and binary size, and that ' 'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
'TensorFlow only supports compute ' '%s]: ' % default_cuda_compute_capabilities)
'capabilities >= 3.5 [Default is: %s]: ' %
default_cuda_compute_capabilities)
tf_cuda_compute_capabilities = get_from_env_or_user_or_default( tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
ask_cuda_compute_capabilities, default_cuda_compute_capabilities) ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
@ -1033,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
for compute_capability in tf_cuda_compute_capabilities.split(','): for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability) m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m: if not m:
print('Invalid compute capability: %s' % compute_capability) # We now support sm_35,sm_50,sm_60,compute_70.
all_valid = False sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
compute_capability)
if not sm_compute_match:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
else:
ver = int(sm_compute_match.group(2))
if ver < 30:
print(
'ERROR: TensorFlow only supports small CUDA compute'
' capabilities of sm_30 and higher. Please re-specify the list'
' of compute capabilities excluding version %s.' % ver)
all_valid = False
if ver < 35:
print('WARNING: XLA does not support CUDA compute capabilities '
'lower than sm_35. Disable XLA when running on older GPUs.')
else: else:
ver = float(m.group(0)) ver = float(m.group(0))
if ver < 3.0: if ver < 3.0:
@ -1225,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it. compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion See also
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be versions, you need to use the registry, and it's not clear if Bazel will be
@ -1372,7 +1386,7 @@ def main():
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION) _TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("Error checking bazel version: ", e.output.decode('UTF-8').strip()) print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
raise e raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)

View File

@ -298,6 +298,13 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Experimental features
config_setting(
name = "stackdriver_support",
define_values = {"stackdriver_support": "true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those # Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements. # platforms due to limitations in nested select() statements.
config_setting( config_setting(
@ -531,6 +538,7 @@ package_group(
# Packages that use composite tensors or dispatch. # Packages that use composite tensors or dispatch.
# TODO(b/154762408) Remove this package group once it's no longer needed. # TODO(b/154762408) Remove this package group once it's no longer needed.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "composite_tensor_whitelist") package_group(name = "composite_tensor_whitelist")
# Packages that use private types symbols, until they are exported. # Packages that use private types symbols, until they are exported.
@ -540,6 +548,11 @@ package_group(
packages = ["//learning/deepmind/tensorflow/replicator/..."], packages = ["//learning/deepmind/tensorflow/replicator/..."],
) )
# Packages that use StructuredTensors.
# TODO(b/159007891) Remove this package once StructuredTensor is exported.
# If this is modified, then copy.bara.sky must also be modified.
package_group(name = "structured_tensor_whitelist")
filegroup( filegroup(
name = "intel_binary_blob", name = "intel_binary_blob",
data = if_mkl_ml( data = if_mkl_ml(

View File

@ -216,6 +216,7 @@ tf_cuda_library(
], ],
visibility = [ visibility = [
"//tensorflow/c:__subpackages__", "//tensorflow/c:__subpackages__",
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
], ],
deps = select({ deps = select({
"//tensorflow:android": [ "//tensorflow:android": [

View File

@ -589,14 +589,16 @@ void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) { TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList; TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response); if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response; return response;
} }
TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session, TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
TF_Status* status) { TF_Status* status) {
TF_DeviceList* response = new TF_DeviceList; TF_DeviceList* response = new TF_DeviceList;
status->status = session->session->ListDevices(&response->response); if (session && session->session)
status->status = session->session->ListDevices(&response->response);
return response; return response;
} }
@ -1384,6 +1386,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
cpp_type v; \ cpp_type v; \
status->status = \ status->status = \
tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \ tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
if (!status->status.ok()) return; \
*value = static_cast<c_type>(v); \ *value = static_cast<c_type>(v); \
} \ } \
void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \ void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
@ -2178,6 +2181,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
} }
return new_session; return new_session;
} else { } else {
LOG(ERROR) << status->status;
DCHECK_EQ(nullptr, session); DCHECK_EQ(nullptr, session);
return nullptr; return nullptr;
} }

View File

@ -144,6 +144,24 @@ cc_library(
], ],
) )
cc_library(
name = "c_api_unified_internal",
hdrs = [
"c_api_unified_experimental_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":c_api",
":c_api_experimental",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
)
cc_library( cc_library(
name = "tensor_handle_interface", name = "tensor_handle_interface",
hdrs = ["tensor_handle_interface.h"], hdrs = ["tensor_handle_interface.h"],
@ -184,7 +202,6 @@ cc_library(
":operation_interface", ":operation_interface",
":tensor_handle_interface", ":tensor_handle_interface",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
@ -351,8 +368,41 @@ tf_cuda_cc_test(
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(), extra_copts = tfe_xla_copts(),
tags = [ tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:function_optimization_registry",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",
srcs = [
"c_api_distributed_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances "noasan", # leaks gRPC server instances
"notsan", # b/157098283
], ],
deps = [ deps = [
":c_api", ":c_api",
@ -383,7 +433,10 @@ tf_cuda_cc_test(
# TODO(b/136478427): Figure out how to correctly shut the server down # TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"], args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(), extra_copts = tfe_xla_copts(),
tags = ["noasan"], # leaks gRPC server instances tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [ deps = [
":c_api", ":c_api",
":c_api_experimental", ":c_api_experimental",
@ -514,6 +567,7 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util", "//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",

View File

@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return; return;
} }
tensorflow::EagerContext* context = status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
} }
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
} }
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) { TF_Status* status) {
tensorflow::EagerContext* context = status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
} }
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
@ -1479,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
} }
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op)); OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs(); tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) { destination->CopyAttributes(*tensorflow::unwrap(attrs));
destination->Set(attribute.first, attribute.second);
}
} }
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -30,26 +30,6 @@ namespace {
using ::tensorflow::string; using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) { void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0); tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
int port = tensorflow::testing::PickUnusedPortOrDie(); int port = tensorflow::testing::PickUnusedPortOrDie();

View File

@ -0,0 +1,506 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace {
using ::tensorflow::string;
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace

View File

@ -35,26 +35,6 @@ namespace {
using ::tensorflow::string; using ::tensorflow::string;
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}
void TestRemoteExecute(bool async) { void TestRemoteExecute(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2); tensorflow::ServerDef server_def = GetServerDef(2);
@ -356,472 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
/*heavy_load_on_streaming_rpc=*/true); /*heavy_load_on_streaming_rpc=*/true);
} }
// Add the values of three variables on three different tasks.
string AddVariablesFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'AddVariablesFunction'"
" input_arg {"
" name: 'var'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'sum'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read1'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'read2'"
" op: 'ReadVariableOp'"
" input: 'var'"
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add1'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read1:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add2'"
" op: 'Add'"
" input: 'add1:z:0'"
" input: 'read2:value:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'sum'"
" value: 'add2:z:0'"
" }",
&def));
return def.SerializeAsString();
}
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(op, var_handle, status);
TFE_TensorHandle* is_initialized[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
CHECK_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
bool initialized = false;
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
EXPECT_EQ(initialized, true);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(is_initialized[0]);
TFE_DeleteOp(op);
delete status;
}
void TestFunctionWithPackedInput(const bool remote) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
// TODO(b/155789951): Remove once b/155789951 is fixed.
VarIsInitialized(ctx, h1);
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
const string composite_device_name =
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
composite_device_name);
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
composite_device_name);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// Register and run a function which returns the sum of 3 variables.
const string function_def = AddVariablesFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, packed_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
if (remote) {
TFE_OpSetDevice(func, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(packed_handle);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(sum, 6.0);
TFE_DeleteTensorHandle(h0);
TFE_DeleteTensorHandle(h1);
TFE_DeleteTensorHandle(h2);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, TestLocalFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/false);
}
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
TestFunctionWithPackedInput(/*remote=*/true);
}
string VariableAddFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'VariableAddFunction'"
" input_arg {"
" name: 'var0'"
" type: DT_RESOURCE"
" }"
" output_arg {"
" name: 'var0_value'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'read0'"
" op: 'ReadVariableOp'"
" input: 'var0'"
" attr {"
" key: 'dtype'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'add'"
" op: 'Add'"
" input: 'read0:value:0'"
" input: 'read0:value:0'"
" device: '/job:localhost/task:1/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" node_def {"
" name: 'identity'"
" op: 'Identity'"
" input: 'add:z:0'"
" device: '/job:localhost/task:0/device:CPU:0'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'var0_value'"
" value: 'identity:output:0'"
" }",
&def));
return def.SerializeAsString();
}
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
public:
FunctionErrorInjectionPass(string error_node, string error_device)
: error_node_(error_node), error_device_(error_device) {}
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
const tensorflow::ConfigProto& config_proto,
std::unique_ptr<tensorflow::Graph>* graph,
tensorflow::FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override {
// Inject failure to function instantiation if finding a node that contains
// the given node name (error_node_) and requested device (error_device_).
for (const auto node : graph->get()->nodes()) {
if (node->name().find(error_node_) != string::npos &&
node->requested_device() == error_device_) {
return tensorflow::errors::Internal("Injected graph pass error.");
}
}
return tensorflow::Status::OK();
}
private:
const string error_node_;
const string error_device_;
};
void TestDistributedFunctionCancellation(bool inject_error) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
if (inject_error) {
// Inject a function optimization pass failure when it sees the 'read0' op
// having a requested device `dev2_name`. During execution:
// * task:0 processes the main function `VariableAddFunction` and places
// the read0 op on task:2
// * task:0 partitions the main function with a subgraph containing read0
// sent to task:2
// * task:2 graph pass reports an error when it sees read0 with dev2_name
tensorflow::function_optimization_registration::
FunctionOptimizationPassRegistration register_test_pass(
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
}
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(func, var_handle, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(func, &retvals[0], &num_retvals, status);
if (inject_error) {
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
} else {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retvals[0]);
float sum = 0;
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
ASSERT_EQ(sum, 4.0);
}
TFE_DeleteOp(func);
TFE_DeleteTensorHandle(var_handle);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}
TEST(CAPI, DistributedFunctionNoError) {
TestDistributedFunctionCancellation(false);
}
TEST(CAPI, DistributedFunctionCancelledOnError) {
TestDistributedFunctionCancellation(true);
}
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Use large matrices so that RPCs don't return before we get a chance
// to call TFE_DeleteContext.
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
auto* h0_task1 =
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* h1_task1 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h0_task1);
TFE_DeleteTensorHandle(h1_task1);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_DeleteContext(ctx);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
}
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
}
} // namespace } // namespace

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
using tensorflow::string; using tensorflow::string;
@ -296,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
TF_DeleteDeviceList(devices); TF_DeleteDeviceList(devices);
return false; return false;
} }
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost:", port)});
}
return server_def;
}
tensorflow::ServerDef GetServerDef(int num_tasks) {
return GetServerDef("localhost", num_tasks);
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
// Return a tensor handle containing a float scalar // Return a tensor handle containing a float scalar
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value); TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
@ -72,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
const char* device_type); const char* device_type);
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
int num_tasks);
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
tensorflow::ServerDef GetServerDef(int num_tasks);
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_

View File

@ -58,7 +58,7 @@ T* dyncast(S source) {
// GraphContext and vice-versa). // GraphContext and vice-versa).
class AbstractTensor { class AbstractTensor {
protected: protected:
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor }; enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
public: public:
@ -101,7 +101,7 @@ class AbstractFunction {
// on a given context, with the same or different input tensors. // on a given context, with the same or different input tensors.
class AbstractOp { class AbstractOp {
protected: protected:
enum AbstractOpKind { kGraphOp, kEagerOp }; enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
public: public:
@ -129,7 +129,7 @@ class AbstractOp {
// eager implementation or to a graph implementation. // eager implementation or to a graph implementation.
struct ExecutionContext { struct ExecutionContext {
protected: protected:
enum ExecutionContextKind { kGraphContext, kEagerContext }; enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
public: public:

View File

@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
TF_DeleteExecutionContext(eager_execution_ctx); TF_DeleteExecutionContext(eager_execution_ctx);
} }
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Values("graphdef", "mlir"));
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -84,11 +84,10 @@ class AbstractContextInterface {
// Create an operation to perform op execution // Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0; virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags // Returns whether the runtime is backed by TFRT or the legacy TF Eager
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI( // Runtime. This is necessary to decouple runtime-dependent
const std::string& directory, // code that is layered on top of the runtime.
const absl::optional<std::unordered_set<std::string>>& tags, virtual bool UsesTFRT() = 0;
tensorflow::Status* status) = 0;
// List attributes of available devices // List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
@ -104,6 +103,14 @@ class AbstractContextInterface {
// Block until all pending nodes are finished. // Block until all pending nodes are finished.
virtual Status AsyncWait() = 0; virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected: protected:
virtual ~AbstractContextInterface() {} virtual ~AbstractContextInterface() {}
}; };

View File

@ -12,39 +12,98 @@ package(
# need a second rule that omits .cc files, in # need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device. # tensorflow/python:_pywrap_parallel_device.
filegroup( filegroup(
name = "headers", name = "lib_headers",
srcs = ["parallel_device_lib.h"],
)
filegroup(
name = "lib_sources",
srcs = ["parallel_device_lib.cc"],
)
filegroup(
name = "device_headers",
srcs = ["parallel_device.h"], srcs = ["parallel_device.h"],
)
filegroup(
name = "device_sources",
srcs = ["parallel_device.cc"],
)
filegroup(
name = "headers",
srcs = [
":device_headers",
":lib_headers",
],
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
) )
filegroup( filegroup(
name = "sources", name = "sources",
srcs = ["parallel_device.cc"], srcs = [
":device_sources",
":lib_sources",
],
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
) )
cc_library( cc_library(
name = "parallel_device", name = "parallel_device",
srcs = [":sources"], srcs = [":device_sources"],
hdrs = [":headers"], hdrs = [":device_headers"],
visibility = ["//tensorflow:internal"],
deps = [
":parallel_device_lib",
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
cc_library(
name = "parallel_device_lib",
srcs = [":lib_sources"],
hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
], ],
) )
cc_library(
name = "parallel_device_testlib",
testonly = 1,
srcs = ["parallel_device_testlib.cc"],
hdrs = ["parallel_device_testlib.h"],
deps = [
":parallel_device",
":parallel_device_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test( tf_cc_test(
name = "parallel_device_test", name = "parallel_device_test",
srcs = ["parallel_device_test.cc"], srcs = ["parallel_device_test.cc"],
deps = [ deps = [
":parallel_device", ":parallel_device",
":parallel_device_ops", ":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental", "//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
@ -55,6 +114,27 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "parallel_device_remote_test",
srcs = ["parallel_device_remote_test.cc"],
# TODO(b/136478427): Enable global heap checking when servers shut down
# cleanly.
args = ["--heap_check=local"],
deps = [
":parallel_device",
":parallel_device_ops",
":parallel_device_testlib",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
# Note: ParallelDevice-specific ops are experimental and not currently linked in # Note: ParallelDevice-specific ops are experimental and not currently linked in
# to TensorFlow by default, just used in a few tests. # to TensorFlow by default, just used in a few tests.
filegroup( filegroup(

View File

@ -23,25 +23,13 @@ limitations under the License.
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow { namespace tensorflow {
namespace eager { namespace parallel_device {
namespace { namespace {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class OpDeleter { class OpDeleter {
public: public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); } void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@ -49,224 +37,46 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>; using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
using MaybeParallelTensorOwned = using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>; absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
using MaybeParallelTensorUnowned = using MaybeParallelTensorUnowned =
absl::variant<ParallelTensor*, TFE_TensorHandle*>; absl::variant<ParallelTensor*, TFE_TensorHandle*>;
// Creates a vector of `count` new executors (threads). // A ParallelDevice on its own is not registered with a TFE_Context, and so has
std::vector<ExecutorPtr> MakeExecutors(size_t count) { // no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
std::vector<ExecutorPtr> executors; // name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
executors.reserve(count); // placed on the parallel device.
for (int i = 0; i < count; ++i) { class NamedParallelDevice {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
// A representation of the custom device passed in and out of the TFE custom
// device APIs, providing context about the parallel device to
// ParallelDeviceExecute.
class ParallelDevice {
public: public:
ParallelDevice(const std::string& name, NamedParallelDevice(const std::string& name,
const std::vector<std::string>& devices); std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
// Helper to copy a tensor handle from another device once for each component const std::string& name() const { return device_name_; }
// of the ParallelDevice. const ParallelDevice& device() const { return *parallel_device_; }
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors (or
// implicitly-mirrored tensors on other devices). Wraps the resulting
// per-device and per-output TFE_TensorHandles into one ParallelTensor per
// output of the original operation.
//
// `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
// un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
// requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
// tensor, but other operations will implicitly broadcast non-parallel input
// tensors across the ParallelDevice's component devices.
//
// Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
// pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
// causes `Execute` to return non-parallel tensors.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK.
absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Implements the parallel case for `Execute`, where all of the outputs of the
// operation are ParallelTensors, and all inputs are either ParallelTensors or
// should be implicitly broadcast. This means the operation is not
// TPUReplicatedInput or TPUReplicatedOutput.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ExecuteParallelOperation(TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name,
const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
const std::string& device_name() const { return device_name_; }
private: private:
// The name of the parallel device std::string device_name_;
// (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0") std::unique_ptr<ParallelDevice> parallel_device_;
const std::string device_name_;
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
}; };
// The internal representation of a TFE_TensorHandle placed on a absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
// ParallelDevice. Contains a tuple of tensors, one on each of the const ParallelDevice& parallel_device,
// `underlying_devices_` of the ParallelDevice. const std::string& parallel_device_name, TFE_Context* context,
class ParallelTensor { std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
public: const TFE_OpAttrs* attributes, int expected_max_outputs,
// Construct a ParallelTensor from TensorHandles placed on the component TF_Status* status) {
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
static TensorHandlePtr AsTensorHandle(TFE_Context* context,
std::unique_ptr<ParallelTensor> t,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
ParallelDevice::ParallelDevice(const std::string& name,
const std::vector<std::string>& devices)
: device_name_(name),
underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
if (device_name_ == current_device) {
std::string message(absl::StrCat(
"Tried to copy a TensorHandle to its existing device: ", device_name_));
TF_SetStatus(status, TF_INTERNAL, message.c_str());
return nullptr;
}
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<MaybeParallelTensorOwned>> result; absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least, // TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors. // or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) { if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel // Special-cased operation for packing per-device tensors into one parallel
// tensor. // tensor.
if (inputs.size() != underlying_devices_.size()) { if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat( std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ", "The parallel device ", parallel_device_name, " expected ",
underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ", parallel_device.num_underlying_devices(),
inputs.size())); " inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result; return result;
} }
@ -289,7 +99,7 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
std::vector<MaybeParallelTensorOwned> result_content; std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1); result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles( result_content.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status)); parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content)); result.emplace(std::move(result_content));
return result; return result;
@ -300,10 +110,10 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
TFE_OpAddAttrs(op.get(), attributes); TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status); int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != underlying_devices_.size()) { if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat( std::string message(absl::StrCat(
"The parallel device ", device_name_, " expected ", "The parallel device ", parallel_device_name, " expected ",
underlying_devices_.size(), parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs)); " outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str()); TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result; return result;
@ -329,15 +139,38 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
} else if (operation_name == std::string("DeviceID")) { } else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content; std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1); result_content.reserve(1);
result_content.push_back(DeviceIDs(context, status)); result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result; if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content)); result.emplace(std::move(result_content));
return result; return result;
} }
std::vector<ParallelTensor*> parallel_inputs;
std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors;
parallel_inputs.reserve(inputs.size());
implicitly_broadcast_tensors.reserve(inputs.size()); // not tight
for (const auto& input : inputs) {
if (absl::holds_alternative<TFE_TensorHandle*>(input)) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
std::unique_ptr<ParallelTensor> parallel_tensor(
parallel_device.CopyToParallelDevice(
context, absl::get<TFE_TensorHandle*>(input), status));
if (TF_GetCode(status) != TF_OK) return result;
parallel_inputs.push_back(parallel_tensor.get());
implicitly_broadcast_tensors.emplace_back(std::move(parallel_tensor));
} else {
parallel_inputs.push_back(absl::get<ParallelTensor*>(input));
}
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results( maybe_parallel_results(
ExecuteParallelOperation(context, std::move(inputs), operation_name, parallel_device.Execute(context, parallel_inputs, operation_name,
attributes, expected_max_outputs, status)); attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result; if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results( std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value())); std::move(maybe_parallel_results.value()));
@ -351,144 +184,6 @@ absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
return result; return result;
} }
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::ExecuteParallelOperation(
TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
// Non-parallel tensors are implicitly broadcast, i.e. set as the input
// to each parallel operation.
//
// TODO(allenl): There may be smarter ways to do this copy in some
// cases, i.e. with a collective broadcast. We'll need to be careful
// about things that are taken as inputs on the host or on their
// existing device (for multi-device functions).
TFE_OpAddInput(op.get(),
absl::get<TFE_TensorHandle*>(inputs[input_index]),
status);
if (TF_GetCode(status) != TF_OK) return result;
} else {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(),
absl::get<ParallelTensor*>(inputs[input_index])
->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
}
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) {
first_op_output_count = real_num_outputs;
} else {
if (real_num_outputs != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how // Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero. // reference counts drop to zero.
@ -496,17 +191,18 @@ void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data); delete reinterpret_cast<ParallelTensor*>(data);
} }
TensorHandlePtr ParallelTensor::AsTensorHandle( TensorHandlePtr ParallelTensorToTensorHandle(
TFE_Context* context, std::unique_ptr<ParallelTensor> t, const std::string& parallel_device_name, TFE_Context* context,
TF_Status* status) { std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which // The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct. // deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release(); ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory( return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, t_released->device_.device_name().c_str(), t_released->dtype_, context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
t_released->shape_.data(), t_released->shape_.size(), t_released, 1, shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
&ParallelTensorDeallocator, nullptr, status)); status));
} }
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device // For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@ -522,12 +218,14 @@ TensorHandlePtr ParallelTensor::AsTensorHandle(
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context, TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor, TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) { TF_Status* status, void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info); NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor( std::unique_ptr<ParallelTensor> parallel_tensor(
dev->CopyToParallelDevice(context, tensor, status)); dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr; if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor), return ParallelTensorToTensorHandle(named_device->name(), context,
status) std::move(parallel_tensor), status)
.release(); .release();
} }
@ -561,14 +259,15 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
const TFE_OpAttrs* attributes, int* num_outputs, const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status, TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) { void* device_info) {
ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info); NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs; std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs); typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device = const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status); TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (dev->device_name() == tensor_handle_device) { if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are // We assume that any tensors already placed on this device are
// ParallelTensors. // ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>( typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@ -580,8 +279,9 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
} }
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs( absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
dev->Execute(context, std::move(typed_inputs), operation_name, attributes, ExecuteWithSpecialOps(named_device->device(), named_device->name(),
*num_outputs, status)); context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) { if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned."); TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@ -602,8 +302,8 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) { if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release(); outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else { } else {
outputs[i] = ParallelTensor::AsTensorHandle( outputs[i] = ParallelTensorToTensorHandle(
context, named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>( std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)), typed_output)),
status) status)
@ -620,7 +320,7 @@ void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
// device_info is passed in using a C-style generic. It must always be a // device_info is passed in using a C-style generic. It must always be a
// ParallelDevice. // ParallelDevice.
void DeleteParallelDevice(void* device_info) { void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<ParallelDevice*>(device_info); delete reinterpret_cast<NamedParallelDevice*>(device_info);
} }
} // namespace } // namespace
@ -639,8 +339,10 @@ void AllocateParallelDevice(const char* device_name,
++device_index) { ++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]); underlying_devices_vector.push_back(underlying_devices[device_index]);
} }
*device_info = new ParallelDevice(device_name, underlying_devices_vector); std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
} }
} // namespace parallel_device
} // namespace eager
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow { namespace tensorflow {
namespace eager { namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to // Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed // `underlying_devices`, maintaining "parallel tensors" with components placed
@ -59,7 +59,7 @@ void AllocateParallelDevice(const char* device_name,
int num_underlying_devices, int num_underlying_devices,
TFE_CustomDevice* device, void** device_info); TFE_CustomDevice* device, void** device_info);
} // namespace eager } // namespace parallel_device
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_

View File

@ -0,0 +1,376 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
class StatusDeleter {
public:
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
};
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (const std::string& underlying_device_name : underlying_devices_) {
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, context, underlying_device_name.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(t);
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
TFE_Context* context, TF_Status* status) const {
// TODO(allenl): We could cache DeviceIDs (keyed by context).
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
int64_t* device_id = new int64_t;
*device_id = device_index;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(
TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
sizeof(int64_t),
[](void* data, size_t, void* arg) {
delete reinterpret_cast<int64_t*>(data);
},
nullptr),
TF_DeleteTensor);
// TODO(allenl): Here and when executing regular operations, we could hold
// on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
// device names repeatedly.
OpPtr const_op(TFE_NewOp(context, "Const", status));
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
TFE_TensorHandle* device_handle;
int num_outputs = 1;
TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
components.emplace_back(device_handle);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
return ParallelTensor::FromTensorHandles(*this, std::move(components),
status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(TFE_Context* context,
const std::vector<ParallelTensor*>& inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
int first_op_output_count = 0;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
std::vector<TFE_TensorHandle*> device_inputs;
device_inputs.reserve(device_inputs.size());
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
// Parallel tensors are divided between operations by device.
device_inputs.push_back(inputs[input_index]->tensor(device_index));
}
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
if (device_index == 0) {
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.
std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
per_device_outputs.reserve(first_op_output_count);
for (int i = 0; i < first_op_output_count; ++i) {
std::vector<TensorHandlePtr> components;
components.reserve(underlying_devices_.size());
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
return result;
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Components of a ParallelTensor must currently all have "
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
#include <memory>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
namespace parallel_device {
// Functor for making unique_ptrs slightly more ergonomic. Using
// decltype(delete_fn) in the unique_ptr's second template argument requires
// passing a function pointer to delete_fn when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) const {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ParallelTensor;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
// Sets a bad status and returns a nullptr if `tensor` is already on the
// ParallelDevice, or if the individual copies fail.
std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status) const;
// A parallel tensor with scalar integers numbering component devices.
std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
TF_Status* status) const;
// The number of devices operations run on.
size_t num_underlying_devices() const { return underlying_devices_.size(); }
// Takes a description of a single operation being executed on the
// ParallelDevice, and in turn runs one operation per component device with
// its corresponding inputs from the input ParallelTensors. Wraps the
// resulting per-device and per-output TFE_TensorHandles into one
// ParallelTensor per output of the original operation.
//
// Attributes are forwarded to executed operations unmodified.
//
// The returned optional has a value if and only if `status` evaluates to
// TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
// if sanity checks on dtypes/metadata fail.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of thread wrappers, one per device, for executing operations in
// parallel.
//
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
// ParallelDevice.
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
// A generalization of the shapes of the underlying tensors.
const std::vector<int64_t>& shape() const { return shape_; }
TF_DataType dtype() const { return dtype_; }
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
dtype_(dtype) {}
const ParallelDevice& device_;
const std::vector<TensorHandlePtr> tensors_;
const std::vector<int64_t> shape_;
const TF_DataType dtype_;
};
} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_

View File

@ -0,0 +1,147 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <array>
#include <string>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) {
tensorflow::ServerDef server_def;
server_def.set_protocol("grpc");
server_def.set_job_name(job_name);
server_def.set_task_index(0);
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
tensorflow::JobDef* job_def = cluster_def->add_job();
job_def->set_name(job_name);
for (int i = 0; i < num_tasks; i++) {
int port = tensorflow::testing::PickUnusedPortOrDie();
job_def->mutable_tasks()->insert(
{i, tensorflow::strings::StrCat("localhost", ":", port)});
}
return server_def;
}
TEST(PARALLEL_DEVICE, TestRemoteBasic) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
BasicTestsForTwoDevices(context.get(),
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0");
worker_server1.release();
worker_server2.release();
}
TEST(PARALLEL_DEVICE, TestAsyncCopyOff) {
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
tensorflow::ServerDef server_def = GetServerDef("worker", 3);
// This server def has the task index set to 0.
std::string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TFE_ContextSetServerDef(context.get(), 0, serialized.data(),
serialized.size(), status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const char* first_device = "/job:worker/replica:0/task:1/device:CPU:0";
const char* second_device = "/job:worker/replica:0/task:2/device:CPU:0";
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context.get(), device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TFE_TensorHandle*, 2> in_components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value = CreatePerDeviceValues(
context.get(), in_components, device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Loop to make synchronization failures more deterministic
for (int i = 0; i < 100; ++i) {
TensorHandlePtr multiply_result(
Multiply(context.get(), combined_value.get(), combined_value.get(),
status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> out_components;
ExtractPerDeviceValues(context.get(), multiply_result.get(),
&out_components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(out_components[0].get(), 9.);
ExpectScalarEq<float>(out_components[1].get(), 4.);
}
worker_server1.release();
worker_server2.release();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are // NOTE(allenl): These tests currently go through TFE_Execute and so are
@ -28,390 +29,6 @@ limitations under the License.
// correspond fairly well to the implementation, but testing the C++ directly is // correspond fairly well to the implementation, but testing the C++ directly is
// another option. // another option.
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::eager::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}
TEST(PARALLEL_DEVICE, TestBasicCPU) { TEST(PARALLEL_DEVICE, TestBasicCPU) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
@ -790,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle); return TensorHandlePtr(result_handle);
} }
TEST(PARALLEL_DEVICE, TestCollective) { void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts( std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -806,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{ std::array<const char*, 2> underlying_devices{
@ -835,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.); ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.); ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
} }
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context, void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size, const char* function_name, int group_size,
TF_Status* status) { TF_Status* status) {

View File

@ -0,0 +1,308 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// NOTE(allenl): These tests currently go through TFE_Execute and so are
// integration testing rather than purely testing the parallel device. They
// correspond fairly well to the implementation, but testing the C++ directly is
// another option.
Variable* Variable::Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "VarHandleOp", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type);
TFE_OpSetAttrShape(op.get(), "shape", dims, num_dims, status);
TFE_OpSetAttrString(op.get(), "container", "", 0);
// Use the special GUID for no buffer sharing
//
// TODO(allenl): Should we provide a better API for this? AFAIK this is the
// only reasonable way to make variables with no aliasing using the eager C
// API.
std::string no_sharing = "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
TFE_OpSetAttrString(op.get(), "shared_name", no_sharing.c_str(),
no_sharing.length());
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
TFE_Execute(op.get(), &var_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return new Variable(var_handle, type);
}
void Variable::Destroy(TFE_Context* context, TF_Status* status) {
// Free the backing buffer for the variable.
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DestroyResourceOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
// Delete the variable handle itself.
TFE_DeleteTensorHandle(handle_);
}
TensorHandlePtr Variable::Read(TFE_Context* context, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "ReadVariableOp", status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op.get(), "dtype", type_);
int num_retvals = 1;
TFE_TensorHandle* var_value = nullptr;
TFE_Execute(op.get(), &var_value, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(var_value);
}
void Variable::GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, op_name, status), &TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrType(op.get(), "dtype", type_);
TFE_OpAddInput(op.get(), handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpAddInput(op.get(), value, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(handle_, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
int num_retvals = 0;
TFE_Execute(op.get(), nullptr, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
}
void Variable::AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignAddVariableOp", context, value, status);
}
void Variable::Assign(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status) {
GeneralAssignment("AssignVariableOp", context, value, status);
}
// Passed to `TF_NewTensor` to indicate how an array of floats should be
// deleted.
static void FloatDeallocator(void* data, size_t, void* arg) {
delete[] static_cast<float*>(data);
}
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status) {
const int num_bytes = sizeof(float);
float* values = new float[1];
values[0] = v;
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, &FloatDeallocator,
nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status) {
const int num_bytes = v.size() * sizeof(float);
float* values = new float[v.size()];
memcpy(values, v.data(), num_bytes);
int64_t dims = v.size();
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
TF_NewTensor(TF_FLOAT, &dims, 1 /* num_dims */, values, num_bytes,
&FloatDeallocator, nullptr),
TF_DeleteTensor);
return TensorHandlePtr(TFE_NewTensorHandle(tensor.get(), status));
}
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedOutput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetAttrInt(op.get(), "num_replicas", num_replicas);
TFE_OpAddInput(op.get(), input, status);
if (TF_GetCode(status) != TF_OK) return;
const char* device = TFE_TensorHandleDeviceName(input, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return;
TFE_TensorHandle* result_handles[num_replicas];
int num_retvals = num_replicas;
TFE_Execute(op.get(), result_handles, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return;
for (int i = 0; i < num_replicas; ++i) {
(*components)[i].reset(result_handles[i]);
}
}
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "Mul", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpAddInput(op.get(), second, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
const char* first_device = TFE_TensorHandleDeviceName(first, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetDevice(op.get(), first_device, status);
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device) {
// Register the custom device
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{first_device, second_device};
RegisterParallelDevice(context, device_name, underlying_devices,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Create a variable handle (uninitialized to start) placed on the parallel
// device.
std::function<void(Variable*)> variable_deleter = [&](Variable* to_delete) {
to_delete->Destroy(context, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
delete to_delete;
};
std::unique_ptr<Variable, decltype(variable_deleter)> variable(
Variable::Create(context, TF_FLOAT, /* Scalar */ {}, 0, device_name,
status.get()),
variable_deleter);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// Assign an initial value to the variable, implicitly mirroring it to each
// component device.
{
TensorHandlePtr initial_value = FloatTensorHandle(20., status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
variable->Assign(context, initial_value.get(), status.get());
}
// Read from the variable and verify that we have a parallel tensor.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 20.);
ExpectScalarEq<float>(components[1].get(), 20.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Add a parallel tensor with different values on each device to the variable.
{
TensorHandlePtr value_one(FloatTensorHandle(3., status.get()));
TensorHandlePtr value_two(FloatTensorHandle(-2., status.get()));
std::array<TFE_TensorHandle*, 2> components{value_one.get(),
value_two.get()};
TensorHandlePtr combined_value =
CreatePerDeviceValues(context, components, device_name, status.get());
variable->AssignAdd(context, combined_value.get(), status.get());
}
// Read the variable and verify that each component has the right modified
// value.
{
TensorHandlePtr read = variable->Read(context, status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, read.get(), &components, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(components[0].get(), 23.);
ExpectScalarEq<float>(components[1].get(), 18.);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
// Compute the device ID twice and verify the result
for (int i = 0; i < 2; ++i) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "DeviceID", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetDevice(op.get(), device_name, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::array<TensorHandlePtr, 2> components;
ExtractPerDeviceValues(context, result_handle, &components, status.get());
TFE_DeleteTensorHandle(result_handle);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<int64_t>(components[0].get(), 0);
ExpectScalarEq<int64_t>(components[1].get(), 1);
std::string first_device =
TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
ASSERT_EQ(underlying_devices[0], first_device);
std::string second_device =
TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
ASSERT_EQ(underlying_devices[1], second_device);
}
}

View File

@ -0,0 +1,174 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <array>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/platform/test.h"
// Functor for making unique_ptr to TFE_TensorHandle slightly more
// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second
// template argument requires passing a function pointer to
// TFE_DeleteTensorHandle when constructing the unique_ptr.
class TensorHandleDeleter {
public:
void operator()(TFE_TensorHandle* to_delete) {
TFE_DeleteTensorHandle(to_delete);
}
};
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
// A helper for performing common operations on variables. A much more
// restricted stand-in for tf.Variable in Python.
class Variable {
public:
// Construct a Variable from a resource-dtype TFE_TensorHandle and an
// indication of the dtype of the variable's value.
//
// Note that creating this resource-dtype handle can fail, so `Create` is a
// separate static method which returns a status.
Variable(TFE_TensorHandle* handle, TF_DataType type)
: handle_(handle), type_(type) {}
// Helper for constructing a resource handle and wrapping it in a `Variable`
// object.
static Variable* Create(TFE_Context* context, TF_DataType type,
const int64_t* dims, const int num_dims,
const char* device, TF_Status* status);
// Dereferences the backing buffer for the variable. Note that since this can
// fail (it runs operations), it must be called explicitly and the resulting
// `status` checked.
void Destroy(TFE_Context* context, TF_Status* status);
// Reads from the variable.
TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
// Assigns a new value to the variable.
void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
// Adds `value` to the existing value of the variable.
void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
TF_Status* status);
private:
// Helper for running any single-argument assignment ops (Assign, AssignAdd,
// AssignSub, ...).
void GeneralAssignment(const char* op_name, TFE_Context* context,
TFE_TensorHandle* value, TF_Status* status);
// The a handle for the resource-dtype tensor pointing to the variable's
// buffer.
TFE_TensorHandle* handle_;
// The dtype of the variable's buffer (input dtype for assignments, output
// dtype of read operations).
TF_DataType type_;
};
// Creates a TFE_TensorHandle with value `v`.
TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
// Creates a rank-one TFE_TensorHandle with value `v`.
TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
TF_Status* status);
// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
template <std::size_t num_replicas>
void ExtractPerDeviceValues(
TFE_Context* context, TFE_TensorHandle* input,
std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status);
TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
TFE_TensorHandle* second, TF_Status* status);
// Assert that `handle` is equal to `expected_value`.
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status);
// Create and modify a variable placed on a parallel device which composes
// `first_device` and `second_device`.
void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
const char* second_device);
// Implementations of templated functions ******************************
template <std::size_t num_replicas>
TensorHandlePtr CreatePerDeviceValues(
TFE_Context* context,
const std::array<TFE_TensorHandle*, num_replicas>& components,
const char* device, TF_Status* status) {
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrInt(op.get(), "N", num_replicas);
for (int i = 0; i < num_replicas; ++i) {
TFE_OpAddInput(op.get(), components[i], status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
TFE_OpSetDevice(op.get(), device, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* result_handle;
int num_retvals = 1;
TFE_Execute(op.get(), &result_handle, &num_retvals, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
return TensorHandlePtr(result_handle);
}
template <typename value_type>
void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_zero(
TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
EXPECT_EQ(expected_value,
*static_cast<value_type*>(TF_TensorData(value_zero.get())));
}
template <std::size_t num_devices>
void RegisterParallelDevice(
TFE_Context* context, const char* device_name,
const std::array<const char*, num_devices>& underlying_devices,
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
}
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_

View File

@ -0,0 +1,31 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for GCS environments
tf_cc_shared_object(
name = "gcs_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":gcs_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
],
)

View File

@ -0,0 +1,101 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdlib.h>
#include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
// TODO(vnvo2409): Implement later
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
// TODO(vnvo2409): Implement later
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
// TODO(vnvo2409): Implement later
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 1;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "gs");
}

View File

@ -57,6 +57,7 @@ cc_library(
":concrete_function", ":concrete_function",
":saved_model_api", ":saved_model_api",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load( Status TFSavedModelAPIImpl::Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) { EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl. // TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented( return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently"); "TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI { class TFSavedModelAPIImpl : public SavedModelAPI {
public: public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path, Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override; ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load( static Status Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out); EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override; std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default; ~TFSavedModelAPIImpl() override = default;
private: private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_; std::vector<ConcreteFunction> functions_;
}; };

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal", "//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api", "//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
extern "C" { extern "C" {
@ -34,10 +38,21 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) { TF_Status* status) {
std::string saved_model_dir(dirname); std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i])); tagset.insert(std::string(tags[i]));
} }
std::unique_ptr<tensorflow::SavedModelAPI> result = std::unique_ptr<tensorflow::SavedModelAPI> result;
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset), if (tensorflow::unwrap(ctx)->UsesTFRT()) {
&status->status); status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }

View File

@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"], hdrs = ["loader.h"],
deps = [ deps = [
":constants", ":constants",
":loader_util",
":reader", ":reader",
] + if_not_mobile([ ] + if_not_mobile([
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
], ],
) )
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test( tf_cc_test(
name = "bundle_v2_test", name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"], srcs = ["bundle_v2_test.cc"],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK(); return Status::OK();
} }
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir, Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name, const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name, const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session); nullptr /* outputs */, &run_metadata, session);
} }
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent( Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir, const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) { std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs; std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs)); internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir, RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name; string init_op_name;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(), asset_file_defs, bundle->session.get(),
init_op_name)); init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -67,13 +67,13 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:target", "@llvm-project//llvm:Target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep "@llvm-project//llvm:X86CodeGen", # fixdeps: keep
"//tensorflow/core:regexp_internal", "//tensorflow/core:regexp_internal",
] + if_llvm_aarch64_available([ ] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]), ]),
) )
@ -94,8 +94,8 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:resource_loader", "//tensorflow/core/platform:resource_loader",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", # fixdeps: keep "@llvm-project//llvm:Support", # fixdeps: keep
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep "@llvm-project//llvm:X86CodeGen", # fixdeps: keep
], ],
) )
@ -109,12 +109,12 @@ cc_library(
name = "llvm_targets", name = "llvm_targets",
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
deps = [ deps = [
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep "@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
"@llvm-project//llvm:target", "@llvm-project//llvm:Target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep "@llvm-project//llvm:X86CodeGen", # fixdeps: keep
] + if_llvm_aarch64_available([ ] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
]), ]),
) )
@ -286,9 +286,9 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:core", "@llvm-project//llvm:Core",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:target", "@llvm-project//llvm:Target",
], ],
) )

View File

@ -1,5 +1,5 @@
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s -dump-input-on-failure # RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=Bridge --debug_info=%s.debug.pbtxt 2>&1 | FileCheck %s
# RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s -dump-input-on-failure # RUN: not tfcompile --graph=%s --config=%s.config.pbtxt --mlir_components=None 2>&1 | FileCheck -check-prefix=OLD %s
# Checks the error message produced by tfcompile with mlir_component # Checks the error message produced by tfcompile with mlir_component
# Checks that source debug information is used in the output error message and # Checks that source debug information is used in the output error message and

View File

@ -4,6 +4,7 @@ traces: {
value: { value: {
file_line_cols: { file_line_cols: {
line: 1 line: 1
col: 1
} }
} }
} }
@ -12,9 +13,11 @@ traces: {
value: { value: {
file_line_cols: { file_line_cols: {
line: 3 line: 3
col: 1
} }
file_line_cols: { file_line_cols: {
line: 4 line: 4
col: 1
} }
} }
} }
@ -23,6 +26,7 @@ traces: {
value: { value: {
file_line_cols: { file_line_cols: {
line: 2 line: 2
col: 1
} }
} }
} }

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags; XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags; XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags; IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list; std::vector<Flag>* flag_list;
absl::once_flag flags_init; absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5; jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) { auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ','); jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true; return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount", Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount, &jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each " "The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")}); "element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags; return *jitter_flags;
} }
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names; std::vector<string> tensor_names;
}; };
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct; // Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer. // repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned. // This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags& const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags(); GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with // Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
// //

View File

@ -395,12 +395,11 @@ static void ShowXlaDeviceDeprecationWarning(
if (absl::StrContains(compilation_device_name, "CPU") || if (absl::StrContains(compilation_device_name, "CPU") ||
absl::StrContains(compilation_device_name, "GPU")) { absl::StrContains(compilation_device_name, "GPU")) {
absl::call_once(once, [] { absl::call_once(once, [] {
LOG(WARNING) LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
<< "XLA_GPU and XLA_CPU devices are deprecated and will be " "removed in subsequent releases. Instead, use either "
"removed in subsequent releases. Instead, use either " "@tf.function(experimental_compile=True) for must-compile "
"@tf.function(experimental_compile=True) for must-compile " "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
"semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " "for auto-clustering best-effort compilation.";
"for auto-clustering best-effort compilation.";
}); });
} }
} }

View File

@ -91,7 +91,7 @@ Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
} }
string message = absl::StrCat( string message = absl::StrCat(
"Function invoked by the following node is not compilable: ", "Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def), ".\n"); SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:"); absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) { for (const auto& node_info : uncompilable_node_info) {
string node_message = string node_message =

View File

@ -201,9 +201,7 @@ void XlaComputationLaunchContext::PopulateInputs(
se::Stream* stream = se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers. // Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
arg_buffers_.resize(kernel->xla_input_shapes.size());
arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
// Pass remaining parameters. // Pass remaining parameters.
const Tensor* t; const Tensor* t;
@ -239,11 +237,11 @@ void XlaComputationLaunchContext::PopulateInputs(
<< " not the same as on-host shape " << " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape); << xla::ShapeUtil::HumanStringWithLayout(shape);
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = absl::make_unique<ShapedBuffer>( arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape, /*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal()); client_->platform(), client_->default_device_ordinal());
arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = arg_buffers_[i].get(); arg_ptrs_[i] = &arg_buffers_.back();
} }
} }
} }
@ -470,10 +468,6 @@ Status XlaComputationLaunchContext::PopulateOutputs(
<< "Invalid input for outputs " << i << ": " << input_index; << "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index)); ctx->set_output(i, ctx->input(input_index));
} else { } else {
if (MustAliasOutput(input_output_alias, output_num)) {
DCHECK(output.buffer({output_num}).is_null())
<< "Expected output buffer to be aliased, but it is not nil.";
}
if (allocate_xla_tensors_) { if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors( TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output, input_output_alias, output_num, ctx, i, shape, &output,

View File

@ -165,7 +165,7 @@ class XlaComputationLaunchContext {
se::DeviceMemoryAllocator* xla_allocator_; se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_; bool allocate_xla_tensors_;
bool use_multiple_streams_; bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_; std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_; std::vector<xla::ShapedBuffer*> arg_ptrs_;
}; };

View File

@ -27,10 +27,6 @@ namespace tensorflow {
return xla_tensor; return xla_tensor;
} }
/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
return tensor.RefCountIsOne();
}
/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
const Tensor& tensor) { const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor); const XlaTensor* xla_tensor = FromTensor(&tensor);

View File

@ -39,8 +39,6 @@ class XlaTensor {
// fails. // fails.
static XlaTensor* FromTensor(const Tensor* tensor); static XlaTensor* FromTensor(const Tensor* tensor);
static bool RefCountIsOne(const Tensor& tensor);
// Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in
// which case the returned value is shaped_buffer()->root_buffer(), or a // which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is // normal Tensor in which case the returned value is
@ -57,7 +55,7 @@ class XlaTensor {
// manage the memory for these tensors a ShapedBuffer may be required. // manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this XlaTensor contains a ShapedBuffer. // Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
// Return the contained ShapedBuffer. // Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer() // REQUIRES: has_shaped_buffer()
const xla::ShapedBuffer& shaped_buffer() const { const xla::ShapedBuffer& shaped_buffer() const {
@ -70,8 +68,7 @@ class XlaTensor {
} }
// Mutates the XlaTensor to set the ShapedBuffer. // Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ = shaped_buffer_ = std::move(shaped_buffer);
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
} }
// Some tensors on the device may have known values on the host. We use these // Some tensors on the device may have known values on the host. We use these
@ -79,14 +76,12 @@ class XlaTensor {
// host value already. // host value already.
// Return true if this XlaTensor contains a host tensor. // Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; } bool has_host_tensor() const { return host_tensor_.has_value(); }
// Return the contained host tensor. // Return the contained host tensor.
// REQUIRES: has_host_tensor() // REQUIRES: has_host_tensor()
const Tensor& host_tensor() const { return *host_tensor_; } const Tensor& host_tensor() const { return *host_tensor_; }
// Sets the contained host tensor. // Sets the contained host tensor.
void set_host_tensor(const Tensor& tensor) { void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
host_tensor_.reset(new Tensor(tensor));
}
// Adds synchronization events to 'stream' that wait for this tensor to be // Adds synchronization events to 'stream' that wait for this tensor to be
// defined on 'stream'. Does nothing if the tensor is already defined on that // defined on 'stream'. Does nothing if the tensor is already defined on that
@ -113,9 +108,9 @@ class XlaTensor {
private: private:
// The optional contained ShapedBuffer. // The optional contained ShapedBuffer.
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_; absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value. // An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_; absl::optional<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been // An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content // defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined. // is always defined.

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"], hdrs = ["op_or_arg_name_mapper.h"],
deps = [ deps = [
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -42,7 +42,7 @@ cc_library(
":init_mlir", ":init_mlir",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"], hdrs = ["init_mlir.h"],
deps = [ deps = [
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -102,8 +102,9 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
alwayslink = 1, alwayslink = 1,
@ -154,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",

View File

@ -216,16 +216,16 @@ cc_library(
"ir/tfl_ops.h", "ir/tfl_ops.h",
"transforms/passes.h", "transforms/passes.h",
"utils/attribute_utils.h", "utils/attribute_utils.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
], ],
deps = [ deps = [
":tensorflow_lite_ops_inc_gen", ":tensorflow_lite_ops_inc_gen",
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
@ -253,7 +253,7 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -272,23 +272,7 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
tf_cc_test(
name = "tftext_utils_test",
size = "small",
srcs = ["utils/lstm_utils_test.cc"],
deps = [
":lstm_utils",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -305,7 +289,7 @@ cc_library(
], ],
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
@ -320,7 +304,7 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -330,6 +314,7 @@ tf_cc_test(
cc_library( cc_library(
name = "tensorflow_lite_legalize_tf", name = "tensorflow_lite_legalize_tf",
srcs = [ srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc", "transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc", "transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc", "transforms/generated_lower_static_tensor_list.inc",
@ -373,7 +358,7 @@ cc_library(
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -399,7 +384,7 @@ cc_library(
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -432,7 +417,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -457,7 +442,7 @@ cc_library(
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -510,8 +495,8 @@ tf_native_cc_binary(
"converter_gen.cc", "converter_gen.cc",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],
) )
@ -557,8 +542,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@flatbuffers", "@flatbuffers",
"@llvm-project//llvm:analysis", "@llvm-project//llvm:Analysis",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",
], ],
@ -635,7 +620,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@flatbuffers", "@flatbuffers",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -669,7 +654,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -729,7 +714,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
@ -759,7 +744,7 @@ cc_library(
"tf_tfl_translate_cl.h", "tf_tfl_translate_cl.h",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -771,7 +756,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -796,7 +781,7 @@ tf_cc_binary(
":tf_tfl_translate_cl_options", ":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer", ":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -821,7 +806,7 @@ tf_cc_binary(
":flatbuffer_translate_lib", ":flatbuffer_translate_lib",
":flatbuffer_translate_registeration", ":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -890,7 +875,7 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
@ -910,6 +895,6 @@ cc_library(
"//tensorflow/lite/experimental/mlir:__subpackages__", "//tensorflow/lite/experimental/mlir:__subpackages__",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )

View File

@ -32,7 +32,6 @@ struct PassConfig {
lower_tensor_list_ops(false), lower_tensor_list_ops(false),
trim_functions_whitelist({}), trim_functions_whitelist({}),
quant_specs(std::move(specs)), quant_specs(std::move(specs)),
skip_control_dialect(false),
form_clusters(false), form_clusters(false),
unfold_batch_matmul(true), unfold_batch_matmul(true),
legalize_tf_while(true), legalize_tf_while(true),
@ -49,13 +48,8 @@ struct PassConfig {
llvm::ArrayRef<std::string> trim_functions_whitelist; llvm::ArrayRef<std::string> trim_functions_whitelist;
// All information about quantization. // All information about quantization.
QuantizationSpecs quant_specs; QuantizationSpecs quant_specs;
// If `skip_control_dialect` is true, TF executor dialect is not converted to // If `form_clusters` is true , clusters are formed by grouping consecutive
// TF control dialect prior to legalization to TF Lite. // ops of the same device, under a `tf_device.launch` op.
// TODO(b/142911013): Remove flag once control dialect is removed.
bool skip_control_dialect;
// If `form_clusters` is true (and `skip_control_dialect` is true), clusters
// are formed by grouping consecutive ops of the same device, under a
// `tf_device.launch` op.
bool form_clusters; bool form_clusters;
// if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set // if `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set
// of tfl.fully_connected ops. // of tfl.fully_connected ops.

View File

@ -525,11 +525,16 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) {
auto *val = trait.getDef().getValue("tflRuntimePredicate"); auto *val = trait.getDef().getValue("tflRuntimePredicate");
if (!val) continue; if (!val) continue;
auto desc = trait.getDef().getValueAsString("tflRuntimeDescription");
mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue())); mlir::tblgen::Pred pred(dyn_cast<llvm::DefInit>(val->getValue()));
os << tgfmt( os << tgfmt(
" if (!($0)) {\n " " if (!($0)) {\n "
" return ::mlir::LogicalResult::Failure;\n }\n", " if (failure_on_operand_type_mismatch) {\n"
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx)); " return top.emitOpError(\"failed to verify that $1\");\n"
" } else {\n"
" return ::mlir::LogicalResult::Failure;\n }\n }\n",
&verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc);
} }
os << " return top.verify();\n}\n"; os << " return top.verify();\n}\n";
} }

View File

@ -424,6 +424,10 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor, StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer, const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) { OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder, TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true, /*shapeless_are_scalars=*/true,
/*is_constant=*/true)); /*is_constant=*/true));
@ -695,8 +699,6 @@ StatusOr<absl::flat_hash_set<const tflite::OperatorT*>> PruneSubgraph(
for (int32_t output : output_indices) { for (int32_t output : output_indices) {
if (auto& op = defining_op[output]) { if (auto& op = defining_op[output]) {
queue.push_back(op); queue.push_back(op);
} else {
return errors::InvalidArgument("Output tensor doesn't have defining op");
} }
} }
@ -801,9 +803,17 @@ StatusOr<FuncOp> ConvertSubgraph(
} }
for (auto output : func_outputs) { for (auto output : func_outputs) {
bool is_constant = !is_op_output[output]; const bool is_func_input = input_index_set.contains(output);
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
// 1. `is_constant` = true, means this tensor is created from a constant op.
// 2. `is_func_input` = true and `is_entry_point` = true, which means this
// tensor is function input and function input type is a scalar tensor.
const bool shapeless_is_scalar =
is_constant || (is_func_input && is_entry_point);
auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder, auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder,
/*shapeless_are_scalars=*/is_constant, shapeless_is_scalar,
/*is_constant=*/is_constant); /*is_constant=*/is_constant);
if (!type_or_err.ok()) { if (!type_or_err.ok()) {
emitError(func_loc, "error reading return types") emitError(func_loc, "error reading return types")
@ -858,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs)); subgraph, &builder, "outputs", func_outputs));
} }
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
} }
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops; absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -46,28 +46,183 @@ namespace mlir {
#include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_structs.cc.inc"
namespace TFL { namespace TFL {
// Returns true when the given two types have the same shape or broadcastable // Returns true when the given operand arguments have the same shape or
// shape within the given rank. If any given shapes are non-static, this method // broadcastable shape within the given rank. If any given shapes are
// returns true. // non-static and maximum rank is within the given rank, this method returns
bool IsBinaryOperandsHaveSameShapesOrBroadcastableShape(Type lhs, Type rhs, // true.
int max_bcast_rank) { bool VerifyOperandsHaveSameShapesOrBroadcastableShape(
// Ignore shape checking on the non-static shapes for model compatibility. Operation *op, ArrayRef<unsigned> indices, int max_bcast_rank) {
auto lhs_shaped_type = lhs.dyn_cast<ShapedType>(); if (indices.empty()) return true;
if (!lhs_shaped_type || !lhs_shaped_type.hasStaticShape()) return true;
auto rhs_shaped_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shaped_type || !rhs_shaped_type.hasStaticShape()) return true;
if (lhs_shaped_type.getShape().equals(rhs_shaped_type.getShape())) // First, it checks there are any inputs that has unknown rank.
return true; bool has_unknown_shape_input = false;
bool has_same_shape = true;
bool reach_first_known_shape = false;
int64_t max_rank = -1;
ArrayRef<int64_t> pivot_shape;
SmallVector<int64_t, 4> current_shape;
SmallVector<int64_t, 4> result_shape; SmallVector<int64_t, 4> result_shape;
if (!OpTrait::util::getBroadcastedShape(lhs_shaped_type.getShape(),
rhs_shaped_type.getShape(), for (unsigned index : indices) {
result_shape)) { ShapedType shaped_type =
return false; op->getOperand(index).getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasRank()) {
// Marks that we have an unknown rank input.
has_unknown_shape_input = true;
continue;
}
max_rank = std::max(max_rank, shaped_type.getRank());
if (!shaped_type.hasStaticShape()) {
// Marks that we have an unknown shape input.
has_unknown_shape_input = true;
continue;
}
ArrayRef<int64_t> shape = shaped_type.getShape();
if (!reach_first_known_shape) {
pivot_shape = shape;
current_shape.assign(shape.begin(), shape.end());
reach_first_known_shape = true;
continue;
}
if (!pivot_shape.equals(shape)) {
has_same_shape = false;
}
// Checks if all the inputs are broadcastable since they have not all the
// same shapes.
if (!OpTrait::util::getBroadcastedShape(current_shape, shape,
result_shape)) {
return false;
}
current_shape = result_shape;
} }
return lhs_shaped_type.getRank() <= max_bcast_rank &&
rhs_shaped_type.getRank() <= max_bcast_rank; // It will treat the unknown shape inputs as acceptable inputs for model
// compatibility unless there is an known rank that is bigger than the allowed
// broadcast maximum rank.
if (has_unknown_shape_input) return max_rank <= max_bcast_rank;
// If all the shape is known and same, CPU kernels are able to handle inputs
// regardless of dimension size.
return has_same_shape || max_rank <= max_bcast_rank;
}
// Return true when the given element_type is QI8.
bool IsQI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
quantized_type.isSigned();
}
// Return true when the given element_type is QUI8.
bool IsQUI8Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 8 &&
!quantized_type.isSigned();
}
// Return true when the given element_type is QI16.
bool IsQI16Type(Type element_type) {
auto quantized_type = element_type.dyn_cast<QuantizedType>();
return quantized_type != nullptr &&
quantized_type.getStorageTypeIntegralWidth() == 16 &&
quantized_type.isSigned();
}
// Return true when the given element_type is I32.
bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QI8, and QUI8 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsQI8Type(element_type) ||
IsQUI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
// Allows QI16 output when operands have the same shape.
if (IsQI16Type(element_type)) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return false;
}
// Return true if the given Sub operation has the CPU kernel supported shapes.
bool VerifySubOpShapeConstraints(SubOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows QI8 output when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsQI8Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
}
// Return true if the given Mul operation has the CPU kernel supported shapes.
bool VerifyMulOpShapeConstraints(MulOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
// Allows QI8 and QUI8 inputs up to five dimension broadcasting unless the
// output type is not QI16. If the output type is Q16, allows onlt the same
// shape operands.
if (IsQI8Type(element_type) || IsQUI8Type(element_type)) {
if (IsQI16Type(getElementTypeOrSelf(op.lhs().getType()))) {
return succeeded(
mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()));
}
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows F32 output when the operands have valid shapes, which are
// broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32()) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);
}
// Allows I32 and QI16 outputs when the operands have valid shapes, which are
// broadcastable shapes up to four dimension or have same shapes.
if (IsI32Type(element_type) || IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/4);
}
return false;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1882,7 +2037,7 @@ static LogicalResult Verify(TransposeConvOp op) {
auto expected_output_type = auto expected_output_type =
RankedTensorType::get(output_shape, output_type.getElementType()); RankedTensorType::get(output_shape, output_type.getElementType());
if (output_type != expected_output_type) { if (failed(mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}", return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type)); expected_output_type, output_type));
} }
@ -1966,9 +2121,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(TransposeOp op) { static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x().getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto perm_type = op.perm().getType().cast<ShapedType>(); auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y().getType().cast<ShapedType>(); auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) { if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError( return op.emitOpError(
@ -2004,7 +2159,8 @@ static LogicalResult Verify(TransposeOp op) {
} }
auto expected_output_type = auto expected_output_type =
RankedTensorType::get(transposed_shape, input_type.getElementType()); RankedTensorType::get(transposed_shape, input_type.getElementType());
if (output_type != expected_output_type) { if (failed(
mlir::verifyCompatibleShape(output_type, expected_output_type))) {
return op.emitOpError(llvm::formatv("expect output type {0}, got {1}", return op.emitOpError(llvm::formatv("expect output type {0}, got {1}",
expected_output_type, output_type)); expected_output_type, output_type));
} }

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -56,7 +56,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -85,7 +85,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",

View File

@ -55,8 +55,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
std::vector<string> node_names; std::vector<string> node_names;
std::vector<string> node_dtypes; std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes; std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins; std::vector<llvm::Optional<double>> node_mins;
std::vector<double> node_maxs; std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs. // Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -125,8 +125,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
std::vector<string> node_names; std::vector<string> node_names;
std::vector<string> node_dtypes; std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes; std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins; std::vector<llvm::Optional<double>> node_mins;
std::vector<double> node_maxs; std::vector<llvm::Optional<double>> node_maxs;
// Populate quantization specs. // Populate quantization specs.
TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs( TF_RETURN_IF_ERROR(internal::PopulateQuantizationSpecs(

View File

@ -177,14 +177,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
return RegisterCustomBuiltinOps(extra_tf_opdefs); return RegisterCustomBuiltinOps(extra_tf_opdefs);
} }
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, Status PopulateQuantizationSpecs(
const toco::TocoFlags& toco_flags, const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_names, std::vector<string>* node_dtypes,
std::vector<string>* node_dtypes, std::vector<std::vector<int>>* node_shapes,
std::vector<std::vector<int>>* node_shapes, std::vector<llvm::Optional<double>>* node_mins,
std::vector<double>* node_mins, std::vector<llvm::Optional<double>>* node_maxs) {
std::vector<double>* node_maxs) {
quant_specs->inference_input_type = quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type()); ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type = tensorflow::DataType inference_type =
@ -211,11 +210,16 @@ Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
flag.shape().dims().end())); flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats // Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) { if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN( if (flag.has_mean_value() && flag.has_std_value()) {
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(), TF_ASSIGN_OR_RETURN(
inference_type)); auto min_max, InputStatsToMinMax(flag.mean_value(),
node_mins->push_back(min_max.first); flag.std_value(), inference_type));
node_maxs->push_back(min_max.second); node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
} else {
node_mins->push_back(llvm::None);
node_maxs->push_back(llvm::None);
}
} }
} }
@ -254,7 +258,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
std::string error_message; std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message); auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) { if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in %s.", filename); return errors::InvalidArgument("Failed to open file in ", filename);
} }
mlir::PassManager pm(module.getContext()); mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os())); pm.addPass(mlir::createPrintOpGraphPass(output->os()));

View File

@ -34,14 +34,13 @@ Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags);
// Populate quantization specs (or not) given user specified ranges for each // Populate quantization specs (or not) given user specified ranges for each
// input arrays. // input arrays.
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags, Status PopulateQuantizationSpecs(
const toco::TocoFlags& toco_flags, const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs, mlir::TFL::QuantizationSpecs* quant_specs, std::vector<string>* node_names,
std::vector<string>* node_names, std::vector<string>* node_dtypes,
std::vector<string>* node_dtypes, std::vector<std::vector<int>>* node_shapes,
std::vector<std::vector<int>>* node_shapes, std::vector<llvm::Optional<double>>* node_mins,
std::vector<double>* node_mins, std::vector<llvm::Optional<double>>* node_maxs);
std::vector<double>* node_maxs);
// Convert imported MLIR file to TfLite flatbuffer. // Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well. // This will also run relevant passes as well.

View File

@ -3,6 +3,10 @@ load(
"//tensorflow/core/platform:build_config.bzl", "//tensorflow/core/platform:build_config.bzl",
"tf_proto_library", "tf_proto_library",
) )
load(
"//third_party/mlir:tblgen.bzl",
"gentbl",
)
package( package(
default_visibility = [ default_visibility = [
@ -23,6 +27,7 @@ package_group(
exports_files([ exports_files([
"quantization_traits.h", "quantization_traits.h",
"quantization_config.h", "quantization_config.h",
"quantization_utils.h",
]) ])
filegroup( filegroup(
@ -34,6 +39,25 @@ filegroup(
], ],
) )
gentbl(
name = "quantization_interfaces_inc_gen",
tbl_outs = [
(
"-gen-op-interface-decls",
"quantization_interface.h.inc",
),
(
"-gen-op-interface-defs",
"quantization_interface.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "quantization.td",
td_srcs = [
":quantization_td_files",
],
)
tf_proto_library( tf_proto_library(
name = "quantization_info_proto", name = "quantization_info_proto",
srcs = [ srcs = [
@ -56,7 +80,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -71,16 +95,18 @@ cc_library(
name = "quantization_lib", name = "quantization_lib",
srcs = [ srcs = [
"quantization_driver.cc", "quantization_driver.cc",
"quantization_interface.cc.inc",
"quantization_utils.cc", "quantization_utils.cc",
], ],
hdrs = [ hdrs = [
"quantization_interface.h.inc",
"quantization_traits.h", "quantization_traits.h",
"quantization_utils.h", "quantization_utils.h",
], ],
deps = [ deps = [
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -99,7 +125,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -109,8 +135,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc", "tools/op_quant_spec_getters_gen.cc",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],
) )
@ -131,7 +157,7 @@ cc_library(
deps = [ deps = [
":numerical_utils", ":numerical_utils",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -146,7 +172,7 @@ cc_library(
":device_target", ":device_target",
":quantization_lib", ":quantization_lib",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",

View File

@ -36,7 +36,7 @@ cc_library(
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],
@ -49,14 +49,16 @@ cc_library(
], ],
hdrs = [ hdrs = [
"tfl_to_std.h", "tfl_to_std.h",
"//tensorflow/compiler/mlir/lite/quantization:quantization_utils.h",
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@com_google_absl//absl/strings", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
], ],
) )
@ -71,7 +73,7 @@ tf_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
], ],
) )

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
namespace mlir { namespace mlir {
namespace TFL { namespace TFL {
@ -47,12 +48,18 @@ void ConvertMlirQuantOpsToTFLQuantOps(FuncOp func) {
auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(), auto dcast = b.create<DequantizeOp>(dq.getLoc(), dq.getResult().getType(),
dq.arg()); dq.arg());
dq.getResult().replaceAllUsesWith(dcast); dq.getResult().replaceAllUsesWith(dcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
dcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
dq.erase(); dq.erase();
} else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) { } else if (auto q = llvm::dyn_cast<quant::QuantizeCastOp>(op)) {
auto out_type = q.getResult().getType(); auto out_type = q.getResult().getType();
auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(), auto qcast = b.create<QuantizeOp>(q.getLoc(), out_type, q.arg(),
TypeAttr::get(out_type)); TypeAttr::get(out_type));
q.getResult().replaceAllUsesWith(qcast); q.getResult().replaceAllUsesWith(qcast);
if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) {
qcast.setAttr(mlir::quant::kVolatileOpAttrName, extra_attr);
}
q.erase(); q.erase();
} }
}); });

View File

@ -63,6 +63,22 @@ def QI32 : QuantizedType<"Uniform", [32], 1>;
// https://www.tensorflow.org/lite/performance/quantization_spec // https://www.tensorflow.org/lite/performance/quantization_spec
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TODO(b/157870442): replace all FixedResultScale trait
def FixedOutputRangeInterface : OpInterface<
"FixedOutputRangeInterface"> {
let description = [{
Interface for defining the fixed output range.
}];
let methods = [
InterfaceMethod<
[{Returns the fixed output range.}],
"UniformQuantizedType", "GetFixedOutputRange",
(ins "bool":$sign, "int":$bit_width)
>,
];
}
// Specify this trait if the op has a fixed output value range. // Specify this trait if the op has a fixed output value range.
class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat( class FixedResultScale<QuantizedType qt> : NativeOpTrait<!strconcat(
"quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>; "quant::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;

View File

@ -45,7 +45,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
absl::string_view inference_type, absl::string_view inference_type,
QuantizationSpecs* quant_specs) { QuantizationSpecs* quant_specs) {
std::vector<std::string> input_nodes = absl::StrSplit(node_names, ','); std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
std::vector<double> node_mins; std::vector<llvm::Optional<double>> node_mins;
if (!min_values.empty()) { if (!min_values.empty()) {
std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ','); std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
for (int i = 0; i < node_mins_str.size(); i++) { for (int i = 0; i < node_mins_str.size(); i++) {
@ -57,7 +57,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
} }
} }
std::vector<double> node_maxs; std::vector<llvm::Optional<double>> node_maxs;
if (!max_values.empty()) { if (!max_values.empty()) {
std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ','); std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
for (int i = 0; i < node_maxs_str.size(); i++) { for (int i = 0; i < node_maxs_str.size(); i++) {
@ -79,11 +79,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
quant_specs); quant_specs);
} }
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names, bool GetInputNodeQuantSpecs(
const std::vector<double>& node_mins, const std::vector<std::string>& node_names,
const std::vector<double>& node_maxs, const std::vector<llvm::Optional<double>>& node_mins,
tensorflow::DataType inference_type, const std::vector<llvm::Optional<double>>& node_maxs,
QuantizationSpecs* quant_specs) { tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
quant_specs->inference_type = inference_type; quant_specs->inference_type = inference_type;
// If min/max are not specified, just return; // If min/max are not specified, just return;

View File

@ -69,7 +69,8 @@ struct QuantizationSpecs {
// arguments. They are only used when `weight_quantization` is set to false, // arguments. They are only used when `weight_quantization` is set to false,
// and the model is required to have quantization parameters, either from // and the model is required to have quantization parameters, either from
// quantization aware training or calibration, for the remaining tensors. // quantization aware training or calibration, for the remaining tensors.
std::vector<std::pair<double, double>> input_ranges; std::vector<std::pair<llvm::Optional<double>, llvm::Optional<double>>>
input_ranges;
// The default ranges can be used when a tensor doesn't have quantization // The default ranges can be used when a tensor doesn't have quantization
// parameters and couldn't be quantized. Used only for latency tests. // parameters and couldn't be quantized. Used only for latency tests.
@ -130,11 +131,11 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names,
// Gets the quantization specification for input arrays. The array names are not // Gets the quantization specification for input arrays. The array names are not
// stored in the spec, and will be matched by position. The min/max will be // stored in the spec, and will be matched by position. The min/max will be
// ignored if the inference_type isn't a quantized type. Returns true if failed. // ignored if the inference_type isn't a quantized type. Returns true if failed.
bool GetInputNodeQuantSpecs(const std::vector<std::string>& node_names, bool GetInputNodeQuantSpecs(
const std::vector<double>& node_mins, const std::vector<std::string>& node_names,
const std::vector<double>& node_maxs, const std::vector<llvm::Optional<double>>& node_mins,
tensorflow::DataType inference_type, const std::vector<llvm::Optional<double>>& node_maxs,
QuantizationSpecs* quant_specs); tensorflow::DataType inference_type, QuantizationSpecs* quant_specs);
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -494,6 +494,13 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value); auto quantize = builder_.create<quant::QuantizeCastOp>(loc, new_type, value);
auto dequantize = builder_.create<quant::DequantizeCastOp>( auto dequantize = builder_.create<quant::DequantizeCastOp>(
loc, expressed_type, quantize.getResult()); loc, expressed_type, quantize.getResult());
// This attribute is set to distinguish the quantize ops being added by the
// quantization pass. These ops can be removed without losing original
// program accuracy.
// TODO(fengliuai): make the attribute being part of op definition.
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
// `original_result` has a use to `quantize`, so this will replace that use // `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards // by the result of `dequantize`. Remember to reset that use afterwards
value.replaceAllUsesWith(dequantize); value.replaceAllUsesWith(dequantize);

View File

@ -21,13 +21,18 @@ limitations under the License.
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project
namespace mlir {
namespace OpTrait {
namespace quant {
using QuantizedType = mlir::quant::QuantizedType; using QuantizedType = mlir::quant::QuantizedType;
using UniformQuantizedType = mlir::quant::UniformQuantizedType; using UniformQuantizedType = mlir::quant::UniformQuantizedType;
namespace mlir {
// This includes the interface class definition. It couldn't be in a namespace
// because the table gen doesn't emit the namespace when it is used.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.h.inc"
namespace OpTrait {
namespace quant {
// The base class that all the quantization related OpTrait implements. // The base class that all the quantization related OpTrait implements.
template <typename ConcreteType, template <typename> class TraitType> template <typename ConcreteType, template <typename> class TraitType>
struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> { struct QuantizationSpecTraitBase : public TraitBase<ConcreteType, TraitType> {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
@ -436,6 +437,16 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops; llvm::SmallVector<quant::StatisticsOp, 16> all_stats_ops;
llvm::DenseSet<Operation*> redundant_stats_ops; llvm::DenseSet<Operation*> redundant_stats_ops;
// Step 0: remove the quant::StatisticsOp which are used by the tfl.quantize
// op in case it overrides the information from training FakeQuant ops.
func.walk([&](quant::QuantizeCastOp q) {
auto input_op = q.arg().getDefiningOp();
if (auto stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(input_op)) {
q.setOperand(stats.arg());
if (stats.use_empty()) stats.erase();
}
});
// Step 1: forward pass: propagate any value scales which are not produces // Step 1: forward pass: propagate any value scales which are not produces
// by `SameOperandsAndResultsScale`. Additionally, remove the value scales // by `SameOperandsAndResultsScale`. Additionally, remove the value scales
// which are produced by the `restricted_output_params`. // which are produced by the `restricted_output_params`.

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
@ -42,6 +43,11 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace quant { namespace quant {
// A unit attribute can be attached to the quantize/dequantize ops which are
// added by the quantization passes. These ops can be removed erased without
// losing accuracy.
constexpr char kVolatileOpAttrName[] = "volatile";
using QuantParams = quant::QuantizedType; using QuantParams = quant::QuantizedType;
using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign using SignedInteger = std::pair<unsigned, unsigned>; // bitwidth and sign
using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>; using QuantParamsForResults = llvm::SmallVector<QuantParams, 4>;
@ -380,7 +386,8 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp(); Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure(); if (!def) return failure();
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() || if (llvm::isa<FixedOutputRangeInterface>(def) ||
def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) { def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure(); return failure();
} }

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s --dump-input-on-failure // RUN: tf-opt %s -quant-import-stats --quant-test-stats='entries { name: "op" params { min_max { min: -1 max: 1 } } } entries { name: "op_0:0" params { min_max { min: -2 max: 2 } } } entries { name_regex: "op_*" params { min_max { min: -3 max: 3 } } }' | FileCheck %s
// CHECK-LABEL: import_stats_skip // CHECK-LABEL: import_stats_skip

View File

@ -32,7 +32,7 @@ cc_library(
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s --dump-input-on-failure // RUN: tf-opt -pass-pipeline='func(canonicalize)' %s | FileCheck %s
// Checks that tfl.reshape should be removed if its output's only user is // Checks that tfl.reshape should be removed if its output's only user is
// another tfl.reshape // another tfl.reshape

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input-on-failure // RUN: tf-opt %s -canonicalize | FileCheck %s
// CHECK-LABEL: @add_float // CHECK-LABEL: @add_float
func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s --dump-input-on-failure // RUN: tf-opt %s -tfl-identify-dilated-conv | FileCheck %s
func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> { func @testDilatedConv(%arg0: tensor<1x128x128x3xf32>, %arg1: tensor<5x5x3x8xf32>) -> tensor<1x128x128x8xf32> {
%cst = constant dense<[2, 2]> : tensor<2xi32> %cst = constant dense<[2, 2]> : tensor<2xi32>

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s # RUN: tf_tfl_translate -tf-input-arrays=input0,input1 -tf-input-shapes=4:4 -tf-input-data-types=DT_INT32,DT_INT32 -tf-output-arrays=Add %s -o - | flatbuffer_to_string - | FileCheck %s
# Add two tensor<4xi32> inputs and return the result # Add two tensor<4xi32> inputs and return the result

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s --dump-input-on-failure # RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - --output-mlir 2>&1 | FileCheck --check-prefix=MLIR %s
# RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s # RUN: tf_tfl_translate -tf-input-arrays=input -tf-input-shapes=1,1,1,256 -tf-input-data-types=DT_FLOAT -tf-inference-type=DT_QINT8 -tf-input-min-values='-33.614346' -tf-input-max-values='21.54917' -tf-output-arrays=output %s -o - | flatbuffer_to_string - | FileCheck %s
node { node {

View File

@ -1,4 +1,4 @@
# RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s --dump-input-on-failure # RUN: tf_tfl_translate -tf-input-arrays=unranked -tf-input-shapes=1,8,8,2 -tf-input-data-types=DT_INT32 -tf-output-arrays=unranked,static,static_10 %s -o - --output-mlir | FileCheck %s
node { node {
name: "tf.Const" name: "tf.Const"

View File

@ -0,0 +1,421 @@
# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s
node {
name: "tf.Less"
op: "Less"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "my_equal"
op: "Equal"
input: "a"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "cst0"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 1.0
float_val: 2.0
float_val: 3.0
float_val: 4.0
}
}
}
}
node {
name: "cst1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 4
}
}
float_val: 5.0
float_val: 6.0
float_val: 7.0
float_val: 8.0
}
}
}
}
node {
name: "StatefulIf"
op: "If"
input: "tf.Less"
input: "a"
input: "b"
input: "cst0"
input: "cst1"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true"
}
}
}
experimental_debug_info {
}
}
node {
name: "StatelessIf"
op: "StatelessIf"
input: "my_equal"
input: "a"
input: "b"
attr {
key: "Tcond"
value {
type: DT_BOOL
}
}
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "else_branch"
value {
func {
name: "cond_false_1"
}
}
}
attr {
key: "then_branch"
value {
func {
name: "cond_true_1"
}
}
}
experimental_debug_info {
}
}
node {
name: "main"
op: "_Retval"
input: "StatefulIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 0
}
}
}
node {
name: "main1"
op: "_Retval"
input: "StatelessIf"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "index"
value {
i: 1
}
}
}
node {
name: "a"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
}
}
library {
function {
signature {
name: "cond_true"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Add"
op: "Add"
input: "cond_true_arg2"
input: "cond_true_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Add"
}
}
ret {
key: "cond_true_ret"
value: "tf.Add:z:0"
}
}
function {
signature {
name: "cond_false"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg2"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg3"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Mul"
op: "Mul"
input: "cond_false_arg0"
input: "cond_false_arg3"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Mul"
}
}
ret {
key: "cond_false_ret"
value: "tf.Mul:z:0"
}
}
function {
signature {
name: "cond_true_1"
input_arg {
name: "cond_true_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_true_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_true_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Sub"
op: "Sub"
input: "cond_true_arg0"
input: "cond_true_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Sub"
}
}
ret {
key: "cond_true_ret"
value: "tf.Sub:z:0"
}
}
function {
signature {
name: "cond_false_1"
input_arg {
name: "cond_false_arg0"
type: DT_FLOAT
}
input_arg {
name: "cond_false_arg1"
type: DT_FLOAT
}
output_arg {
name: "cond_false_ret"
type: DT_FLOAT
}
}
node_def {
name: "tf.Div"
op: "Div"
input: "cond_false_arg0"
input: "cond_false_arg1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "tf.Div"
}
}
ret {
key: "cond_false_ret"
value: "tf.Div:z:0"
}
}
}
versions {
producer: 115
min_consumer: 12
}
# CHECK: func @StatefulIf_else
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK-NEXT: tfl.mul
# CHECK: func @StatefulIf_then
# CHECK-NEXT: constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
# CHECK-NEXT: return
# CHECK: func @StatelessIf_else
# CHECK-NEXT: tfl.div
# CHECK: func @StatelessIf_then
# CHECK-NEXT: tfl.sub
# CHECK: "tf.If"{{.+}}else_branch = @StatelessIf_else{{.+}}then_branch = @StatelessIf_then
# CHECK: "tf.If"{{.+}}else_branch = @StatefulIf_else{{.+}}then_branch = @StatefulIf_then

View File

@ -54,7 +54,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -70,6 +70,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure basic_lstm roundtrip exactly // Ensure basic_lstm roundtrip exactly
func @main(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>) -> tensor<1x96xf32> { func @main(%arg0: tensor<1x384xf32>, %arg1: tensor<1x96xf32>, %arg2: tensor<384x480xf32>, %arg3: tensor<384xf32>, %arg4: tensor<1x96xf32>) -> tensor<1x96xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure constants roundtrip exactly // Ensure constants roundtrip exactly
func @bool() -> tensor<4xi1> { func @bool() -> tensor<4xi1> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> {
%0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = opaque<"tfl", "0x010000000200000002000000"> : tensor<12xi8>} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32>

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> // CHECK: func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32>
func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> { func @main(%arg0: tensor<?x19x19x3xf32>) -> tensor<?x9x9x4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --use-external-constant - -o - | FileCheck %s
// Ensure that `tfl.external_const` is imported when the flag `-use-external-constant` is enabled. // Ensure that `tfl.external_const` is imported when the flag `-use-external-constant` is enabled.
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm function references in if ops are preserved // Confirm function references in if ops are preserved
func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// CHECK: %cst = constant unit // CHECK: %cst = constant unit
// CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32> // CHECK: %[[RES0:.*]] = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 0 : i32, stride_w = 0 : i32} : (tensor<256x32x32x3xf32>, tensor<16x3x3x3xf32>, none) -> tensor<256x32x32x16xf32>

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// This only test the exporter and importer are working without min/max quantization parameters. // This only test the exporter and importer are working without min/max quantization parameters.
func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> { func @main(tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate -input-arrays=squared_difference --experimental-prune-unreachable-nodes-unconditionally --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Tests -input-arrays flag. // Tests -input-arrays flag.
func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Tests input and output names from FlatBuffer are added to `tf.entry_function` attribute. // Tests input and output names from FlatBuffer are added to `tf.entry_function` attribute.

View File

@ -1,15 +1,15 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Ensure lstm roundtrip exactly // Ensure lstm roundtrip exactly
func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> {
%cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const")
%24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %24 : tensor<1x4xf32> return %24 : tensor<1x4xf32>
// CHECK-LABEL: main // CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252 // seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( { // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]] // CHECK: return %[[RES0]]
} }

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm a wide array of attribute survives the round-trip // Confirm a wide array of attribute survives the round-trip
func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Confirm float constants and operators survive a roundtrip // Confirm float constants and operators survive a roundtrip
func @main(tensor<4xf32>) -> tensor<4xf32> { func @main(tensor<4xf32>) -> tensor<4xf32> {

View File

@ -1,4 +1,4 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s
// Test to make sure optional parameters survive a roundtrip // Test to make sure optional parameters survive a roundtrip
func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {

View File

@ -1,4 +1,4 @@
// RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck --dump-input-on-failure %s // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir -o - | FileCheck %s
// This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d. // This test is to test that if the flatbuffer omits the last optional input `bias` of tfl.conv_2d op, the flatbuffer_importer will automatically adds `none` value to tfl.conv_2d.

Some files were not shown because too many files have changed in this diff Show More