Compare commits

...

53 Commits

Author SHA1 Message Date
27a1657c4f Add rpi4ub-armv8 build configuration 2021-12-04 15:37:18 +00:00
Reuben Morais
4bdd395511 Revert "Move native_client specific WORKSPACE changes to root WORKSPACE file"
This reverts commit 9b67f161e5.
2021-09-21 16:39:29 +02:00
Reuben Morais
9173bdff3d Streamline Android build options 2021-09-21 15:54:32 +02:00
Reuben Morais
182e869fb8 Streamline RPi3 and RPi3-ARMv8 build options 2021-09-21 15:05:53 +02:00
Reuben Morais
dc74c09b8d
Merge pull request #4 from coqui-ai/pull_request_template 2021-08-03 15:49:04 +02:00
kdavis-coqui
90f5b25508 Added pull_request_template 2021-08-02 18:30:12 +02:00
Reuben Morais
9b67f161e5 Move native_client specific WORKSPACE changes to root WORKSPACE file 2021-07-23 16:21:28 +02:00
Reuben Morais
811608d4e4 Fix broken mirror link for AWS SDK 2021-07-23 16:02:00 +02:00
Reuben Morais
ca5d9fdf5c Add sigmoid and tanh to deepspeech_cwise_ops 2021-01-02 17:17:29 +00:00
lissyx
23ad988fcd
Merge pull request #124 from bernardohenz/layer-norm
Adding dependencies for layer normalization
2020-08-25 13:18:09 +02:00
Bernardo Henz
6dc2a1becf Adding dependencies for layer normalization 2020-08-19 13:05:58 -03:00
lissyx
4336a5b49f
Merge pull request #123 from lissyx/r2.3+moz
Mozilla TensorFlow r2.3
2020-08-06 22:39:47 +02:00
Alexandre Lissy
6fad14b203 Mozilla TensorFlow r2.3 2020-08-06 19:19:26 +02:00
Mihai Maruseac
ee598066c4
Merge pull request #41837 from tensorflow/ggadde-cp2-2-3
Update the release notes to fix some typos and missed changes.
2020-07-28 20:45:39 +00:00
Goldie Gadde
fd3b3ca6f5 Update the release notes to fix some typos and missed changes. 2020-07-28 13:34:56 -07:00
Goldie Gadde
b36436b087
Merge pull request #41670 from tensorflow-jenkins/version-numbers-2.3.0-24378
Update version numbers for TensorFlow 2.3.0
2020-07-23 17:09:13 -07:00
Mihai Maruseac
c4b2951888
Merge pull request #40701 from tensorflow-jenkins/relnotes-2.3.0rc0-13382
Update release notes for TensorFlow 2.3.0
2020-07-23 22:33:42 +00:00
TensorFlow Release Automation
82515cee8f Update version numbers to 2.3.0 2020-07-23 11:05:30 -07:00
Mihai Maruseac
ab9c694484
Merge pull request #41644 from tensorflow/update_v
Updating estimator version after estimator final release
2020-07-23 16:13:58 +00:00
Geeta Chavan
1f610fc5ae Updating version for final release 2020-07-22 17:44:09 -07:00
Geeta Chavan
44e3817ad0 Updating version for final release 2020-07-22 17:39:39 -07:00
Goldie Gadde
dbbdcde0fd
Update RELEASE.md 2020-07-22 14:21:11 -07:00
Goldie Gadde
ca2b7ba75c
Merge pull request #41636 from geetachavan1/cherrypicks_F4YCK
[CherryPick 2.3] Going back to forcing embedding layer variables on the CPU even within a tf.function as this is breaking some user code.
2020-07-22 14:12:10 -07:00
E
67ab71d747
Merge pull request #41633 from psybuzz/cherrypicks_WRXIO
This cherrypick is required to depend on a version of TensorBoard that
is compatible with TensorFlow 2.3.x, such that `tensorboard` starts
successfully.

Note: starting from TF/TB 2.3, the tensorboard dependency is more
relaxed to allow TensorBoard to be released at a more frequent cadence.
The major version is still synced, while the minor version is not.

PiperOrigin-RevId: 322474571
2020-07-22 12:47:19 -07:00
Rohan Jain
d60d7d3c7e Going back to forcing embedding layer variables on the CPU even within a tf.function as this is breaking some user code.
PiperOrigin-RevId: 321607029
Change-Id: Id159867f51b26e6604a1186d9ce526658ddd1e19
2020-07-22 11:45:58 -07:00
A. Unique TensorFlower
2ef4243a20 Update tensorboard dependency to 2.3.x
TensorBoard release: https://pypi.org/project/tensorboard/2.3.0/

PiperOrigin-RevId: 322474571
Change-Id: I7cb6bcbb101cb9b10d04b832279e62ea9066abca
2020-07-22 10:43:37 -07:00
Austin Anderson
f9233753f3
Reword tf.sysconfig.get_build_info notice
An internal bug report revealed that "support" could be mistaken for "official support," which is not intended.
2020-07-21 16:04:26 -07:00
Goldie Gadde
bb3c460114
Merge pull request #41498 from tensorflow-jenkins/version-numbers-2.3.0rc2-11909
Update version numbers for TensorFlow 2.3.0-rc2
2020-07-17 12:45:59 -07:00
Goldie Gadde
f923fa474b
Merge pull request #41483 from tomerk/cherrypicks_DC0YA
Raise an error when some but not all values passed to the first layer…
2020-07-17 12:07:45 -07:00
Mihai Maruseac
8edf01fbdc
Merge pull request #41357 from kenfranko/cherrypicks_4ECL0
2.3-rc2 cherry-pick request: Correctly set the experimental_io_device when restoring variable from a checkpoint.
2020-07-17 18:48:31 +00:00
Mihai Maruseac
557fdcfcc9
Merge pull request #41464 from aaudiber/cherrypicks_DM5NO
Update "master" to "dispatch"/"dispatcher" in tf.data service terminology
2020-07-17 18:48:17 +00:00
TensorFlow Release Automation
f989a28561 Update version numbers to 2.3.0-rc2 2020-07-17 09:34:18 -07:00
Goldie Gadde
89bb4c3f42
Merge pull request #41385 from tomerk/cherrypicks_ORV3K
[Cherrypick:r2.3] Explicitly raise a (clearer) error message when models end up in inva…
2020-07-17 09:20:32 -07:00
Goldie Gadde
b8c74daee7
Merge pull request #41426 from geetachavan1/cherrypicks_MI0K1
[Cherrypick r2.3] Fix critical bug with `add_loss` TFOpLayer graph construction that caused incorrect loss values and backprop issues
2020-07-17 09:16:51 -07:00
Tomer Kaftan
7e0abd9c89 Relax the error about functional api construction w/ a mix of symbolic and non-symbolic tensors for built-in layers (such as layers.add and layers.multiply where using constants is a common user pattern)
PiperOrigin-RevId: 321698209
Change-Id: Ief13e59aec91b787361a7760318ecd47870d938f
2020-07-16 21:40:28 -07:00
Tomer Kaftan
7d310df2de Raise an error when some but not all values passed to the first layer call arg are symbolic. This setting can cause functional models to be constructed incorrectly.
Support for this will be added when we enable the KerasTensors refactoring.

Addreses GitHub Issue #40638

PiperOrigin-RevId: 321639068
Change-Id: Iebf0e1198018fe44b1f60673bd991a9262ecef7d
2020-07-16 21:40:28 -07:00
Andrew Audibert
6b9a9d98bb Update "master" to "dispatch"/"dispatcher" in tf.data service terminology.
Dispatcher is more descriptive and follows the guidance in https://developers.google.com/style/word-list#master

PiperOrigin-RevId: 321613785
Change-Id: Iaa576d35f0581e21278101f8b31201ba737a6865
2020-07-16 13:07:28 -07:00
Mihai Maruseac
0aa1d61fad
Merge pull request #41435 from tensorflow/Fix-Estimator-min-bound
Fix typo in estimator min bound
2020-07-16 16:28:02 +00:00
Mihai Maruseac
0b2321fdd1
Fix typo in estimator min bound
Should be `2.3.0rc0`, not `2.3.0-rc0`
2020-07-15 17:25:43 -07:00
Goldie Gadde
2b03d7b7a0
Update RELEASE.md 2020-07-15 17:22:49 -07:00
Goldie Gadde
549064075e
Update RELEASE.md 2020-07-15 14:58:17 -07:00
Francois Chollet
d03e29a094 Fix critical bug with add_loss TFOpLayer graph construction
that caused incorrect loss values and backprop issues

PiperOrigin-RevId: 320257330
Change-Id: I0a030bc7632735b152454657fd15e41539b4e4bd
2020-07-15 12:18:38 -07:00
Tomer Kaftan
ef4db27b31 Explicitly raise a (clearer) error message when models end up in invalid states due to interleaving graph and eager.
In rare cases code may have run w/o crashing when in these invalid states, but it's safer to error with an explanation rather than risk silent failures/fragile behavior.

PiperOrigin-RevId: 321192744
Change-Id: I9e97ac3b7cea27c9b389e5202de9f1c09a4aa2b8
2020-07-14 11:45:20 -07:00
Shanqing Cai
13c4eadd25
Grammar tweaks in the Debugger V2 bullet point 2020-07-13 21:51:43 -04:00
Shanqing Cai
de4c4425b7
Add mentioned Debugger V2 to r2.3 release notes 2020-07-13 21:47:19 -04:00
Ken Franko
b8694e39d8 Correctly set the experimental_io_device when restoring variable from a checkpoint.
PiperOrigin-RevId: 320222381
Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
2020-07-13 15:50:48 -07:00
Goldie Gadde
98a59288c8
Update RELEASE.md 2020-07-08 18:22:59 -07:00
Goldie Gadde
257447e193
Update RELEASE.md 2020-07-06 15:23:52 -07:00
Austin Anderson
61b2024a19
Added point about tf.sysconfig.get_build_info() 2020-07-06 14:24:16 -07:00
Goldie Gadde
d3dc6a2071
Update RELEASE.md 2020-06-26 15:49:19 -07:00
Goldie Gadde
9310f2a180
Update RELEASE.md 2020-06-26 15:41:40 -07:00
Goldie Gadde
3b27581629
Update RELEASE.md 2020-06-26 07:38:48 -07:00
TensorFlow Release Automation
b4c95671f2 Insert release notes place-fill 2020-06-22 21:30:45 -07:00
82 changed files with 2385 additions and 403 deletions

View File

@ -94,6 +94,9 @@ build:libc++ --linkopt -fuse-ld=lld
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
build:android --crosstool_top=//external:android/crosstool
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:android --copt=-D_GLIBCXX_USE_C99
build:android --cxxopt=-std=c++14
build:android --action_env ANDROID_NDK_API_LEVEL=21
build:android_arm --config=android
build:android_arm --cpu=armeabi-v7a
build:android_arm --fat_apk_cpu=armeabi-v7a
@ -202,6 +205,29 @@ build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --co
build:sycl_nodouble --config=sycl
build:sycl_trisycl --define=using_trisycl=true
build --copt=-DTFLITE_WITH_RUY_GEMV
build:rpi3 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:rpi3 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-armeabi:toolchain
build:rpi3 --cpu=armv7a --define=target_system=rpi3
build:rpi3 --copt=-march=armv7-a --copt=-mtune=cortex-a53 --copt=-mfloat-abi=hard --copt=-mfpu=neon-fp-armv8 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99 --copt=-mno-unaligned-access
build:rpi3 --define=tensorflow_mkldnn_contraction_kernel=0
build:rpi3_opt -c opt --config=rpi3 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
build:rpi3-armv8 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:rpi3-armv8 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-aarch64:toolchain
build:rpi3-armv8 --cpu=aarch64 --define=target_system=rpi3-armv8
build:rpi3-armv8 --copt=-march=armv8-a --copt=-mtune=cortex-a53 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99
build:rpi3-armv8 --define=tensorflow_mkldnn_contraction_kernel=0
build:rpi3-armv8_opt -c opt --config=rpi3-armv8 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
build:rpi4ub-armv8 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
build:rpi4ub-armv8 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-aarch64:toolchain
build:rpi4ub-armv8 --cpu=aarch64 --define=target_system=rpi4ub-armv8
build:rpi4ub-armv8 --copt=-march=armv8-a --copt=-mtune=cortex-a72 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99
build:rpi4ub-armv8 --define=tensorflow_mkldnn_contraction_kernel=0
build:rpi4ub-armv8_opt -c opt --config=rpi4ub-armv8 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
# Options extracted from configure script
build:ngraph --define=with_ngraph_support=true
build:numa --define=with_numa_support=true

15
.github/pull_request_template.md vendored Normal file
View File

@ -0,0 +1,15 @@
# Pull request guidelines
Welcome to the 🐸tensorflow project! We are excited to see your interest, and appreciate your support!
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/tensorflow).
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/tensorflow):
- Protects you, Coqui, and the users of the code.
- Does not change your rights to use your contributions for any purpose.
- Does not change the license of the 🐸tensorflow project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.

View File

@ -1,19 +1,206 @@
# Release 2.3.0
## Breaking Changes
## Major Features and Improvements
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your models memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
* TFLite now properly supports dynamic shapes during conversion and inference. Weve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
## Breaking Changes
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
* `tf.data`
* Makes the following (breaking) changes to the `tf.data`.
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
* `tf.keras`
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
* `tf.image.extract_glimpse` has been updated to correctly process the case
where `centered=False` and `normalized=False`. This is a breaking change as
the output is different from (incorrect) previous versions. Note this
breaking change only impacts `tf.image.extract_glimpse` and
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
## Known Caveats
* `tf.lite`
* Keras-based LSTM models must be converted with an explicit batch size in the input layer.
## Bug Fixes and Other Changes
* Mutable tables now restore checkpointed values when loaded from SavedModel.
### TF Core:
* Set `tf2_behavior` to 1 to enable V2 for early loading cases.
* Add `execute_fn_for_device function` to dynamically choose the implementation based on underlying device placement.
* Eager:
* Add `reduce_logsumexp` benchmark with experiment compile.
* Give `EagerTensor`s a meaningful `__array__` implementation.
* Add another version of defun matmul for performance analysis.
* `tf.function`/AutoGraph:
* `AutoGraph` now includes into TensorFlow loops any variables that are closed over by local functions. Previously, such variables were sometimes incorrectly ignored.
* functions returned by the `get_concrete_function` method of `tf.function` objects can now be called with arguments consistent with the original arguments or type specs passed to `get_concrete_function`. This calling convention is now the preferred way to use concrete functions with nested values and composite tensors. Please check the [guide](https://www.tensorflow.org/guide/concrete_function) for more details on `concrete_ function`.
* Update `tf.function`'s `experimental_relax_shapes` to handle composite tensors appropriately.
* Optimize `tf.function` invocation, by removing redundant list converter.
* `tf.function` will retrace when called with a different variable instead of simply using the `dtype` & `shape`.
* [Improve support](https://github.com/tensorflow/tensorflow/issues/33862) for dynamically-sized TensorArray inside `tf.function`.
* `tf.math`:
* Narrow down `argmin`/`argmax` contract to always return the smallest index for ties.
* `tf.math.reduce_variance` and `tf.math.reduce_std` return correct computation for complex types and no longer support integer types.
* Add Bessel functions of order 0,1 to `tf.math.special`.
* `tf.divide` now always returns a tensor to be consistent with documentation and other APIs.
* `tf.image`:
* Replaced [`tf.image.non_max_suppression_padded`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/image/non_max_suppression_padded?hl=en) with a new implementation that supports batched inputs, which is considerably faster on TPUs and GPUs. Boxes with area=0 will be ignored. Existing usage with single inputs should still work as before.
* `tf.linalg`
* Add `tf.linalg.banded_triangular_solve`.
* `tf.random`:
* Add `tf.random.stateless_parameterized_truncated_normal`.
* `tf.ragged`:
* Add `tf.ragged.cross` and `tf.ragged.cross_hashed` operations.
* `tf.RaggedTensor`:
* `RaggedTensor.to_tensor()` now preserves static shape.
* Add `tf.strings.format()` and `tf.print()` to support RaggedTensors.
* `tf.saved_model`:
* `@tf.function` from SavedModel no longer ignores args after a `RaggedTensor` when selecting the concrete function to run.
* Fix save model issue for ops with a list of functions.
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
* Mutable tables now restore checkpointed values when loaded from SavedModel.
* GPU
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
* Others
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
* `tf.custom_gradient` can now be applied to functions that accept nested structures of `tensors` as inputs (instead of just a list of tensors). Note that Python structures such as tuples and lists now won't be treated as tensors, so if you still want them to be treated that way, you need to wrap them with `tf.convert_to_tensor`.
* No lowering on gradient case op when input is `DeviceIndex` op.
* Extend the ragged version of `tf.gather` to support `batch_dims` and `axis` args.
* Update `tf.map_fn` to support RaggedTensors and SparseTensors.
* Deprecate `tf.group`. It is not useful in eager mode.
* Add CPU and GPU implementation of modified variation of [`FTRL`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrl)/[`FTRLV2`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrlV2) that can triggerred by `multiply_linear_by_lr` allowing a learning rate of zero.
### `tf.data`:
* `tf.data.experimental.dense_to_ragged_batch` works correctly with tuples.
* `tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank.
* `tf.data.experimental.cardinality` is now a method on `tf.data.Dataset`.
* `tf.data.Dataset` now supports `len(Dataset)` when the cardinality is finite.
### `tf.distribute`:
* Expose experimental [`tf.distribute.DistributedDataset`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedDataset?hl=en) and [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator) to distribute input data when using `tf.distribute` to scale training on multiple devices.
* Added a [`get_next_as_optional`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en#get_next_as_optional) method for [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en) class to return a `tf.experimental.Optional` instance that contains the next value for all replicas or none instead of raising an out of range error. Also see *new* [guide on input distribution](https://www.tensorflow.org/tutorials/distribute/input).
* Allow var.assign on MirroredVariables with aggregation=NONE in replica context. Previously this would raise an error. We now allow this because many users and library writers find using `.assign` in replica context to be more convenient, instead of having to use `Strategy.extended.update` which was the previous way of updating variables in this situation.
* `tf.distribute.experimental.MultiWorkerMirroredStrategy` adds support for partial batches. Workers running out of data now continue to participate in the training with empty inputs, instead of raising an error. Learn more about [partial batches here](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
* Improve the performance of reading metrics eagerly under `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
### `tf.keras`:
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
* Added **categorical data** processing layers:
* `IntegerLookup` & `StringLookup`: build an index of categorical feature values
* `CategoryEncoding`: turn integer-encoded categories into one-hot, multi-hot, or tf-idf encoded representations
* `CategoryCrossing`: create new categorical features representing co-occurrences of previous categorical feature values
* `Hashing`: the hashing trick, for large-vocabulary categorical features
* `Discretization`: turn continuous numerical features into categorical features by binning their values
* Improved **image preprocessing** layers: `CenterCrop`, `Rescaling`
* Improved **image augmentation** layers: `RandomCrop`, `RandomFlip`, `RandomTranslation`, `RandomRotation`, `RandomHeight`, `RandomWidth`, `RandomZoom`, `RandomContrast`
* Improved **`TextVectorization`** layer, which handles string tokenization, n-gram generation, and token encoding
* The `TextVectorization` layer now accounts for the mask_token as part of the vocabulary size when output_mode='int'. This means that, if you have a max_tokens value of 5000, your output will have 5000 unique values (not 5001 as before).
* Change the return value of `TextVectorization.get_vocabulary()` from `byte` to `string`. Users who previously were calling 'decode' on the output of this method should no longer need to do so.
* Introduce new Keras dataset generation utilities :
* **[`image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)** is a utility based on `tf.data.Dataset`, meant to replace the legacy `ImageDataGenerator`. It takes you from a structured directory of images to a labeled dataset, in one function call. Note that it doesn't perform image data augmentation (which is meant to be done using preprocessing layers).
* **[`text_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory)** takes you from a structured directory of text files to a labeled dataset, in one function call.
* **[`timeseries_dataset_from_array`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/timeseries_dataset_from_array)** is a `tf.data.Dataset`-based replacement of the legacy `TimeseriesGenerator`. It takes you from an array of timeseries data to a dataset of shifting windows with their targets.
* Added [`experimental_steps_per_execution`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/Model?hl=en#compile)
arg to `model.compile` to indicate the number of batches to run per `tf.function` call. This can speed up Keras Models on TPUs up to 3x.
* Extends `tf.keras.layers.Lambda` layers to support multi-argument lambdas, and keyword arguments when calling the layer.
* Functional models now get constructed if *any* tensor in a layer call's arguments/keyword arguments comes from a keras input. Previously the functional api would only work if all of the elements in the first argument to the layer came from a keras input.
* Clean up `BatchNormalization` layer's `trainable` property to act like standard python state when it's used inside `tf.functions` (frozen at tracing time), instead of acting like a pseudo-variable whose updates *kind of sometimes* get reflected in already-traced `tf.function` traces.
* Add the `Conv1DTranspose` layer.
* Refine the semantics of `SensitivitySpecificityBase` derived metrics. See the updated API docstrings for [`tf.keras.metrics.SensitivityAtSpecificity`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SensitivityAtSpecificity) and [`tf.keras.metrics.SpecificityAtSensitivty`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SpecificityAtSensitivity).
### `tf.lite`:
* Converter
* Restored `inference_input_type` and `inference_output_type` flags in TF 2.x TFLiteConverter (backward compatible with TF 1.x) to support integer (tf.int8, tf.uint8) input and output types in post training full integer quantized models.
* Added support for converting and resizing models with dynamic (placeholder) dimensions. Previously, there was only limited support for dynamic batch size, and even that did not guarantee that the model could be properly resized at runtime.
* Enabled experimental support for a new quantization mode with 16-bit activations and 8-bit weights. See `lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8`.
* CPU
* Fix an issue w/ dynamic weights and `Conv2D` on x86.
* Add a runtime Android flag for enabling `XNNPACK` for optimized CPU performance.
* Add a runtime iOS flag for enabling `XNNPACK` for optimized CPU performance.
* Add a compiler flag to enable building a TFLite library that applies `XNNPACK` delegate automatically when the model has a `fp32` operation.
* GPU
* Allow GPU acceleration starting with internal graph nodes
* Experimental support for quantized models with the Android GPU delegate
* Add GPU delegate whitelist.
* Rename GPU whitelist -> compatibility (list).
* Improve GPU compatibility list entries from crash reports.
* NNAPI
* Set default value for `StatefulNnApiDelegate::Options::max_number_delegated_partitions` to 3.
* Add capability to disable `NNAPI` CPU and check `NNAPI` Errno.
* Fix crashes when using `NNAPI` with target accelerator specified with model containing Conv2d or FullyConnected or LSTM nodes with quantized weights.
* Fix `ANEURALNETWORKS_BAD_DATA` execution failures with `sum`/`max`/`min`/`reduce` operations with `scalar` inputs.
* Hexagon
* TFLite Hexagon Delegate out of experimental.
* Experimental `int8` support for most hexagon ops.
* Experimental per-channel quant support for `conv` in Hexagon delegate.
* Support dynamic batch size in C++ API.
* CoreML
* Opensource CoreML delegate
* Misc
* Enable building Android TFLite targets on Windows
* Add support for `BatchMatMul`.
* Add support for `half_pixel_centers` with `ResizeNearestNeighbor`.
* Add 3D support for `BatchToSpaceND`.
* Add 5D support for `BroadcastSub`, `Maximum`, `Minimum`, `Transpose` and `BroadcastDiv`.
* Rename `kTfLiteActRelu1` to `kTfLiteActReluN1To1`.
* Enable flex delegate on tensorflow.lite.Interpreter Python package.
* Add `Buckettize`, `SparseCross` and `BoostedTreesBucketize` to the flex whitelist.
* Add support for selective registration of flex ops.
* Add missing kernels for flex delegate whitelisted ops.
* Fix issue when using direct `ByteBuffer` inputs with graphs that have dynamic shapes.
* Fix error checking supported operations in a model containing `HardSwish`.
### Packaging Support
* Added `tf.sysconfig.get_build_info()`. Returns a dict that describes the build environment of the currently installed TensorFlow package, e.g. the NVIDIA CUDA and NVIDIA CuDNN versions used when TensorFlow was built.
### Profiler
* Fix a subtle use-after-free issue in `XStatVisitor::RefValue()`.
### TPU Enhancements
* Adds 3D mesh support in TPU configurations ops.
* Added TPU code for `FTRL` with `multiply_linear_by_lr`.
* Silently adds a new file system registry at `gstpu`.
* Support `restartType` in cloud tpu client.
* Depend on a specific version of google-api-python-client.
* Fixes apiclient import.
### Tracing and Debugging
* Add a `TFE_Py_Execute` traceme.
### XLA Support
* Implement stable `argmin` and `argmax`
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
902449@58880@bigcat_chen@ASIC, Abdul Baseer Khan, Abhineet Choudhary, Abolfazl Shahbazi, Adam Hillier, ag.ramesh, Agoniii, Ajay P, Alex Hoffman, Alexander Bayandin, Alexander Grund, Alexandre Abadie, Alexey Rogachevskiy, amoitra, Andrew Stevens, Angus-Luo, Anshuman Tripathy, Anush Elangovan, Artem Mavrin, Ashutosh Hathidara, autoih, Ayushman Kumar, ayushmankumar7, Bairen Yi, Bas Aarts, Bastian Eichenberger, Ben Barsdell, bhack, Bharat Raghunathan, Biagio Montaruli, Bigcat-Himax, blueyi, Bryan Cutler, Byambaa, Carlos Hernandez-Vaquero, Chen Lei, Chris Knorowski, Christian Clauss, chuanqiw, CuiYifeng, Daniel Situnayake, Daria Zhuravleva, Dayananda-V, Deven Desai, Devi Sandeep Endluri, Dmitry Zakharov, Dominic Jack, Duncan Riach, Edgar Liberis, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, Eugene Kuznetsov, Eugene Mikhantiev, Evgenii Zheltonozhskii, Fabio Di Domenico, Fausto Morales, Fei Sun, feihugis, Felix E. Klee, flyingcat, Frederic Bastien, Fredrik Knutsson, frreiss, fsx950223, ganler, Gaurav Singh, Georgios Pinitas, Gian Marco Iodice, Giorgio Arena, Giuseppe Rossini, Gregory Keith, Guozhong Zhuang, gurushantj, Hahn Anselm, Harald Husum, Harjyot Bagga, Hristo Vrigazov, Ilya Persky, Ir1d, Itamar Turner-Trauring, jacco, Jake Tae, Janosh Riebesell, Jason Zaman, jayanth, Jeff Daily, Jens Elofsson, Jinzhe Zeng, JLZ, Jonas Skog, Jonathan Dekhtiar, Josh Meyer, Joshua Chia, Judd, justkw, Kaixi Hou, Kam D Kasravi, Kamil Rakoczy, Karol Gugala, Kayou, Kazuaki Ishizaki, Keith Smiley, Khaled Besrour, Kilaru Yasaswi Sri Chandra Gandhi, Kim, Young Soo, Kristian Hartikainen, Kwabena W. Agyeman, Leslie-Fang, Leslie-Fang-Intel, Li, Guizi, Lukas Geiger, Lutz Roeder, M\U00E5Ns Nilsson, Mahmoud Abuzaina, Manish, Marcel Koester, Marcin Sielski, marload, Martin Jul, Matt Conley, mdfaijul, Meng, Peng, Meteorix, Michael Käufl, Michael137, Milan Straka, Mitchell Vitez, Ml-0, Mokke Meguru, Mshr-H, nammbash, Nathan Luehr, naumkin, Neeraj Bhadani, ngc92, Nick Morgan, nihui, Niranjan Hasabnis, Niranjan Yadla, Nishidha Panpaliya, Oceania2018, oclyke, Ouyang Jin, OverLordGoldDragon, Owen Lyke, Patrick Hemmer, Paul Andrey, Peng Sun, periannath, Phil Pearl, Prashant Dandriyal, Prashant Kumar, Rahul Huilgol, Rajan Singh, Rajeshwar Reddy T, rangjiaheng, Rishit Dagli, Rohan Reddy, rpalakkal, rposts, Ruan Kunliang, Rushabh Vasani, Ryohei Ikegami, Semun Lee, Seo-Inyoung, Sergey Mironov, Sharada Shiddibhavi, ShengYang1, Shraiysh Vaishay, Shunya Ueta, shwetaoj, Siyavash Najafzade, Srinivasan Narayanamoorthy, Stephan Uphoff, storypku, sunchenggen, sunway513, Sven-Hendrik Haase, Swapnil Parekh, Tamas Bela Feher, Teng Lu, tigertang, tomas, Tomohiro Ubukata, tongxuan.ltx, Tony Tonev, Tzu-Wei Huang, Téo Bouvard, Uday Bondhugula, Vaibhav Jade, Vijay Tadikamalla, Vikram Dattu, Vincent Abriou, Vishnuvardhan Janapati, Vo Van Nghia, VoVAllen, Will Battel, William D. Irons, wyzhao, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, xutianming, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yixing Fu, Yong Tang, Yuan Tang, zhaozheng09, Zilin Zhu, zilinzhu, 张志豪
# Release 2.1.1

View File

@ -18,6 +18,18 @@ load("//tensorflow:workspace.bzl", "tf_repositories")
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_repositories()
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
git_repository(
name = "com_github_nelhage_rules_boost",
commit = "1e3a69bf2d5cd10c34b74f066054cd335d033d71",
remote = "https://github.com/nelhage/rules_boost",
shallow_since = "1591047380 -0700",
)
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
boost_deps()
register_toolchains("@local_config_python//:py_toolchain")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")

1
native_client Symbolic link
View File

@ -0,0 +1 @@
../native_client

View File

@ -28,8 +28,8 @@ tf_proto_library(
)
tf_proto_library(
name = "master_proto",
srcs = ["master.proto"],
name = "dispatcher_proto",
srcs = ["dispatcher.proto"],
has_services = 1,
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [
@ -49,17 +49,17 @@ tf_proto_library(
)
cc_library(
name = "master_impl",
srcs = ["master_impl.cc"],
name = "dispatcher_impl",
srcs = ["dispatcher_impl.cc"],
hdrs = [
"master_impl.h",
"dispatcher_impl.h",
],
deps = [
":common_proto_cc",
":credentials_factory",
":data_service",
":dispatcher_proto_cc",
":grpc_util",
":master_proto_cc",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/c:c_api_internal",
@ -86,9 +86,9 @@ cc_library(
deps = [
":common_proto_cc",
":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_proto_cc",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
@ -207,12 +207,12 @@ tf_cc_test(
)
cc_library(
name = "grpc_master_impl",
srcs = ["grpc_master_impl.cc"],
hdrs = ["grpc_master_impl.h"],
name = "grpc_dispatcher_impl",
srcs = ["grpc_dispatcher_impl.cc"],
hdrs = ["grpc_dispatcher_impl.h"],
deps = [
":master_cc_grpc_proto",
":master_impl",
":dispatcher_cc_grpc_proto",
":dispatcher_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(),
],
@ -250,7 +250,7 @@ cc_library(
],
deps = [
":credentials_factory",
":grpc_master_impl",
":grpc_dispatcher_impl",
":grpc_util",
":grpc_worker_impl",
"//tensorflow/core:lib",
@ -268,9 +268,9 @@ cc_library(
],
deps = [
":credentials_factory",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_util",
":master_cc_grpc_proto",
":master_proto_cc",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/core:framework",
@ -287,12 +287,12 @@ tf_cc_test(
tags = ["no_windows"],
deps = [
":data_service",
":grpc_master_impl",
":dispatcher_cc_grpc_proto",
":dispatcher_proto_cc",
":grpc_dispatcher_impl",
":grpc_util",
":grpc_worker_impl",
":local_credentials_factory",
":master_cc_grpc_proto",
":master_proto_cc",
":server_lib",
":test_cluster",
":test_util",
@ -309,11 +309,11 @@ tf_cc_test(
)
cc_grpc_library(
name = "master_cc_grpc_proto",
srcs = [":master_proto"],
name = "dispatcher_cc_grpc_proto",
srcs = [":dispatcher_proto"],
generate_mocks = True,
grpc_only = True,
deps = [":master_proto_cc"],
deps = [":dispatcher_proto_cc"],
)
cc_grpc_library(

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
}
}
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) {
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
int64* dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = dataset;
@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
return Status::OK();
}
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
int64* job_id) {
Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
CreateJobRequest req;
req.set_dataset_id(dataset_id);
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
return Status::OK();
}
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
ProcessingMode processing_mode,
const std::string& job_name,
int job_name_index,
int64* job_id) {
Status DataServiceDispatcherClient::GetOrCreateJob(
int64 dataset_id, ProcessingMode processing_mode,
const std::string& job_name, int job_name_index, int64* job_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
return Status::OK();
}
Status DataServiceMasterClient::GetTasks(int64 job_id,
std::vector<TaskInfo>* tasks,
bool* job_finished) {
Status DataServiceDispatcherClient::GetTasks(int64 job_id,
std::vector<TaskInfo>* tasks,
bool* job_finished) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetTasksRequest req;
req.set_job_id(job_id);
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
return Status::OK();
}
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
Status DataServiceDispatcherClient::GetWorkers(
std::vector<WorkerInfo>* workers) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetWorkersRequest req;
GetWorkersResponse resp;
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
return Status::OK();
}
Status DataServiceMasterClient::EnsureInitialized() {
Status DataServiceDispatcherClient::EnsureInitialized() {
std::shared_ptr<grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel = grpc::CreateChannel(address_, credentials);
stub_ = MasterService::NewStub(channel);
stub_ = DispatcherService::NewStub(channel);
return Status::OK();
}
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return Status::OK();
}
Status CreateDataServiceMasterClient(
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out) {
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
std::unique_ptr<DataServiceDispatcherClient>* out) {
auto client =
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
TF_RETURN_IF_ERROR(client->Initialize());
*out = std::move(client);
return Status::OK();

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -67,11 +67,11 @@ class DataServiceClientBase {
const std::string protocol_;
};
// Client for communicating with the tf.data service master.
class DataServiceMasterClient : public DataServiceClientBase {
// Client for communicating with the tf.data service dispatcher.
class DataServiceDispatcherClient : public DataServiceClientBase {
public:
DataServiceMasterClient(const std::string& address,
const std::string& protocol)
DataServiceDispatcherClient(const std::string& address,
const std::string& protocol)
: DataServiceClientBase(address, protocol) {}
// Registers a dataset with the tf.data service, and stores the generated
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
const std::string& job_name, int job_name_index,
int64* job_id);
// Queries the master for the tasks associated with the specified job.
// Queries the dispatcher for the tasks associated with the specified job.
// The tasks will be stored in *tasks, and whether the job is finished will
// be stored in `*job_finished`.
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
bool* job_finished);
// Queries the master for its registered workers. The worker info will be
// Queries the dispatcher for its registered workers. The worker info will be
// stored in `*workers`.
Status GetWorkers(std::vector<WorkerInfo>* workers);
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
Status EnsureInitialized() override;
private:
std::unique_ptr<MasterService::Stub> stub_;
std::unique_ptr<DispatcherService::Stub> stub_;
};
// Client for communicating with the tf.data service worker.
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
std::unique_ptr<WorkerService::Stub> stub_;
};
// Creates and initializes a new tf.data service master client.
Status CreateDataServiceMasterClient(
// Creates and initializes a new tf.data service dispatcher client.
Status CreateDataServiceDispatcherClient(
const std::string& address, const std::string& protocol,
std::unique_ptr<DataServiceMasterClient>* out);
std::unique_ptr<DataServiceDispatcherClient>* out);
// Creates and initializes a new tf.data service worker client.
Status CreateDataServiceWorkerClient(

View File

@ -19,9 +19,9 @@ limitations under the License.
#include "grpcpp/security/credentials.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/test_cluster.h"
#include "tensorflow/core/data/service/test_util.h"
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
TEST(DataService, GetWorkers) {
TestCluster cluster(1);
TF_ASSERT_OK(cluster.Initialize());
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
kProtocol);
std::vector<WorkerInfo> workers;
TF_EXPECT_OK(master.GetWorkers(&workers));
TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
EXPECT_EQ(1, workers.size());
}

View File

@ -110,11 +110,11 @@ message GetWorkersResponse {
repeated WorkerInfo workers = 1;
}
service MasterService {
// Registers a worker with the master.
service DispatcherService {
// Registers a worker with the dispatcher.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
// Updates the master with information about the worker's state.
// Updates the dispatcher with information about the worker's state.
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
// Registers a dataset with the server, or returns its id if it is already
@ -134,6 +134,6 @@ service MasterService {
// Reports a list of all tasks for a job.
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
// Reports a list of all workers registered with the master.
// Reports a list of all workers registered with the dispatcher.
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
}

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/master_impl.h"
#include "tensorflow/core/data/service/dispatcher_impl.h"
#include <memory>
#include <tuple>
@ -26,8 +26,8 @@ limitations under the License.
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
}
} // namespace
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol)
DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
: protocol_(protocol) {}
Status DataServiceMasterImpl::RegisterWorker(
Status DataServiceDispatcherImpl::RegisterWorker(
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
VLOG(3) << "Received register worker request";
mutex_lock l(mu_);
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
return Status::OK();
}
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
WorkerUpdateResponse* response) {
Status DataServiceDispatcherImpl::WorkerUpdate(
const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
mutex_lock l(mu_);
int64 worker_id = request->worker_id();
for (auto& update : request->updates()) {
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetOrRegisterDataset(
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
const GetOrRegisterDatasetRequest* request,
GetOrRegisterDatasetResponse* response) {
uint64 fingerprint;
@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
return Status::OK();
}
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset)
int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
const DatasetDef& dataset)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 dataset_id = next_dataset_id_++;
auto new_dataset =
@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
return dataset_id;
}
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) {
Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
CreateJobResponse* response) {
VLOG(3) << "Received create job request for dataset id "
<< request->dataset_id();
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetOrCreateJob(
Status DataServiceDispatcherImpl::GetOrCreateJob(
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
VLOG(3) << "Received get or create job request for dataset id "
<< request->dataset_id() << " with name " << request->job_name()
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
}
// Validates that the job matches the given processing_mode and dataset_id.
Status DataServiceMasterImpl::ValidateMatchingJob(
Status DataServiceDispatcherImpl::ValidateMatchingJob(
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
DCHECK(job.name().has_value());
std::string job_name = job.name().value();
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
return Status::OK();
}
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
ProcessingMode processing_mode,
absl::optional<std::string> job_name,
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
Status DataServiceDispatcherImpl::CreateJob(
int64 dataset_id, ProcessingMode processing_mode,
absl::optional<std::string> job_name, int64* out_job_id)
LOCKS_EXCLUDED(mu_) {
switch (processing_mode) {
case ProcessingMode::PARALLEL_EPOCHS:
break;
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
return Status::OK();
}
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
return CreateTaskLocked(job, worker_address);
}
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const DataServiceDispatcherImpl::Task&
DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 task_id = next_task_id_++;
DCHECK(!tasks_.contains(task_id));
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
return tasks_.at(task_id);
}
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
if (!worker->stub()) {
std::unique_ptr<WorkerService::Stub> stub;
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
return Status::OK();
}
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
Worker* worker)
Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
Worker* worker)
LOCKS_EXCLUDED(mu_) {
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
grpc::ClientContext client_ctx;
@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
return Status::OK();
}
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
GetTasksResponse* response) {
Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
GetTasksResponse* response) {
mutex_lock l(mu_);
VLOG(3) << "Looking up tasks for job id " << request->job_id();
auto it = jobs_.find(request->job_id());
@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
return Status::OK();
}
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) {
Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
GetWorkersResponse* response) {
mutex_lock l(mu_);
VLOG(3) << "Enter GetWorkers";
for (auto& worker : workers_) {

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
@ -40,11 +40,11 @@ namespace data {
// ProcessingModeDef which determines what data it produces.
// * Task: A job is broken into multiple tasks, which each represent
// iterating over all of or part of the dataset. Workers process tasks.
class DataServiceMasterImpl {
class DataServiceDispatcherImpl {
public:
explicit DataServiceMasterImpl(const std::string protocol);
explicit DataServiceDispatcherImpl(const std::string protocol);
// See master.proto for API documentation.
// See dispatcher.proto for API documentation.
/// Worker-facing API.
Status RegisterWorker(const RegisterWorkerRequest* request,
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
// Creates a new task for a job, returning a reference to the task.
const Task& CreateTask(Job* job, const std::string& worker_address)
LOCKS_EXCLUDED(mu_);
// Same as `CreateTask`, but expects that the master lock is already held.
// Same as `CreateTask`, but expects that the dispatcher lock is already held.
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Validates that an existing job matches the given processing_mode and
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/grpc_master_impl.h"
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "grpcpp/server_context.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@ -25,18 +25,18 @@ using ::grpc::ServerBuilder;
using ::grpc::ServerContext;
using ::grpc::Status;
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder,
const std::string& protocol)
GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
const std::string& protocol)
: impl_(protocol) {
server_builder->RegisterService(this);
VLOG(1) << "Registered data service master";
VLOG(1) << "Registered data service dispatcher";
}
#define HANDLER(method) \
Status GrpcMasterImpl::method(ServerContext* context, \
const method##Request* request, \
method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \
#define HANDLER(method) \
Status GrpcDispatcherImpl::method(ServerContext* context, \
const method##Request* request, \
method##Response* response) { \
return ToGrpcStatus(impl_.method(request, response)); \
}
HANDLER(RegisterWorker);
HANDLER(WorkerUpdate);

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
#include "grpcpp/server_builder.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master_impl.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher_impl.h"
namespace tensorflow {
namespace data {
@ -29,14 +29,14 @@ namespace data {
//
// ::grpc::ServerBuilder builder;
// // configure builder
// GrpcMasterImpl data_service(&builder);
// GrpcDispatcherImpl data_service(&builder);
// builder.BuildAndStart()
//
class GrpcMasterImpl : public MasterService::Service {
class GrpcDispatcherImpl : public DispatcherService::Service {
public:
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder,
const std::string& protocol);
~GrpcMasterImpl() override {}
explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
const std::string& protocol);
~GrpcDispatcherImpl() override {}
#define HANDLER(method) \
grpc::Status method(grpc::ServerContext* context, \
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
#undef HANDLER
private:
DataServiceMasterImpl impl_;
DataServiceDispatcherImpl impl_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_

View File

@ -26,9 +26,9 @@ using ::grpc::ServerContext;
using ::grpc::Status;
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& protocol)
: impl_(master_address, protocol) {
: impl_(dispatcher_address, protocol) {
server_builder->RegisterService(this);
VLOG(1) << "Registered data service worker";
}

View File

@ -35,7 +35,7 @@ namespace data {
class GrpcWorkerImpl : public WorkerService::Service {
public:
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& protocol);
~GrpcWorkerImpl() override {}

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/data/service/server_lib.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/grpc_master_impl.h"
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/grpc_worker_impl.h"
#include "tensorflow/core/platform/errors.h"
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
int GrpcDataServerBase::BoundPort() { return bound_port(); }
MasterGrpcDataServer::MasterGrpcDataServer(int port,
const std::string& protocol)
DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
const std::string& protocol)
: GrpcDataServerBase(port, protocol) {}
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; }
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_);
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
service_ = service.release();
}
Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
GetWorkersRequest req;
GetWorkersResponse resp;
grpc::ServerContext ctx;
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
return Status::OK();
}
WorkerGrpcDataServer::WorkerGrpcDataServer(int port,
const std::string& protocol,
const std::string& master_address,
const std::string& worker_address)
WorkerGrpcDataServer::WorkerGrpcDataServer(
int port, const std::string& protocol,
const std::string& dispatcher_address, const std::string& worker_address)
: GrpcDataServerBase(port, protocol),
master_address_(master_address),
dispatcher_address_(dispatcher_address),
worker_address_(worker_address) {}
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
auto service =
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_);
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
protocol_);
service_ = service.release();
}
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
return Status::OK();
}
Status NewMasterServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server) {
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol);
Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
*out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
return Status::OK();
}
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"",
out_server);
return NewWorkerServer(port, protocol, dispatcher_address,
/*worker_address=*/"", out_server);
}
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
*out_server = absl::make_unique<WorkerGrpcDataServer>(
port, protocol, master_address, worker_address);
port, protocol, dispatcher_address, worker_address);
return Status::OK();
}

View File

@ -25,7 +25,7 @@ namespace data {
// Forward declared because transitively depending on .grpc.pb.h files causes
// issues in the pywrap build.
class GrpcMasterImpl;
class GrpcDispatcherImpl;
class GrpcWorkerImpl;
// A grpc server for the tf.data service.
@ -35,7 +35,7 @@ class GrpcDataServerBase {
// server will find an available port in `Start()`. The chosen port can be
// found in the output of `Target()`.
//
// master_address is only needed for worker data servers.
// dispatcher_address is only needed for worker data servers.
GrpcDataServerBase(int requested_port, const std::string& protocol);
virtual ~GrpcDataServerBase() {}
@ -70,12 +70,12 @@ class GrpcDataServerBase {
std::unique_ptr<grpc::Server> server_;
};
class MasterGrpcDataServer : public GrpcDataServerBase {
class DispatchGrpcDataServer : public GrpcDataServerBase {
public:
MasterGrpcDataServer(int requested_port, const std::string& protocol);
~MasterGrpcDataServer() override;
DispatchGrpcDataServer(int requested_port, const std::string& protocol);
~DispatchGrpcDataServer() override;
// Returns the number of workers registerd with the master.
// Returns the number of workers registerd with the dispatcher.
Status NumWorkers(int* num_workers);
protected:
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override { return Status::OK(); }
private:
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared.
GrpcMasterImpl* service_;
// Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
GrpcDispatcherImpl* service_;
};
class WorkerGrpcDataServer : public GrpcDataServerBase {
public:
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address);
~WorkerGrpcDataServer() override;
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
Status StartServiceInternal() override;
private:
const std::string master_address_;
const std::string dispatcher_address_;
const std::string worker_address_;
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
GrpcWorkerImpl* service_;
};
// Creates a master tf.data server and stores it in `*out_server`.
Status NewMasterServer(int port, const std::string& protocol,
std::unique_ptr<MasterGrpcDataServer>* out_server);
// Creates a dispatch tf.data server and stores it in `*out_server`.
Status NewDispatchServer(int port, const std::string& protocol,
std::unique_ptr<DispatchGrpcDataServer>* out_server);
// Creates a worker tf.data server and stores it in `*out_server`.
//
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
// will be chosen in Start(). This value can be queried with BoundPort().
//
// The worker_address argument is optional. If left empty, it will default to
// "localhost:%port%". When the worker registers with the master, the worker
// will report the worker address, so that the master can tell clients where to
// read from. The address may contain the placeholder "%port%", which will be
// "localhost:%port%". When the worker registers with the dispatcher, the worker
// will report the worker address, so that the dispatcher can tell clients where
// to read from. The address may contain the placeholder "%port%", which will be
// replaced with the value of BoundPort().
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
const std::string& worker_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
// Creates a worker using the default worker_address.
Status NewWorkerServer(int port, const std::string& protocol,
const std::string& master_address,
const std::string& dispatcher_address,
std::unique_ptr<WorkerGrpcDataServer>* out_server);
} // namespace data

View File

@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
"Test cluster has already been initialized.");
}
initialized_ = true;
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_));
TF_RETURN_IF_ERROR(master_->Start());
master_address_ = absl::StrCat("localhost:", master_->BoundPort());
TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
TF_RETURN_IF_ERROR(dispatcher_->Start());
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
workers_.reserve(num_workers_);
worker_addresses_.reserve(num_workers_);
for (int i = 0; i < num_workers_; ++i) {
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
Status TestCluster::AddWorker() {
std::unique_ptr<WorkerGrpcDataServer> worker;
TF_RETURN_IF_ERROR(
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker));
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
TF_RETURN_IF_ERROR(worker->Start());
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
workers_.push_back(std::move(worker));
return Status::OK();
}
std::string TestCluster::MasterAddress() { return master_address_; }
std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
std::string TestCluster::WorkerAddress(int index) {
DCHECK_GE(index, 0);

View File

@ -24,7 +24,7 @@ namespace data {
// Helper class for unit testing a tf.data service cluster.
class TestCluster {
public:
// Creates a new test cluster with a master and `num_workers` workers.
// Creates a new test cluster with a dispatcher and `num_workers` workers.
explicit TestCluster(int num_workers);
// Initializes the test cluster. This must be called before interacting with
@ -32,8 +32,8 @@ class TestCluster {
Status Initialize();
// Adds a new worker to the cluster.
Status AddWorker();
// Returns the master address in the form "hostname:port".
std::string MasterAddress();
// Returns the dispatcher address in the form "hostname:port".
std::string DispatcherAddress();
// Returns the address of the worker at the specified index, in the form
// "hostname:port". The index must be non-negative and less than the number of
// workers in the cluster.
@ -42,8 +42,8 @@ class TestCluster {
private:
bool initialized_ = false;
int num_workers_;
std::unique_ptr<MasterGrpcDataServer> master_;
std::string master_address_;
std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
std::string dispatcher_address_;
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
std::vector<std::string> worker_addresses_;
};

View File

@ -21,9 +21,9 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/master.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@ -45,9 +45,9 @@ auto* tf_data_service_created =
"has been created.");
} // namespace
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address,
const std::string& protocol)
: master_address_(master_address), protocol_(protocol) {
DataServiceWorkerImpl::DataServiceWorkerImpl(
const std::string& dispatcher_address, const std::string& protocol)
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
tf_data_service_created->GetCell()->Set(true);
}
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
heartbeat_thread_.reset(thread);
Status s = Register();
while (!s.ok()) {
LOG(WARNING) << "Failed to register with master at " << master_address_
<< ": " << s;
LOG(WARNING) << "Failed to register with dispatcher at "
<< dispatcher_address_ << ": " << s;
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
s = Register();
}
}
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response) {
mutex_lock l(mu_);
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
return Status::OK();
}
Status DataServiceWorkerImpl::EnsureMasterStubInitialized()
Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!master_stub_) {
if (!dispatcher_stub_) {
::grpc::ChannelArguments args;
std::shared_ptr<::grpc::ChannelCredentials> credentials;
TF_RETURN_IF_ERROR(
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
auto channel =
::grpc::CreateCustomChannel(master_address_, credentials, args);
master_stub_ = MasterService::NewStub(channel);
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
dispatcher_stub_ = DispatcherService::NewStub(channel);
}
return Status::OK();
}
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Registering with master at " << master_address_;
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
RegisterWorkerRequest req;
req.set_worker_address(worker_address_);
RegisterWorkerResponse resp;
grpc::ClientContext ctx;
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp);
grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to register worker", s);
}
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(3) << "Sending " << pending_completed_tasks_.size()
<< " task updates to master";
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
<< " task updates to dispatcher";
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
WorkerUpdateRequest req;
req.set_worker_id(worker_id_);
for (int task_id : pending_completed_tasks_) {
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
WorkerUpdateResponse resp;
grpc::ClientContext ctx;
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp);
grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to send task updates", s);
}
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
}
Status s = SendTaskUpdate();
if (!s.ok()) {
LOG(WARNING) << "Failed to send task updates to master: " << s;
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
}
}
}

View File

@ -17,7 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h"
@ -29,17 +29,17 @@ namespace data {
// A TensorFlow DataService serves dataset elements over RPC.
class DataServiceWorkerImpl {
public:
explicit DataServiceWorkerImpl(const std::string& master_address,
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
const std::string& protocol);
~DataServiceWorkerImpl();
// Starts the worker. The worker needs to know its own address so that it can
// register with the master.
// register with the dispatcher.
void Start(const std::string& worker_address);
// See worker.proto for API documentation.
/// Master-facing API.
/// Dispatcher-facing API.
Status ProcessTask(const ProcessTaskRequest* request,
ProcessTaskResponse* response);
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
GetElementResponse* response);
private:
// Sets master_stub_ if it isn't already set.
Status EnsureMasterStubInitialized();
// Registers the worker with the master.
// Sets dispatcher_stub_ if it isn't already set.
Status EnsureDispatcherStubInitialized();
// Registers the worker with the dispatcher.
Status Register();
// Sends task status to the master.
// Sends task status to the dispatcher.
Status SendTaskUpdate();
// Creates an iterator to process a task.
Status ProcessTaskInternal(const TaskDef& task);
// A thread for updating the master with worker status.
// A thread for updating the dispatcher with worker status.
void HeartbeatThread();
typedef struct Task {
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
std::unique_ptr<standalone::Iterator> iterator;
} Task;
const std::string master_address_;
// Protocol for communicating with the master.
const std::string dispatcher_address_;
// Protocol for communicating with the dispatcher.
const std::string protocol_;
// The worker's own address.
std::string worker_address_;
mutex mu_;
int64 worker_id_ TF_GUARDED_BY(mu_);
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_);
std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
// Information about tasks, keyed by task ids.
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
// List of completed tasks which haven't yet been communicated to the master.
// List of completed tasks which haven't yet been communicated to the
// dispatcher.
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Condition variable for notifying the heartbeat thread.

View File

@ -5864,15 +5864,15 @@ cc_library(
":string_format_op",
":string_join_op",
":string_length_op",
":string_lower_op",
# ":string_lower_op",
":string_ngrams_op",
":string_split_op",
":string_strip_op",
":string_to_hash_bucket_op",
":string_upper_op",
# ":string_upper_op",
":substr_op",
":unicode_ops",
":unicode_script_op",
# ":unicode_ops",
# ":unicode_script_op",
":unsorted_segment_join_op",
],
)
@ -5885,7 +5885,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@icu//:common",
# "@icu//:common",
],
)
@ -6041,7 +6041,7 @@ tf_kernel_library(
prefix = "string_lower_op",
deps = STRING_DEPS + [
"@com_google_absl//absl/strings",
"@icu//:common",
# "@icu//:common",
],
)
@ -6050,7 +6050,7 @@ tf_kernel_library(
prefix = "string_upper_op",
deps = STRING_DEPS + [
"@com_google_absl//absl/strings",
"@icu//:common",
# "@icu//:common",
],
)
@ -6096,7 +6096,7 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
"//third_party/eigen3",
"//third_party/icu/data:conversion_data",
"@icu//:common",
# "@icu//:common",
],
)
@ -7125,10 +7125,10 @@ filegroup(
"mutex_ops.*",
"batch_kernels.*",
"regex_replace_op.cc",
"string_lower_op.cc", # Requires ICU for unicode.
"string_upper_op.cc", # Requires ICU for unicode.
# "string_lower_op.cc", # Requires ICU for unicode.
# "string_upper_op.cc", # Requires ICU for unicode.
"unicode_ops.cc",
"unicode_script_op.cc",
# "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@ -8620,7 +8620,7 @@ tf_kernel_library(
srcs = ["unicode_script_op.cc"],
deps = [
"//tensorflow/core:framework",
"@icu//:common",
# "@icu//:common",
],
)
@ -8652,6 +8652,39 @@ cc_library(
],
)
tf_kernel_library(
name = "deepspeech_cwise_ops",
srcs = [
"cwise_op_add_1.cc",
"cwise_op_add_2.cc",
"cwise_op_less.cc",
"cwise_op_minimum.cc",
"cwise_op_mul_1.cc",
"cwise_op_rsqrt.cc",
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_sigmoid.cc",
"cwise_op_tanh.cc",
],
gpu_srcs = [
"cwise_op_gpu_add.cu.cc",
"cwise_op_gpu_less.cu.cc",
"cwise_op_gpu_minimum.cu.cc",
"cwise_op_gpu_mul.cu.cc",
"cwise_op_gpu_rsqrt.cu.cc",
"cwise_op_gpu_squared_difference.cu.cc",
"cwise_op_gpu_sub.cu.cc",
"cwise_op_gpu_sigmoid.cu.cc",
"cwise_op_gpu_tanh.cu.cc",
],
deps = [
":cwise_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)
# Header-only version of cwise_lib for clients that want to use the cwise_ops
# functionality in their own custom ops.
cc_header_only_library(

View File

@ -116,6 +116,7 @@ REGISTER_KERNEL(GPU, int16);
REGISTER_KERNEL(GPU, qint16);
REGISTER_KERNEL(GPU, quint16);
REGISTER_KERNEL(GPU, uint32);
REGISTER_KERNEL(GPU, int32);
REGISTER_KERNEL(GPU, qint32);
REGISTER_KERNEL(GPU, int64);
REGISTER_KERNEL(GPU, uint64);

View File

@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
// Dataset for reading data from the tf.data service non-deterministically.
//
// This dataset interleaves dataset elements produced by multiple tf.data
// workers. We periodically query the tf.data master to determine which workers
// workers. We periodically query the dispatcher to determine which workers
// to read from (in case workers are added or removed).
class DataServiceDatasetOp::Dataset : public DatasetBase {
public:
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
if (dataset()->job_name_.empty()) {
TF_RETURN_IF_ERROR(master.CreateJob(
TF_RETURN_IF_ERROR(dispatcher.CreateJob(
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
} else {
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
dataset()->dataset_id_, dataset()->processing_mode_,
dataset()->job_name_, iterator_index_, &job_id_));
}
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// Periodically refresh the task list.
// Maintain one thread fetching elements for each task.
// TODO(aaudibert): Instead of polling, have master send updates when
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
// the list of tasks changes.
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
VLOG(3) << "Starting task thread manager";
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
uint64 next_check = Env::Default()->NowMicros();
while (true) {
{
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return;
}
}
UpdateTasks(&master);
UpdateTasks(&dispatcher);
UpdateWorkerThreads(ctx.get());
next_check = Env::Default()->NowMicros() +
dataset()->task_refresh_interval_ms_ * 1000;
}
}
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating tasks";
std::vector<TaskInfo> tasks;
bool job_finished;
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
if (!s.ok()) {
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
<< s;

View File

@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
DataServiceMasterClient client(address, protocol);
DataServiceDispatcherClient client(address, protocol);
int64 dataset_id;
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));

View File

@ -25,7 +25,7 @@ namespace data {
// Registers a dataset with the tf.data service.
//
// The address and protocol inputs are used to connect to the tf.data master.
// The address and protocol inputs are used to connect to the dispatcher.
// The external state policy attribute determines whether to ignore, warn, or
// error out when the dataset contains external state.
// The op produces a dataset id for identifying the registered dataset.

View File

@ -26,7 +26,7 @@ limitations under the License.
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
#define TF_VERSION_SUFFIX "-rc1"
#define TF_VERSION_SUFFIX ""
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)

View File

@ -57,7 +57,6 @@ cc_library(
"//conditions:default": [],
}) + select({
"//tensorflow:fuchsia": [],
"//tensorflow:windows": [],
"//conditions:default": [
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
],

View File

@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
amount of memory used, since `distribute` won't use more than
`element_size` * `max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
the master for task changes.
the dispatcher for task changes.
"""
if job_name is None:
@ -173,7 +173,7 @@ def _distribute(processing_mode,
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
master for task changes.
dispatcher for task changes.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
from tensorflow.python.data.experimental.service.server_lib import MasterServer
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
from tensorflow.python.data.experimental.service.server_lib import WorkerServer

View File

@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.MasterServer", v1=[])
class MasterServer(object):
"""An in-process tf.data service master server.
@tf_export("data.experimental.service.DispatchServer", v1=[])
class DispatchServer(object):
"""An in-process tf.data service dispatch server.
A `tf.data.experimental.service.MasterServer` coordinates a cluster of
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the master.
register themselves with the dispatcher.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data master process, use join() to block
When starting a dedicated tf.data dispatch process, use join() to block
indefinitely after starting up the server.
```
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
"""
def __init__(self, port, protocol=None, start=True):
"""Creates a new master server.
"""Creates a new dispatch server.
Args:
port: Specifies the port to bind to.
@ -68,15 +68,16 @@ class MasterServer(object):
if protocol is None:
protocol = "grpc"
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol)
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
if start:
self._server.start()
def start(self):
"""Starts this server.
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
>>> master.start()
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
... start=False)
>>> dispatcher.start()
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
@ -87,11 +88,11 @@ class MasterServer(object):
def join(self):
"""Blocks until the server has shut down.
This is useful when starting a dedicated master process.
This is useful when starting a dedicated dispatch process.
```
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
dispatcher.join()
```
Raises:
@ -104,10 +105,10 @@ class MasterServer(object):
def target(self):
"""Returns a target that can be used to connect to the server.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050".
@ -136,7 +137,7 @@ class MasterServer(object):
return "localhost:{0}".format(self._server.bound_port())
def _num_workers(self):
"""Returns the number of workers registered with the master."""
"""Returns the number of workers registered with the dispatcher."""
return self._server.num_workers()
@ -147,15 +148,15 @@ class WorkerServer(object):
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
processing for user-defined datasets, and provides the resulting elements over
RPC. A worker is associated with a single
`tf.data.experimental.service.MasterServer`.
`tf.data.experimental.service.DispatchServer`.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
>>> dispatcher_address = dispatcher.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
... port=0, dispatcher_address=dispatcher_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
... processing_mode="parallel_epochs", service=dispatcher.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
@ -164,14 +165,14 @@ class WorkerServer(object):
```
worker = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050")
port=5051, dispatcher_address="grpc://localhost:5050")
worker.join()
```
"""
def __init__(self,
port,
master_address,
dispatcher_address,
worker_address=None,
protocol=None,
start=True):
@ -180,11 +181,12 @@ class WorkerServer(object):
Args:
port: Specifies the port to bind to. A value of 0 indicates that the
worker can bind to any available port.
master_address: Specifies the address of the master server.
dispatcher_address: Specifies the address of the dispatcher.
worker_address: (Optional.) Specifies the address of the worker server.
This address is passed to the master server so that the master can tell
clients how to connect to this worker. Defaults to `"localhost:%port%"`,
where `%port%` will be replaced with the port used by the worker.
This address is passed to the dispatcher so that the dispatcher can
tell clients how to connect to this worker. Defaults to
`"localhost:%port%"`, where `%port%` will be replaced with the port used
by the worker.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
start: (Optional.) Boolean, indicating whether to start the server after
@ -201,7 +203,7 @@ class WorkerServer(object):
self._protocol = protocol
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
port, protocol, master_address, worker_address)
port, protocol, dispatcher_address, worker_address)
if start:
self._server.start()
@ -221,7 +223,7 @@ class WorkerServer(object):
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050")
port=5051, dispatcher_address="grpc://localhost:5050")
worker_server.join()
```

View File

@ -25,68 +25,68 @@ from tensorflow.python.platform import test
class ServerLibTest(test.TestCase):
def testStartMaster(self):
master = server_lib.MasterServer(0, start=False)
master.start()
def testStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=False)
dispatcher.start()
def testMultipleStartMaster(self):
master = server_lib.MasterServer(0, start=True)
master.start()
def testMultipleStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0, start=True)
dispatcher.start()
def testStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address, start=False)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
worker.start()
def testMultipleStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address, start=True)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
worker.start()
def testStopMaster(self):
master = server_lib.MasterServer(0)
master._stop()
master._stop()
def testStopDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
dispatcher._stop()
def testStopWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
worker._stop()
def testStopStartMaster(self):
master = server_lib.MasterServer(0)
master._stop()
def testStopStartDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"):
master.start()
dispatcher.start()
def testStopStartWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
with self.assertRaisesRegex(
RuntimeError, "Server cannot be started after it has been stopped"):
worker.start()
def testJoinMaster(self):
master = server_lib.MasterServer(0)
master._stop()
master.join()
def testJoinDispatcher(self):
dispatcher = server_lib.DispatchServer(0)
dispatcher._stop()
dispatcher.join()
def testJoinWorker(self):
master = server_lib.MasterServer(0)
worker = server_lib.WorkerServer(0, master._address)
dispatcher = server_lib.DispatchServer(0)
worker = server_lib.WorkerServer(0, dispatcher._address)
worker._stop()
worker.join()
def testMasterNumWorkers(self):
master = server_lib.MasterServer(0)
self.assertEqual(0, master._num_workers())
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
self.assertEqual(1, master._num_workers())
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
self.assertEqual(2, master._num_workers())
def testDispatcherNumWorkers(self):
dispatcher = server_lib.DispatchServer(0)
self.assertEqual(0, dispatcher._num_workers())
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(1, dispatcher._num_workers())
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
self.assertEqual(2, dispatcher._num_workers())
if __name__ == "__main__":

View File

@ -28,13 +28,14 @@ limitations under the License.
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_server_lib, m) {
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer")
.def("start", &tensorflow::data::MasterGrpcDataServer::Start)
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop)
.def("join", &tensorflow::data::MasterGrpcDataServer::Join)
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort)
py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
"DispatchGrpcDataServer")
.def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
.def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
.def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
.def("num_workers",
[](tensorflow::data::MasterGrpcDataServer* server) -> int {
[](tensorflow::data::DispatchGrpcDataServer* server) -> int {
int num_workers;
tensorflow::Status status = server->NumWorkers(&num_workers);
tensorflow::MaybeRaiseFromStatus(status);
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
m.def(
"TF_DATA_NewMasterServer",
"TF_DATA_NewDispatchServer",
[](int port, std::string protocol)
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> {
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server;
-> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
tensorflow::Status status =
tensorflow::data::NewMasterServer(port, protocol, &server);
tensorflow::data::NewDispatchServer(port, protocol, &server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
m.def(
"TF_DATA_NewWorkerServer",
[](int port, std::string protocol, std::string master_address,
[](int port, std::string protocol, std::string dispatcher_address,
std::string worker_address)
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
tensorflow::Status status = tensorflow::data::NewWorkerServer(
port, protocol, master_address, worker_address, &server);
port, protocol, dispatcher_address, worker_address, &server);
tensorflow::MaybeRaiseFromStatus(status);
return server;
},

View File

@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
num_workers: The number of workers in the cluster.
Returns:
The address of the master.
The address of the dispatcher.
"""
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._servers = []
for _ in range(num_workers):
self._servers.append(
server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL))
port=0,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL))
return self._master._address
return self._dispatcher._address
@combinations.generate(test_base.eager_only_combinations())
def testDistributeBasic(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds]
self.assertEqual(list(range(num_elements)), results)
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDifferentShuffleOrders(self):
random_seed.set_random_seed(None)
num_elements = 100
master_address = self.create_cluster(2)
dispatcher_address = self.create_cluster(2)
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.shuffle(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
output = [elem.numpy() for elem in ds]
# The output will be two sequences of range(num_elements)
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testMultipleEpochs(self):
num_elements = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
for _ in range(10):
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testRepeatedDataset(self):
num_elements = 10
num_repetitions = 5
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
ds = ds.repeat(num_repetitions)
self.assertDatasetProduces(
ds, expected_output=num_repetitions * list(range(num_elements)))
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testConcurrentEpoch(self):
num_elements = 10
num_datasets = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
iterators = []
results = []
for _ in range(num_datasets):
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
iterators.append(iter(ds))
results.append([])
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.skipTest("Not yet implemented")
num_elements = 10
num_iterators = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
result = []
iterators = []
for _ in range(num_iterators):
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testMultiWorker(self):
num_workers = 3
num_elements = 10
master_address = self.create_cluster(num_workers)
dispatcher_address = self.create_cluster(num_workers)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
results = [elem.numpy() for elem in ds]
self.assertCountEqual(num_workers * list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations())
def testAddWorkerMidJob(self):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address)
ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds)
results = []
# Read halfway through the dataset.
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
results.append(next(iterator).numpy())
self._new_worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
# Wait for the new worker to register with the master.
while self._master._num_workers() < 2:
# Wait for the new worker to register with the dispatcher.
while self._dispatcher._num_workers() < 2:
time.sleep(10 / 1000) # 10ms
for elem in iterator:
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
combinations.times(test_base.eager_only_combinations(),
combinations.combine(use_same_port=[True, False])))
def testRestartWorker(self, use_same_port):
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer(
port=0, master_address=self._master._address, protocol=PROTOCOL)
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master._address)
ds = _make_distributed_dataset(ds, self._dispatcher._address)
iterator = iter(ds)
# Read halfway through the dataset.
midpoint = num_elements // 2
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
port = int(self._worker._address.split(":")[1])
self._worker._stop()
self._new_worker = server_lib.WorkerServer(
port=port, master_address=self._master._address, protocol=PROTOCOL)
port=port,
dispatcher_address=self._dispatcher._address,
protocol=PROTOCOL)
# There may have been some elements prefetched from the first worker
# before it was stopped.
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testInsideFunction(self):
num_workers = 3
num_elements = 10
master_address = self.create_cluster(num_workers)
dispatcher_address = self.create_cluster(num_workers)
@def_function.function
def f():
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
result = tensor_array_ops.TensorArray(
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
i = 0
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobName(self):
num_elements = 100
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
iter1 = iter(ds1)
iter2 = iter(ds2)
results = []
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testDifferentJobNames(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2")
ds1 = _make_distributed_dataset(
ds, dispatcher_address, job_name="job_name1")
ds2 = _make_distributed_dataset(
ds, dispatcher_address, job_name="job_name2")
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations())
def testSharedJobNameMultiIteration(self):
num_elements = 10
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
# iteration 1
self.assertDatasetProduces(ds1, list(range(num_elements)))
self.assertDatasetProduces(ds2, [])
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSharedJobNameRepeat(self):
num_elements = 100
num_repetitions = 3
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds1 = ds1.repeat(num_repetitions)
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
ds2 = ds2.repeat(num_repetitions)
results = []
iter1 = iter(ds1)
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testApplyDeterminismOption(self):
elements = list(range(10))
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
def dataset_fn(delay_ms):
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
opts = dataset_ops.Options()
opts.experimental_deterministic = False
ds = ds.with_options(opts)
ds = _make_distributed_dataset(ds, master_address)
ds = _make_distributed_dataset(ds, dispatcher_address)
return ds
self.checkDeterminism(
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
options.experimental_external_state_policy = external_state_policy
ds = ds.with_options(options)
master_address = self.create_cluster(3)
ds = _make_distributed_dataset(ds, master_address)
dispatcher_address = self.create_cluster(3)
ds = _make_distributed_dataset(ds, dispatcher_address)
next(iter(ds))
@combinations.generate(
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testDistributeFromInterleave(self):
master_address = self.create_cluster(1)
dispatcher_address = self.create_cluster(1)
ds = dataset_ops.Dataset.range(2)
def interleave_fn(_):
ds = dataset_ops.Dataset.range(2)
_make_distributed_dataset(ds, master_address)
_make_distributed_dataset(ds, dispatcher_address)
return ds
with self.assertRaisesRegex(

View File

@ -4690,7 +4690,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
labels=target, logits=output, axis=axis)
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax'):
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
@ -4735,7 +4735,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
if (not from_logits and
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Softmax'):
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
# When softmax activation function is used for output operation, we
# use logits from the softmax function directly to compute loss in order
# to prevent collapsing zero when training.
@ -4814,7 +4814,7 @@ def binary_crossentropy(target, output, from_logits=False):
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
output.op.type == 'Sigmoid'):
output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
# When sigmoid activation function is used for output operation, we
# use logits from the sigmoid function directly to compute loss in order
# to prevent collapsing zero when training.

View File

@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# >> inputs = tf.keras.Input(10)
# >> outputs = MyLayer()(inputs) # Functional construction mode.
# >> model = tf.keras.Model(inputs, outputs)
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
return self._functional_construction_call(inputs, args, kwargs,
input_list)
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
return config
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
"""Check the arguments to see if we are constructing a functional model."""
if keras_tensor.keras_tensors_enabled():
# We are constructing a functional model if any of the inputs
@ -3215,7 +3215,20 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
for tensor in nest.flatten([inputs, args, kwargs]))
else:
if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
all_inputs_symbolic = all(
tf_utils.is_symbolic_tensor(t) for t in input_list)
if (base_layer_utils.is_subclassed(layer) and
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
[inputs, args, kwargs])) and not all_inputs_symbolic):
raise ValueError('It appears you are trying to construct a '
'functional model, but not all of the inputs in '
'the first positional argument of your layer call '
'are symbolic tensors. '
'(Input objects, or the output of another layer) '
'Functional models cannot correctly track custom '
'layers unless all values in the first call argument '
'are symbolic.')
return all_inputs_symbolic
else:
return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list))

View File

@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
# might want to turn it off, like Sequential model.
self._auto_track_sub_layers = True
# Mark this layer as having been originally built as a tf1 layer/model
self._originally_built_as_v1 = True
@trackable.no_automatic_dependency_tracking
@generic_utils.default
def build(self, input_shape):
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
ValueError: if the layer's `call` method returns None (an invalid value).
RuntimeError: if `super().__init__()` was not called in the constructor.
"""
self._assert_built_as_v1()
if not hasattr(self, '_thread_local'):
raise RuntimeError(
'You must call `super().__init__()` in the layer constructor.')
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
return outputs
def _assert_built_as_v1(self):
if not hasattr(self, '_originally_built_as_v1'):
raise ValueError(
'Your Layer or Model is in an invalid state. This can happen if you '
'are interleaving estimator/non-estimator models or '
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
'with models/layers created outside of it. '
'Converting a model to an estimator (via model_to_estimator) '
'invalidates all models/layers made before the conversion (even '
'if they were not the model converted to an estimator). '
'Similarly, making a layer or a model inside a '
'a tf.compat.v1.Graph invalidates all layers/models you previously '
'made outside of the graph.')
@property
def dtype(self):
return self._dtype_policy.variable_dtype

View File

@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer
@ -931,6 +932,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=False)))
def test_only_some_in_first_arg_derived_from_keras_layer(self):
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
layer([0.0, input1, None, input2, None])
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=True)))
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
# This functionality is unsupported in v1 graphs
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
outputs = layer([0.0, input1, None, input2, None])
model = training_lib.Model([input1, input2], outputs)
self.assertIn(layer, model.layers)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
# Check serialization.
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_kwarg_derived_from_keras_layer(self):
@ -1069,7 +1136,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
input2 = input_layer_lib.Input(10)
input3 = input_layer_lib.Input(10)
outputs = AddAll()(
layer = AddAll()
outputs = layer(
[input1, 4 * array_ops.ones((1, 10))],
x3={
'a': input2,
@ -1077,6 +1145,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
'c': 5 * array_ops.ones((1, 10))
})
model = training_lib.Model([input1, input2, input3], outputs)
self.assertIn(layer, model.layers)
model.compile(
'sgd',
'mse',
@ -1833,6 +1902,37 @@ class AddLossTest(keras_parameterized.TestCase):
self.assertAllClose(model.get_weights(), model2.get_weights())
def test_add_loss_crossentropy_backtracking(self):
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((1,))
outputs = layers.Dense(1, activation='sigmoid')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.binary_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.random((2, 1))
model.fit([x, y])
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((2,))
outputs = layers.Dense(2, activation='softmax')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.categorical_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.random((2, 2))
model.fit([x, y])
inputs = input_layer_lib.Input((2,))
labels = input_layer_lib.Input((1,), dtype='int32')
outputs = layers.Dense(2, activation='softmax')(inputs)
model = functional.Functional([inputs, labels], outputs)
model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs))
model.compile('adam')
x = np.random.random((2, 2))
y = np.random.randint(0, 2, size=(2, 1))
model.fit([x, y])
@combinations.generate(combinations.keras_mode_combinations())
class WeightAccessTest(keras_parameterized.TestCase):

View File

@ -303,6 +303,7 @@ class Model(training_lib.Model):
ValueError: In case of invalid arguments for
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
"""
self._assert_built_as_v1()
self._run_eagerly = kwargs.pop('run_eagerly', None)
self._experimental_run_tf_function = kwargs.pop(
'experimental_run_tf_function', True)
@ -773,6 +774,7 @@ class Model(training_lib.Model):
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('fit_v1').set(True)
# Legacy support
if 'nb_epoch' in kwargs:
@ -893,6 +895,7 @@ class Model(training_lib.Model):
Raises:
ValueError: in case of invalid arguments.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('evaluate_v1').set(True)
self._assert_compile_was_called()
self._check_call_args('evaluate')
@ -972,6 +975,7 @@ class Model(training_lib.Model):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
self._assert_built_as_v1()
_keras_api_gauge.get_cell('predict_v1').set(True)
self._check_call_args('predict')

View File

@ -132,8 +132,7 @@ class Embedding(Layer):
# right now. Checking for the presence of GPUs to avoid complicating the
# TPU codepaths which can handle sparse optimizers. But if we are within
# a tf.function, we go back the graph mode logic and rely on the placer.
if (context.executing_eagerly() and context.context().num_gpus() and
not ops.inside_function()):
if context.executing_eagerly() and context.context().num_gpus():
with ops.device('cpu:0'):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),

View File

@ -293,9 +293,10 @@ class CheckpointPosition(object):
checkpoint_key = serialized_tensor.checkpoint_key
dtype = self._checkpoint.dtype_map[checkpoint_key]
base_type = dtype.base_dtype
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
with ops.init_scope():
with ops.device("/cpu:0"):
# Run the restore itself on the CPU.
with ops.device(io_device):
# Run the restore itself on the io_device(CPU or specified).
value, = io_ops.restore_v2(
prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],

View File

@ -1,6 +1,6 @@
path: "tensorflow.data.experimental.service.MasterServer"
path: "tensorflow.data.experimental.service.DispatchServer"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
is_instance: "<type \'object\'>"
member {
name: "target"

View File

@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
}
member_method {
name: "join"

View File

@ -1,7 +1,7 @@
path: "tensorflow.data.experimental.service"
tf_module {
member {
name: "MasterServer"
name: "DispatchServer"
mtype: "<type \'type\'>"
}
member {

View File

@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
tensorflow::data::GrpcDataServerBase::Start
tensorflow::data::GrpcDataServerBase::Stop
tensorflow::data::GrpcDataServerBase::BoundPort
tensorflow::data::MasterGrpcDataServer::NumWorkers
tensorflow::data::NewMasterServer
tensorflow::data::DispatchGrpcDataServer::NumWorkers
tensorflow::data::NewDispatchServer
tensorflow::data::NewWorkerServer
[protos_all] # device_lib, dtypes

View File

@ -49,7 +49,7 @@ from setuptools.dist import Distribution
# result for pip.
# Also update tensorflow/tensorflow.bzl and
# tensorflow/core/public/version.h
_VERSION = '2.3.0-rc1'
_VERSION = '2.3.0'
REQUIRED_PACKAGES = [
'absl-py >= 0.7.0',
@ -63,8 +63,8 @@ REQUIRED_PACKAGES = [
'numpy >= 1.16.0, < 1.19.0',
'opt_einsum >= 2.3.2',
'protobuf >= 3.9.2',
'tensorboard >= 2.2.0, < 2.3.0',
'tensorflow_estimator >= 2.3.0-rc0, < 2.4.0',
'tensorboard >= 2.3.0, < 3',
'tensorflow_estimator >= 2.3.0, < 2.4.0',
'termcolor >= 1.1.0',
'wrapt >= 1.11.1',
'wheel >= 0.26',

View File

@ -292,6 +292,26 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
],
)
tf_http_archive(
name = "LinaroArmGcc72",
build_file = clean_dep("//third_party/toolchains/embedded/linaro-gcc72-armeabi:linaro-gcc72-armeabi.BUILD"),
strip_prefix = "gcc-linaro-7.2.1-2017.11-x86_64_arm-linux-gnueabihf/",
urls = [
"https://releases.linaro.org/components/toolchain/binaries/7.2-2017.11/arm-linux-gnueabihf/gcc-linaro-7.2.1-2017.11-x86_64_arm-linux-gnueabihf.tar.xz",
],
sha256 = "cee0087b1f1205b73996651b99acd3a926d136e71047048f1758ffcec69b1ca2",
)
tf_http_archive(
name = "LinaroAarch64Gcc72",
build_file = clean_dep("//third_party/toolchains/embedded/linaro-gcc72-aarch64:linaro-gcc72-aarch64.BUILD"),
strip_prefix = "gcc-linaro-7.2.1-2017.11-x86_64_aarch64-linux-gnu/",
urls = [
"https://releases.linaro.org/components/toolchain/binaries/7.2-2017.11/aarch64-linux-gnu/gcc-linaro-7.2.1-2017.11-x86_64_aarch64-linux-gnu.tar.xz",
],
sha256 = "20181f828e1075f1a493947ff91e82dd578ce9f8638fbdfc39e24b62857d8f8d",
)
tf_http_archive(
name = "libxsmm_archive",
build_file = clean_dep("//third_party:libxsmm.BUILD"),

View File

@ -9,7 +9,7 @@ def repo():
third_party_http_archive(
name = "aws",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
"https://mirror.tensorflow.orgg/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
],
sha256 = "758174f9788fed6cc1e266bcecb20bf738bd5ef1c3d646131c9ed15c2d6c5720",

View File

@ -1,3 +1,14 @@
--- ./absl/time/internal/cctz/include/cctz/civil_time_detail.h 2020-08-06 01:33:56.005757145 +0200
+++ ./absl/time/internal/cctz/include/cctz/civil_time_detail.h 2020-08-06 01:33:35.460579387 +0200
@@ -23,7 +23,7 @@
#include "absl/base/config.h"
// Disable constexpr support unless we are in C++14 mode.
-#if __cpp_constexpr >= 201304 || (defined(_MSC_VER) && _MSC_VER >= 1910)
+#if (!defined(NO_CONSTEXPR_FOR_YOU) && __cpp_constexpr >= 201304) || (defined(_MSC_VER) && _MSC_VER >= 1910)
#define CONSTEXPR_D constexpr // data
#define CONSTEXPR_F constexpr // function
#define CONSTEXPR_M constexpr // member
--- ./absl/time/internal/cctz/BUILD.bazel 2019-09-23 13:20:52.000000000 -0700
+++ ./absl/time/internal/cctz/BUILD.bazel.fixed 2019-09-23 13:20:48.000000000 -0700
@@ -74,15 +74,6 @@
@ -301,4 +312,3 @@
+ .internal_compressed_tuple::template Storage<CompressedTuple, I>::get();
}
};

View File

@ -1,5 +1,8 @@
# We make everything here private to make any dependencies on ICU become a build
# failure and easier/faster to track down, as it's not needed for DeepSpeech and
# causes linking problems on Windows.
package(
default_visibility = ["//visibility:public"],
default_visibility = ["//visibility:private"],
)
licenses(["notice"]) # Apache 2.0

View File

@ -67,5 +67,4 @@ config_setting(
)
%{PYTHON_INCLUDE_GENRULE}
%{NUMPY_INCLUDE_GENRULE}
%{PYTHON_IMPORT_LIB_GENRULE}

View File

@ -210,7 +210,7 @@ def _create_local_python_repository(repository_ctx):
python_lib = _get_python_lib(repository_ctx, python_bin)
_check_python_lib(repository_ctx, python_lib)
python_include = _get_python_include(repository_ctx, python_bin)
numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
# numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
python_include_rule = _symlink_genrule_for_dir(
repository_ctx,
python_include,
@ -233,12 +233,12 @@ def _create_local_python_repository(repository_ctx):
[python_import_lib_src],
[python_import_lib_name],
)
numpy_include_rule = _symlink_genrule_for_dir(
repository_ctx,
numpy_include,
"numpy_include/numpy",
"numpy_include",
)
#numpy_include_rule = _symlink_genrule_for_dir(
# repository_ctx,
# numpy_include,
# "numpy_include/numpy",
# "numpy_include",
#)
platform_constraint = ""
if repository_ctx.attr.platform_constraint:
@ -247,7 +247,7 @@ def _create_local_python_repository(repository_ctx):
"%{PYTHON_BIN_PATH}": python_bin,
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
#"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
"%{PLATFORM_CONSTRAINT}": platform_constraint,
})

View File

@ -16,6 +16,8 @@
_SINGLE_URL_WHITELIST = depset([
"arm_compiler",
"LinaroArmGcc72",
"LinaroAarch64Gcc72",
])
def _is_windows(ctx):

View File

@ -0,0 +1,67 @@
# This is the entry point for --crosstool_top.
#
# The cc_toolchain rule used is found by:
#
# 1. Finding the appropriate toolchain in the CROSSTOOL file based on the --cpu
# and --compiler command line flags (if they exist, otherwise using the
# "default_target_cpu" / "default_toolchain" fields in the CROSSTOOL file)
# 2. Concatenating the "target_cpu" and "compiler" fields of the toolchain in
# use and using that as a key in the map in the "toolchains" attribute
package(default_visibility = ["//visibility:public"])
load(":linaro_toolchain_config.bzl", "linaro_toolchain_config")
cc_toolchain_suite(
name = "toolchain",
toolchains = {
"aarch64": ":cc-compiler-aarch64",
},
)
filegroup(
name = "empty",
srcs = [],
)
filegroup(
name = "gcc_linux_all_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:tool-wrappers",
"@LinaroAarch64Gcc72//:compiler_pieces",
],
)
filegroup(
name = "gcc_linux_linker_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:ld",
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:ar",
"@LinaroAarch64Gcc72//:compiler_pieces",
],
)
filegroup(
name = "gcc_linux_compiler_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:gcc",
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:as",
],
)
linaro_toolchain_config(name = "linaro_aarch64")
cc_toolchain(
name = "cc-compiler-aarch64",
all_files = ":gcc_linux_all_files",
compiler_files = ":gcc_linux_compiler_files",
toolchain_identifier = "gcc72_linaro_aarch64",
toolchain_config = ":linaro_aarch64",
dwp_files = ":empty",
dynamic_runtime_lib = ":empty",
linker_files = ":gcc_linux_linker_files",
objcopy_files = "//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:objcopy",
static_runtime_lib = ":empty",
strip_files = "//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:strip",
supports_param_files = 1,
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,79 @@
package(default_visibility = ['//third_party/toolchains/embedded/linaro-gcc72-aarch64:__pkg__'])
filegroup(
name = 'gcc',
srcs = [
'@LinaroAarch64Gcc72//:gcc',
'aarch64-linux-gnu-gcc',
],
)
filegroup(
name = 'ar',
srcs = [
'@LinaroAarch64Gcc72//:ar',
'aarch64-linux-gnu-ar',
],
)
filegroup(
name = 'ld',
srcs = [
'@LinaroAarch64Gcc72//:ld',
'aarch64-linux-gnu-ld',
],
)
filegroup(
name = 'nm',
srcs = [
'@LinaroAarch64Gcc72//:nm',
'aarch64-linux-gnu-nm',
],
)
filegroup(
name = 'objcopy',
srcs = [
'@LinaroAarch64Gcc72//:objcopy',
'aarch64-linux-gnu-objcopy',
],
)
filegroup(
name = 'objdump',
srcs = [
'@LinaroAarch64Gcc72//:objdump',
'aarch64-linux-gnu-objdump',
],
)
filegroup(
name = 'strip',
srcs = [
'@LinaroAarch64Gcc72//:strip',
'aarch64-linux-gnu-strip',
],
)
filegroup(
name = 'as',
srcs = [
'@LinaroAarch64Gcc72//:as',
'aarch64-linux-gnu-as',
],
)
filegroup(
name = 'tool-wrappers',
srcs = [
':gcc',
':ar',
':ld',
':nm',
':objcopy',
':objdump',
':strip',
':as',
],
)

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-ar \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-ar \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-as \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-as \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-cpp \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-cpp \
"$@"

View File

@ -0,0 +1,6 @@
#!/bin/bash --norc
PATH="external/LinaroAarch64Gcc72/libexec/gcc/aarch64-linux-gnu/7.2.1/:$PATH" \
exec \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-gcc \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-gcov \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-gcov \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-ld \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-ld \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-nm \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-nm \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-objcopy \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-objcopy \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-objdump \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-objdump \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a aarch64-linux-gnu-strip \
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-strip \
"$@"

View File

@ -0,0 +1,81 @@
package(default_visibility = ['//visibility:public'])
filegroup(
name = 'gcc',
srcs = [
'bin/aarch64-linux-gnu-gcc',
],
)
filegroup(
name = 'ar',
srcs = [
'bin/aarch64-linux-gnu-ar',
],
)
filegroup(
name = 'ld',
srcs = [
'bin/aarch64-linux-gnu-ld',
],
)
filegroup(
name = 'nm',
srcs = [
'bin/aarch64-linux-gnu-nm',
],
)
filegroup(
name = 'objcopy',
srcs = [
'bin/aarch64-linux-gnu-objcopy',
],
)
filegroup(
name = 'objdump',
srcs = [
'bin/aarch64-linux-gnu-objdump',
],
)
filegroup(
name = 'strip',
srcs = [
'bin/aarch64-linux-gnu-strip',
],
)
filegroup(
name = 'as',
srcs = [
'bin/aarch64-linux-gnu-as',
],
)
filegroup(
name = 'compiler_pieces',
srcs = glob([
'aarch64-linux-gnu/**',
'libexec/**',
'lib/gcc/aarch64-linux-gnu/**',
'include/**',
]),
)
filegroup(
name = 'compiler_components',
srcs = [
':gcc',
':ar',
':ld',
':nm',
':objcopy',
':objdump',
':strip',
':as',
],
)

View File

@ -0,0 +1,484 @@
# Copyright 2019 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Starlark cc_toolchain configuration rule"""
load("@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
"action_config",
"artifact_name_pattern",
"env_entry",
"env_set",
"feature",
"feature_set",
"flag_group",
"flag_set",
"make_variable",
"tool",
"tool_path",
"variable_with_value",
"with_feature_set",
)
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")
all_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.clif_match,
ACTION_NAMES.lto_backend,
]
all_cpp_compile_actions = [
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.clif_match,
]
preprocessor_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.clif_match,
]
codegen_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
]
all_link_actions = [
ACTION_NAMES.cpp_link_executable,
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
]
def _impl(ctx):
abi_version = "aarch64"
abi_libc_version = "glibc_2.24"
builtin_sysroot = None
compiler = "gcc"
host_system_name = "aarch64"
needs_pic = True
supports_gold_linker = False
supports_incremental_linker = False
supports_fission = False
supports_interface_shared_objects = False
supports_normalizing_ar = False
supports_start_end_lib = False
supports_thin_archives = False
target_libc = "glibc_2.24"
target_cpu = "armv8"
target_system_name = "armv8"
toolchain_identifier = "gcc72_linaro_aarch64"
cc_target_os = None
action_configs = []
supports_pic_feature = feature(name = "supports_pic", enabled = True)
supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
user_compile_flags_feature = feature(
name = "user_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = ["%{user_compile_flags}"],
iterate_over = "user_compile_flags",
expand_if_available = "user_compile_flags",
),
],
),
],
)
user_link_flags_feature = feature(
name = "user_link_flags",
flag_sets = [
flag_set(
actions = all_link_actions,
flag_groups = [
flag_group(
flags = ["%{user_link_flags}"],
iterate_over = "user_link_flags",
expand_if_available = "user_link_flags",
),
],
),
],
)
shared_flag_feature = feature(
name = "shared_flag",
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
ACTION_NAMES.lto_index_for_dynamic_library,
ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
],
flag_groups = [flag_group(flags = ["-shared"])],
),
],
)
sysroot_feature = feature(
name = "sysroot",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
ACTION_NAMES.cpp_link_executable,
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
],
flag_groups = [
flag_group(
flags = ["--sysroot=%{sysroot}"],
expand_if_available = "sysroot",
),
],
),
],
)
objcopy_embed_flags_feature = feature(
name = "objcopy_embed_flags",
enabled = True,
flag_sets = [
flag_set(
actions = ["objcopy_embed_data"],
flag_groups = [flag_group(flags = ["-I", "binary"])],
),
],
)
unfiltered_compile_flags_feature = feature(
name = "unfiltered_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
# Make C++ compilation deterministic. Use linkstamping instead of these
# compiler symbols.
"-Wno-builtin-macro-redefined",
"-D__DATE__=\"redacted\"",
"-D__TIMESTAMP__=\"redacted\"",
"-D__TIME__=\"redacted\"",
# This makes GCC and Clang do what we want when called through symlinks.
"-no-canonical-prefixes",
],
),
],
),
],
)
default_compile_flags_feature = feature(
name = "default_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-U_FORTIFY_SOURCE",
"-D_FORTIFY_SOURCE=1",
"-fstack-protector",
],
),
],
),
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [flag_group(flags = ["-g"])],
with_features = [with_feature_set(features = ["dbg"])],
),
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-g0",
"-O2",
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
],
),
],
with_features = [with_feature_set(features = ["opt"])],
),
flag_set(
actions = [
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-std=c++11",
"--sysroot=external/LinaroAarch64Gcc72/aarch64-linux-gnu/libc",
"-pthread",
"-nostdinc",
"-isystem",
"external/LinaroAarch64Gcc72/aarch64-linux-gnu/include/c++/7.2.1/aarch64-linux-gnu",
"-isystem",
"external/LinaroAarch64Gcc72/aarch64-linux-gnu/include/c++/7.2.1",
"-isystem",
"external/LinaroAarch64Gcc72/lib/gcc/aarch64-linux-gnu/7.2.1/include",
"-isystem",
"external/LinaroAarch64Gcc72/aarch64-linux-gnu/libc/usr/include",
"-isystem",
"external/LinaroAarch64Gcc72/lib/gcc/aarch64-linux-gnu/7.2.1/include-fixed",
"-isystem",
"external/LinaroAarch64Gcc72/aarch64-linux-gnu/libc/usr/include",
"-isystem",
"external/LinaroAarch64Gcc72/aarch64-linux-gnu/libc/usr/include/aarch64-linux-gnu",
"-isystem",
"external/LinaroAarch64Gcc72/lib/gcc/aarch64-linux-gnu/7.2.1/include",
"-isystem",
"external/LinaroAarch64Gcc72/include/c++/7.2.1/aarch64-linux-gnu",
"-isystem",
"external/LinaroAarch64Gcc72/include/c++/7.2.1",
# Security hardening on by default.
"-fstack-protector",
"-fPIE",
# All warnings are enabled. Maybe enable -Werror as well?
"-Wall",
# Enable a few more warnings that aren't part of -Wall.
"-Wunused-but-set-parameter",
# But disable some that are problematic.
"-Wno-free-nonheap-object", # has false positives
# Keep stack frames for debugging, even in opt mode.
"-fno-omit-frame-pointer",
# Enable coloring even if there's no attached terminal. Bazel removes the
# escape sequences if --nocolor is specified.
"-fdiagnostics-color=always",
],
),
],
),
],
)
default_link_flags_feature = feature(
name = "default_link_flags",
enabled = True,
flag_sets = [
flag_set(
actions = all_link_actions,
flag_groups = [
flag_group(
flags = [
# "-target",
# "aarch64-linux-gnu",
"--sysroot=external/LinaroAarch64Gcc72/aarch64-linux-gnu/libc",
"-pass-exit-codes",
"-pie",
"-lstdc++",
"-lm",
"-lpthread",
"-Wl,--dynamic-linker=/lib/ld-linux-aarch64.so.1",
"-Wl,-no-as-needed",
"-Wl,-z,relro,-z,now",
"-no-canonical-prefixes",
# Stamp the binary with a unique identifier.
"-Wl,--build-id=md5",
"-Wl,--hash-style=gnu",
"-Lexternal/LinaroAarch64Gcc72/aarch64-linux-gnu/lib",
"-Lexternal/LinaroAarch64Gcc72/aarch64-linux-gnu/libc/lib",
"-Lexternal/LinaroAarch64Gcc72/aarch64-linux-gnu/libc/usr/lib",
"-Bexternal/LinaroAarch64Gcc72/aarch64-linux-gnu/bin",
],
),
],
),
flag_set(
actions = all_link_actions,
flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])],
with_features = [with_feature_set(features = ["opt"])],
),
],
)
opt_feature = feature(name = "opt")
dbg_feature = feature(name = "dbg")
features = [
default_compile_flags_feature,
default_link_flags_feature,
supports_dynamic_linker_feature,
supports_pic_feature,
objcopy_embed_flags_feature,
opt_feature,
dbg_feature,
user_compile_flags_feature,
user_link_flags_feature,
shared_flag_feature,
sysroot_feature,
unfiltered_compile_flags_feature,
]
cxx_builtin_include_directories = [
"%package(@LinaroAarch64Gcc72//include)%",
"%package(@LinaroAarch64Gcc72//aarch64-linux-gnu/libc/usr/include)%",
"%package(@LinaroAarch64Gcc72//aarch64-linux-gnu/libc/lib/gcc/aarch64-linux-gnu/7.2.1/include-fixed)%",
"%package(@LinaroAarch64Gcc72//include)%/c++/7.2.1",
"%package(@LinaroAarch64Gcc72//aarch64-linux-gnu/libc/lib/gcc/aarch64-linux-gnu/7.2.1/include)%",
"%package(@LinaroAarch64Gcc72//aarch64-linux-gnu/libc/lib/gcc/aarch64-linux-gnu/7.2.1/include-fixed)%",
"%package(@LinaroAarch64Gcc72//lib/gcc/aarch64-linux-gnu/7.2.1/include)%",
"%package(@LinaroAarch64Gcc72//lib/gcc/aarch64-linux-gnu/7.2.1/include-fixed)%",
"%package(@LinaroAarch64Gcc72//aarch64-linux-gnu/include)%/c++/7.2.1",
]
artifact_name_patterns = []
make_variables = []
tool_paths = [
tool_path(name = "ar", path = "gcc/aarch64-linux-gnu-ar"),
tool_path(name = "compat-ld", path = "gcc/aarch64-linux-gnu-ld"),
tool_path(name = "cpp", path = "gcc/aarch64-linux-gnu-cpp"),
tool_path(name = "dwp", path = "gcc/aarch64-linux-gnu-dwp"),
tool_path(name = "gcc", path = "gcc/aarch64-linux-gnu-gcc"),
tool_path(name = "gcov", path = "arm-frc-linux-gnueabi/arm-frc-linux-gnueabi-gcov-4.9"),
# C(++), compiles invoke the compiler (as that is the one knowing where
# to find libraries),, but we provide LD so other rules can invoke the linker.
tool_path(name = "ld", path = "gcc/aarch64-linux-gnu-ld"),
tool_path(name = "nm", path = "gcc/aarch64-linux-gnu-nm"),
tool_path(name = "objcopy", path = "gcc/aarch64-linux-gnu-objcopy"),
tool_path(name = "objdump", path = "gcc/aarch64-linux-gnu-objdump"),
tool_path(name = "strip", path = "gcc/aarch64-linux-gnu-strip"),
]
return cc_common.create_cc_toolchain_config_info(
ctx = ctx,
features = features,
action_configs = action_configs,
artifact_name_patterns = artifact_name_patterns,
cxx_builtin_include_directories = cxx_builtin_include_directories,
toolchain_identifier = toolchain_identifier,
host_system_name = host_system_name,
target_system_name = target_system_name,
target_cpu = target_cpu,
target_libc = target_libc,
compiler = compiler,
abi_version = abi_version,
abi_libc_version = abi_libc_version,
tool_paths = tool_paths,
make_variables = make_variables,
builtin_sysroot = builtin_sysroot,
cc_target_os = cc_target_os,
)
linaro_toolchain_config = rule(
implementation = _impl,
attrs = {},
provides = [CcToolchainConfigInfo],
)

View File

@ -0,0 +1,67 @@
# This is the entry point for --crosstool_top.
#
# The cc_toolchain rule used is found by:
#
# 1. Finding the appropriate toolchain in the CROSSTOOL file based on the --cpu
# and --compiler command line flags (if they exist, otherwise using the
# "default_target_cpu" / "default_toolchain" fields in the CROSSTOOL file)
# 2. Concatenating the "target_cpu" and "compiler" fields of the toolchain in
# use and using that as a key in the map in the "toolchains" attribute
package(default_visibility = ["//visibility:public"])
load(":linaro_toolchain_config.bzl", "linaro_toolchain_config")
cc_toolchain_suite(
name = "toolchain",
toolchains = {
"armv7a": ":cc-compiler-armv7a",
},
)
filegroup(
name = "empty",
srcs = [],
)
filegroup(
name = "gcc_linux_all_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:tool-wrappers",
"@LinaroArmGcc72//:compiler_pieces",
],
)
filegroup(
name = "gcc_linux_linker_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:ld",
"//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:ar",
"@LinaroArmGcc72//:compiler_pieces",
],
)
filegroup(
name = "gcc_linux_compiler_files",
srcs = [
"//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:gcc",
"//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:as",
],
)
linaro_toolchain_config(name = "linaro_armeabi-v7a")
cc_toolchain(
name = "cc-compiler-armv7a",
all_files = ":gcc_linux_all_files",
compiler_files = ":gcc_linux_compiler_files",
toolchain_identifier = "gcc72_linaro_armhf",
toolchain_config = ":linaro_armeabi-v7a",
dwp_files = ":empty",
dynamic_runtime_lib = ":empty",
linker_files = ":gcc_linux_linker_files",
objcopy_files = "//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:objcopy",
static_runtime_lib = ":empty",
strip_files = "//third_party/toolchains/embedded/linaro-gcc72-armeabi/gcc:strip",
supports_param_files = 0,
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,79 @@
package(default_visibility = ['//third_party/toolchains/embedded/linaro-gcc72-armeabi:__pkg__'])
filegroup(
name = 'gcc',
srcs = [
'@LinaroArmGcc72//:gcc',
'arm-linux-gnueabihf-gcc',
],
)
filegroup(
name = 'ar',
srcs = [
'@LinaroArmGcc72//:ar',
'arm-linux-gnueabihf-ar',
],
)
filegroup(
name = 'ld',
srcs = [
'@LinaroArmGcc72//:ld',
'arm-linux-gnueabihf-ld',
],
)
filegroup(
name = 'nm',
srcs = [
'@LinaroArmGcc72//:nm',
'arm-linux-gnueabihf-nm',
],
)
filegroup(
name = 'objcopy',
srcs = [
'@LinaroArmGcc72//:objcopy',
'arm-linux-gnueabihf-objcopy',
],
)
filegroup(
name = 'objdump',
srcs = [
'@LinaroArmGcc72//:objdump',
'arm-linux-gnueabihf-objdump',
],
)
filegroup(
name = 'strip',
srcs = [
'@LinaroArmGcc72//:strip',
'arm-linux-gnueabihf-strip',
],
)
filegroup(
name = 'as',
srcs = [
'@LinaroArmGcc72//:as',
'arm-linux-gnueabihf-as',
],
)
filegroup(
name = 'tool-wrappers',
srcs = [
':gcc',
':ar',
':ld',
':nm',
':objcopy',
':objdump',
':strip',
':as',
],
)

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-ar \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-ar \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-as \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-as \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-cpp \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-cpp \
"$@"

View File

@ -0,0 +1,6 @@
#!/bin/bash --norc
PATH="external/LinaroArmGcc72/libexec/gcc/arm-linux-gnueabihf/7.2.1/:$PATH" \
exec \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-gcc \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-gcov \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-gcov \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-ld \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-ld \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-nm \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-nm \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-objcopy \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-objcopy \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-objdump \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-objdump \
"$@"

View File

@ -0,0 +1,5 @@
#!/bin/bash --norc
exec -a arm-linux-gnueabihf-strip \
external/LinaroArmGcc72/bin/arm-linux-gnueabihf-strip \
"$@"

View File

@ -0,0 +1,81 @@
package(default_visibility = ['//visibility:public'])
filegroup(
name = 'gcc',
srcs = [
'bin/arm-linux-gnueabihf-gcc',
],
)
filegroup(
name = 'ar',
srcs = [
'bin/arm-linux-gnueabihf-ar',
],
)
filegroup(
name = 'ld',
srcs = [
'bin/arm-linux-gnueabihf-ld',
],
)
filegroup(
name = 'nm',
srcs = [
'bin/arm-linux-gnueabihf-nm',
],
)
filegroup(
name = 'objcopy',
srcs = [
'bin/arm-linux-gnueabihf-objcopy',
],
)
filegroup(
name = 'objdump',
srcs = [
'bin/arm-linux-gnueabihf-objdump',
],
)
filegroup(
name = 'strip',
srcs = [
'bin/arm-linux-gnueabihf-strip',
],
)
filegroup(
name = 'as',
srcs = [
'bin/arm-linux-gnueabihf-as',
],
)
filegroup(
name = 'compiler_pieces',
srcs = glob([
'arm-linux-gnueabihf/**',
'libexec/**',
'lib/gcc/arm-linux-gnueabihf/**',
'include/**',
]),
)
filegroup(
name = 'compiler_components',
srcs = [
':gcc',
':ar',
':ld',
':nm',
':objcopy',
':objdump',
':strip',
':as',
],
)

View File

@ -0,0 +1,484 @@
# Copyright 2019 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Starlark cc_toolchain configuration rule"""
load("@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl",
"action_config",
"artifact_name_pattern",
"env_entry",
"env_set",
"feature",
"feature_set",
"flag_group",
"flag_set",
"make_variable",
"tool",
"tool_path",
"variable_with_value",
"with_feature_set",
)
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES")
all_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.clif_match,
ACTION_NAMES.lto_backend,
]
all_cpp_compile_actions = [
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.clif_match,
]
preprocessor_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.clif_match,
]
codegen_compile_actions = [
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
]
all_link_actions = [
ACTION_NAMES.cpp_link_executable,
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
]
def _impl(ctx):
abi_version = "armeabi"
abi_libc_version = "glibc_2.24"
builtin_sysroot = None
compiler = "gcc"
host_system_name = "armeabi"
needs_pic = True
supports_gold_linker = False
supports_incremental_linker = False
supports_fission = False
supports_interface_shared_objects = False
supports_normalizing_ar = False
supports_start_end_lib = False
supports_thin_archives = False
target_libc = "glibc_2.24"
target_cpu = "armv7"
target_system_name = "armeabi-v7a"
toolchain_identifier = "gcc72_linaro_armhf"
cc_target_os = None
action_configs = []
supports_pic_feature = feature(name = "supports_pic", enabled = True)
supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True)
user_compile_flags_feature = feature(
name = "user_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = ["%{user_compile_flags}"],
iterate_over = "user_compile_flags",
expand_if_available = "user_compile_flags",
),
],
),
],
)
user_link_flags_feature = feature(
name = "user_link_flags",
flag_sets = [
flag_set(
actions = all_link_actions,
flag_groups = [
flag_group(
flags = ["%{user_link_flags}"],
iterate_over = "user_link_flags",
expand_if_available = "user_link_flags",
),
],
),
],
)
shared_flag_feature = feature(
name = "shared_flag",
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
ACTION_NAMES.lto_index_for_dynamic_library,
ACTION_NAMES.lto_index_for_nodeps_dynamic_library,
],
flag_groups = [flag_group(flags = ["-shared"])],
),
],
)
sysroot_feature = feature(
name = "sysroot",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
ACTION_NAMES.cpp_link_executable,
ACTION_NAMES.cpp_link_dynamic_library,
ACTION_NAMES.cpp_link_nodeps_dynamic_library,
],
flag_groups = [
flag_group(
flags = ["--sysroot=%{sysroot}"],
expand_if_available = "sysroot",
),
],
),
],
)
objcopy_embed_flags_feature = feature(
name = "objcopy_embed_flags",
enabled = True,
flag_sets = [
flag_set(
actions = ["objcopy_embed_data"],
flag_groups = [flag_group(flags = ["-I", "binary"])],
),
],
)
unfiltered_compile_flags_feature = feature(
name = "unfiltered_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
# Make C++ compilation deterministic. Use linkstamping instead of these
# compiler symbols.
"-Wno-builtin-macro-redefined",
"-D__DATE__=\"redacted\"",
"-D__TIMESTAMP__=\"redacted\"",
"-D__TIME__=\"redacted\"",
# This makes GCC and Clang do what we want when called through symlinks.
"-no-canonical-prefixes",
],
),
],
),
],
)
default_compile_flags_feature = feature(
name = "default_compile_flags",
enabled = True,
flag_sets = [
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-U_FORTIFY_SOURCE",
"-D_FORTIFY_SOURCE=1",
"-fstack-protector",
],
),
],
),
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [flag_group(flags = ["-g"])],
with_features = [with_feature_set(features = ["dbg"])],
),
flag_set(
actions = [
ACTION_NAMES.assemble,
ACTION_NAMES.preprocess_assemble,
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.c_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-g0",
"-O2",
"-DNDEBUG",
"-ffunction-sections",
"-fdata-sections",
],
),
],
with_features = [with_feature_set(features = ["opt"])],
),
flag_set(
actions = [
ACTION_NAMES.linkstamp_compile,
ACTION_NAMES.cpp_compile,
ACTION_NAMES.cpp_header_parsing,
ACTION_NAMES.cpp_module_compile,
ACTION_NAMES.cpp_module_codegen,
ACTION_NAMES.lto_backend,
ACTION_NAMES.clif_match,
],
flag_groups = [
flag_group(
flags = [
"-std=c++11",
"--sysroot=external/LinaroArmGcc72/arm-linux-gnueabihf/libc",
"-pthread",
"-nostdinc",
"-isystem",
"external/LinaroArmGcc72/arm-linux-gnueabihf/include/c++/7.2.1/arm-linux-gnueabihf",
"-isystem",
"external/LinaroArmGcc72/arm-linux-gnueabihf/include/c++/7.2.1",
"-isystem",
"external/LinaroArmGcc72/lib/gcc/arm-linux-gnueabihf/7.2.1/include",
"-isystem",
"external/LinaroArmGcc72/arm-linux-gnueabihf/libc/usr/include",
"-isystem",
"external/LinaroArmGcc72/lib/gcc/arm-linux-gnueabihf/7.2.1/include-fixed",
"-isystem",
"external/LinaroArmGcc72/arm-linux-gnueabihf/libc/usr/include",
"-isystem",
"external/LinaroArmGcc72/arm-linux-gnueabihf/libc/usr/include/arm-linux-gnueabihf",
"-isystem",
"external/LinaroArmGcc72/lib/gcc/arm-linux-gnueabihf/7.2.1/include",
"-isystem",
"external/LinaroArmGcc72/include/c++/7.2.1/arm-linux-gnueabihf",
"-isystem",
"external/LinaroArmGcc72/include/c++/7.2.1",
# Security hardening on by default.
"-fstack-protector",
"-fPIE",
# All warnings are enabled. Maybe enable -Werror as well?
"-Wall",
# Enable a few more warnings that aren't part of -Wall.
"-Wunused-but-set-parameter",
# But disable some that are problematic.
"-Wno-free-nonheap-object", # has false positives
# Keep stack frames for debugging, even in opt mode.
"-fno-omit-frame-pointer",
# Enable coloring even if there's no attached terminal. Bazel removes the
# escape sequences if --nocolor is specified.
"-fdiagnostics-color=always",
],
),
],
),
],
)
default_link_flags_feature = feature(
name = "default_link_flags",
enabled = True,
flag_sets = [
flag_set(
actions = all_link_actions,
flag_groups = [
flag_group(
flags = [
# "-target",
# "arm-linux-gnueabihf",
"--sysroot=external/LinaroArmGcc72/arm-linux-gnueabihf/libc",
"-pass-exit-codes",
"-pie",
"-lstdc++",
"-lm",
"-lpthread",
"-Wl,--dynamic-linker=/lib/ld-linux-armhf.so.3",
"-Wl,-no-as-needed",
"-Wl,-z,relro,-z,now",
"-no-canonical-prefixes",
# Stamp the binary with a unique identifier.
"-Wl,--build-id=md5",
"-Wl,--hash-style=gnu",
"-Lexternal/LinaroArmGcc72/arm-linux-gnueabihf/lib",
"-Lexternal/LinaroArmGcc72/arm-linux-gnueabihf/libc/lib",
"-Lexternal/LinaroArmGcc72/arm-linux-gnueabihf/libc/usr/lib",
"-Bexternal/LinaroArmGcc72/arm-linux-gnueabihf/bin",
],
),
],
),
flag_set(
actions = all_link_actions,
flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])],
with_features = [with_feature_set(features = ["opt"])],
),
],
)
opt_feature = feature(name = "opt")
dbg_feature = feature(name = "dbg")
features = [
default_compile_flags_feature,
default_link_flags_feature,
supports_dynamic_linker_feature,
supports_pic_feature,
objcopy_embed_flags_feature,
opt_feature,
dbg_feature,
user_compile_flags_feature,
user_link_flags_feature,
shared_flag_feature,
sysroot_feature,
unfiltered_compile_flags_feature,
]
cxx_builtin_include_directories = [
"%package(@LinaroArmGcc72//include)%",
"%package(@LinaroArmGcc72//arm-linux-gnueabihf/libc/usr/include)%",
"%package(@LinaroArmGcc72//arm-linux-gnueabihf/libc/lib/gcc/arm-linux-gnueabihf/7.2.1/include-fixed)%",
"%package(@LinaroArmGcc72//include)%/c++/7.2.1",
"%package(@LinaroArmGcc72//arm-linux-gnueabihf/libc/lib/gcc/arm-linux-gnueabihf/7.2.1/include)%",
"%package(@LinaroArmGcc72//arm-linux-gnueabihf/libc/lib/gcc/arm-linux-gnueabihf/7.2.1/include-fixed)%",
"%package(@LinaroArmGcc72//lib/gcc/arm-linux-gnueabihf/7.2.1/include)%",
"%package(@LinaroArmGcc72//lib/gcc/arm-linux-gnueabihf/7.2.1/include-fixed)%",
"%package(@LinaroArmGcc72//arm-linux-gnueabihf/include)%/c++/7.2.1",
]
artifact_name_patterns = []
make_variables = []
tool_paths = [
tool_path(name = "ar", path = "gcc/arm-linux-gnueabihf-ar"),
tool_path(name = "compat-ld", path = "gcc/arm-linux-gnueabihf-ld"),
tool_path(name = "cpp", path = "gcc/arm-linux-gnueabihf-cpp"),
tool_path(name = "dwp", path = "gcc/arm-linux-gnueabihf-dwp"),
tool_path(name = "gcc", path = "gcc/arm-linux-gnueabihf-gcc"),
tool_path(name = "gcov", path = "arm-frc-linux-gnueabi/arm-frc-linux-gnueabi-gcov-4.9"),
# C(++), compiles invoke the compiler (as that is the one knowing where
# to find libraries),, but we provide LD so other rules can invoke the linker.
tool_path(name = "ld", path = "gcc/arm-linux-gnueabihf-ld"),
tool_path(name = "nm", path = "gcc/arm-linux-gnueabihf-nm"),
tool_path(name = "objcopy", path = "gcc/arm-linux-gnueabihf-objcopy"),
tool_path(name = "objdump", path = "gcc/arm-linux-gnueabihf-objdump"),
tool_path(name = "strip", path = "gcc/arm-linux-gnueabihf-strip"),
]
return cc_common.create_cc_toolchain_config_info(
ctx = ctx,
features = features,
action_configs = action_configs,
artifact_name_patterns = artifact_name_patterns,
cxx_builtin_include_directories = cxx_builtin_include_directories,
toolchain_identifier = toolchain_identifier,
host_system_name = host_system_name,
target_system_name = target_system_name,
target_cpu = target_cpu,
target_libc = target_libc,
compiler = compiler,
abi_version = abi_version,
abi_libc_version = abi_libc_version,
tool_paths = tool_paths,
make_variables = make_variables,
builtin_sysroot = builtin_sysroot,
cc_target_os = cc_target_os,
)
linaro_toolchain_config = rule(
implementation = _impl,
attrs = {},
provides = [CcToolchainConfigInfo],
)