Compare commits
87 Commits
v2.4.0-rc2
...
rei/fork-r
Author | SHA1 | Date | |
---|---|---|---|
27a1657c4f | |||
|
4bdd395511 | ||
|
9173bdff3d | ||
|
182e869fb8 | ||
|
dc74c09b8d | ||
|
90f5b25508 | ||
|
9b67f161e5 | ||
|
811608d4e4 | ||
|
ca5d9fdf5c | ||
|
23ad988fcd | ||
|
6dc2a1becf | ||
|
4336a5b49f | ||
|
6fad14b203 | ||
|
ee598066c4 | ||
|
fd3b3ca6f5 | ||
|
b36436b087 | ||
|
c4b2951888 | ||
|
82515cee8f | ||
|
ab9c694484 | ||
|
1f610fc5ae | ||
|
44e3817ad0 | ||
|
dbbdcde0fd | ||
|
ca2b7ba75c | ||
|
67ab71d747 | ||
|
d60d7d3c7e | ||
|
2ef4243a20 | ||
|
f9233753f3 | ||
|
bb3c460114 | ||
|
f923fa474b | ||
|
8edf01fbdc | ||
|
557fdcfcc9 | ||
|
f989a28561 | ||
|
89bb4c3f42 | ||
|
b8c74daee7 | ||
|
7e0abd9c89 | ||
|
7d310df2de | ||
|
6b9a9d98bb | ||
|
0aa1d61fad | ||
|
0b2321fdd1 | ||
|
2b03d7b7a0 | ||
|
549064075e | ||
|
d03e29a094 | ||
|
ef4db27b31 | ||
|
13c4eadd25 | ||
|
de4c4425b7 | ||
|
b8694e39d8 | ||
|
98a59288c8 | ||
|
14b2d686d6 | ||
|
bf61dd6420 | ||
|
6dc322fdfb | ||
|
2f2ace8bc3 | ||
|
f89998cac6 | ||
|
602c151bb1 | ||
|
1725ab6962 | ||
|
fc9d68a5e8 | ||
|
12a91f913c | ||
|
257447e193 | ||
|
61b2024a19 | ||
|
4703d0ddce | ||
|
ed8d661b5d | ||
|
10ba65efab | ||
|
8cdfc53a63 | ||
|
a3e64f721c | ||
|
02ad000479 | ||
|
d3dc6a2071 | ||
|
9310f2a180 | ||
|
3b27581629 | ||
|
99fea8da0d | ||
|
8964fa8419 | ||
|
9cc469ac21 | ||
|
fbaa6d4346 | ||
|
5af8da4a17 | ||
|
04bd81d7e6 | ||
|
2f810e1d36 | ||
|
b878ae3918 | ||
|
91bdadc08e | ||
|
890eae3e88 | ||
|
83e4305ca4 | ||
|
e5a4b2534b | ||
|
f41584353a | ||
|
fa305ce9a7 | ||
|
01f5259e79 | ||
|
48064efcfd | ||
|
88e3eebad5 | ||
|
f9ca9abe78 | ||
|
8d12b76309 | ||
|
b4c95671f2 |
26
.bazelrc
26
.bazelrc
@ -94,6 +94,9 @@ build:libc++ --linkopt -fuse-ld=lld
|
|||||||
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
|
# https://docs.bazel.build/versions/master/user-manual.html#flag--fat_apk_cpu
|
||||||
build:android --crosstool_top=//external:android/crosstool
|
build:android --crosstool_top=//external:android/crosstool
|
||||||
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
|
build:android --copt=-D_GLIBCXX_USE_C99
|
||||||
|
build:android --cxxopt=-std=c++14
|
||||||
|
build:android --action_env ANDROID_NDK_API_LEVEL=21
|
||||||
build:android_arm --config=android
|
build:android_arm --config=android
|
||||||
build:android_arm --cpu=armeabi-v7a
|
build:android_arm --cpu=armeabi-v7a
|
||||||
build:android_arm --fat_apk_cpu=armeabi-v7a
|
build:android_arm --fat_apk_cpu=armeabi-v7a
|
||||||
@ -202,6 +205,29 @@ build:sycl_asan --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --co
|
|||||||
build:sycl_nodouble --config=sycl
|
build:sycl_nodouble --config=sycl
|
||||||
build:sycl_trisycl --define=using_trisycl=true
|
build:sycl_trisycl --define=using_trisycl=true
|
||||||
|
|
||||||
|
build --copt=-DTFLITE_WITH_RUY_GEMV
|
||||||
|
|
||||||
|
build:rpi3 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
|
build:rpi3 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-armeabi:toolchain
|
||||||
|
build:rpi3 --cpu=armv7a --define=target_system=rpi3
|
||||||
|
build:rpi3 --copt=-march=armv7-a --copt=-mtune=cortex-a53 --copt=-mfloat-abi=hard --copt=-mfpu=neon-fp-armv8 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99 --copt=-mno-unaligned-access
|
||||||
|
build:rpi3 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:rpi3_opt -c opt --config=rpi3 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
|
||||||
|
|
||||||
|
build:rpi3-armv8 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
|
build:rpi3-armv8 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-aarch64:toolchain
|
||||||
|
build:rpi3-armv8 --cpu=aarch64 --define=target_system=rpi3-armv8
|
||||||
|
build:rpi3-armv8 --copt=-march=armv8-a --copt=-mtune=cortex-a53 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99
|
||||||
|
build:rpi3-armv8 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:rpi3-armv8_opt -c opt --config=rpi3-armv8 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
|
||||||
|
|
||||||
|
build:rpi4ub-armv8 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
|
build:rpi4ub-armv8 --crosstool_top=//third_party/toolchains/embedded/linaro-gcc72-aarch64:toolchain
|
||||||
|
build:rpi4ub-armv8 --cpu=aarch64 --define=target_system=rpi4ub-armv8
|
||||||
|
build:rpi4ub-armv8 --copt=-march=armv8-a --copt=-mtune=cortex-a72 --copt=-DRASPBERRY_PI --copt=-D_GLIBCXX_USE_CXX11_ABI=0 --copt=-std=gnu99
|
||||||
|
build:rpi4ub-armv8 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
|
build:rpi4ub-armv8_opt -c opt --config=rpi4ub-armv8 --copt=-funsafe-math-optimizations --copt=-ftree-vectorize --copt=-pipe
|
||||||
|
|
||||||
# Options extracted from configure script
|
# Options extracted from configure script
|
||||||
build:ngraph --define=with_ngraph_support=true
|
build:ngraph --define=with_ngraph_support=true
|
||||||
build:numa --define=with_numa_support=true
|
build:numa --define=with_numa_support=true
|
||||||
|
15
.github/pull_request_template.md
vendored
Normal file
15
.github/pull_request_template.md
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Pull request guidelines
|
||||||
|
|
||||||
|
Welcome to the 🐸tensorflow project! We are excited to see your interest, and appreciate your support!
|
||||||
|
|
||||||
|
This repository is governed by the Contributor Covenant Code of Conduct. For more details, see the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file.
|
||||||
|
|
||||||
|
In order to make a good pull request, please see our [CONTRIBUTING.md](CONTRIBUTING.md) file.
|
||||||
|
|
||||||
|
Before accepting your pull request, you will be asked to sign a [Contributor License Agreement](https://cla-assistant.io/coqui-ai/tensorflow).
|
||||||
|
|
||||||
|
This [Contributor License Agreement](https://cla-assistant.io/coqui-ai/tensorflow):
|
||||||
|
|
||||||
|
- Protects you, Coqui, and the users of the code.
|
||||||
|
- Does not change your rights to use your contributions for any purpose.
|
||||||
|
- Does not change the license of the 🐸tensorflow project. It just makes the terms of your contribution clearer and lets us know you are OK to contribute.
|
209
RELEASE.md
209
RELEASE.md
@ -1,15 +1,206 @@
|
|||||||
# Release 2.3.0
|
# Release 2.3.0
|
||||||
|
|
||||||
## Breaking Changes
|
## Major Features and Improvements
|
||||||
|
* `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources:
|
||||||
|
* [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot)
|
||||||
|
* [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service).
|
||||||
|
|
||||||
|
In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler.
|
||||||
|
|
||||||
|
* [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`).
|
||||||
|
|
||||||
|
* [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level.
|
||||||
|
|
||||||
|
* Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers.
|
||||||
|
|
||||||
|
* TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental).
|
||||||
|
|
||||||
|
* Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds).
|
||||||
|
|
||||||
|
* The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations.
|
||||||
|
|
||||||
|
## Breaking Changes
|
||||||
|
* Increases the **minimum bazel version** required to build TF to **3.1.0**.
|
||||||
|
* `tf.data`
|
||||||
|
* Makes the following (breaking) changes to the `tf.data`.
|
||||||
|
* C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation.
|
||||||
|
* The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`.
|
||||||
|
* Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed.
|
||||||
|
* The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly.
|
||||||
|
* `tf.keras`
|
||||||
|
* Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback.
|
||||||
|
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
||||||
|
where `centered=False` and `normalized=False`. This is a breaking change as
|
||||||
|
the output is different from (incorrect) previous versions. Note this
|
||||||
|
breaking change only impacts `tf.image.extract_glimpse` and
|
||||||
|
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
||||||
|
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
||||||
|
exsiting C++ kernel `ExtractGlimpse` does not change either, so saved
|
||||||
|
models using `tf.raw_ops.ExtractGlimpse` will not be impacted.
|
||||||
|
|
||||||
|
## Known Caveats
|
||||||
|
* `tf.lite`
|
||||||
|
* Keras-based LSTM models must be converted with an explicit batch size in the input layer.
|
||||||
|
|
||||||
|
## Bug Fixes and Other Changes
|
||||||
|
|
||||||
|
### TF Core:
|
||||||
|
* Set `tf2_behavior` to 1 to enable V2 for early loading cases.
|
||||||
|
* Add `execute_fn_for_device function` to dynamically choose the implementation based on underlying device placement.
|
||||||
|
* Eager:
|
||||||
|
* Add `reduce_logsumexp` benchmark with experiment compile.
|
||||||
|
* Give `EagerTensor`s a meaningful `__array__` implementation.
|
||||||
|
* Add another version of defun matmul for performance analysis.
|
||||||
|
* `tf.function`/AutoGraph:
|
||||||
|
* `AutoGraph` now includes into TensorFlow loops any variables that are closed over by local functions. Previously, such variables were sometimes incorrectly ignored.
|
||||||
|
* functions returned by the `get_concrete_function` method of `tf.function` objects can now be called with arguments consistent with the original arguments or type specs passed to `get_concrete_function`. This calling convention is now the preferred way to use concrete functions with nested values and composite tensors. Please check the [guide](https://www.tensorflow.org/guide/concrete_function) for more details on `concrete_ function`.
|
||||||
|
* Update `tf.function`'s `experimental_relax_shapes` to handle composite tensors appropriately.
|
||||||
|
* Optimize `tf.function` invocation, by removing redundant list converter.
|
||||||
|
* `tf.function` will retrace when called with a different variable instead of simply using the `dtype` & `shape`.
|
||||||
|
* [Improve support](https://github.com/tensorflow/tensorflow/issues/33862) for dynamically-sized TensorArray inside `tf.function`.
|
||||||
|
* `tf.math`:
|
||||||
|
* Narrow down `argmin`/`argmax` contract to always return the smallest index for ties.
|
||||||
|
* `tf.math.reduce_variance` and `tf.math.reduce_std` return correct computation for complex types and no longer support integer types.
|
||||||
|
* Add Bessel functions of order 0,1 to `tf.math.special`.
|
||||||
|
* `tf.divide` now always returns a tensor to be consistent with documentation and other APIs.
|
||||||
|
* `tf.image`:
|
||||||
|
* Replaced [`tf.image.non_max_suppression_padded`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/image/non_max_suppression_padded?hl=en) with a new implementation that supports batched inputs, which is considerably faster on TPUs and GPUs. Boxes with area=0 will be ignored. Existing usage with single inputs should still work as before.
|
||||||
|
* `tf.linalg`
|
||||||
|
* Add `tf.linalg.banded_triangular_solve`.
|
||||||
|
* `tf.random`:
|
||||||
|
* Add `tf.random.stateless_parameterized_truncated_normal`.
|
||||||
|
* `tf.ragged`:
|
||||||
|
* Add `tf.ragged.cross` and `tf.ragged.cross_hashed` operations.
|
||||||
|
* `tf.RaggedTensor`:
|
||||||
|
* `RaggedTensor.to_tensor()` now preserves static shape.
|
||||||
|
* Add `tf.strings.format()` and `tf.print()` to support RaggedTensors.
|
||||||
|
* `tf.saved_model`:
|
||||||
|
* `@tf.function` from SavedModel no longer ignores args after a `RaggedTensor` when selecting the concrete function to run.
|
||||||
|
* Fix save model issue for ops with a list of functions.
|
||||||
|
* Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights.
|
||||||
|
* Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights.
|
||||||
|
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||||
|
* GPU
|
||||||
|
* TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities.
|
||||||
|
* Others
|
||||||
|
* Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`.
|
||||||
|
* Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations.
|
||||||
|
* `tf.custom_gradient` can now be applied to functions that accept nested structures of `tensors` as inputs (instead of just a list of tensors). Note that Python structures such as tuples and lists now won't be treated as tensors, so if you still want them to be treated that way, you need to wrap them with `tf.convert_to_tensor`.
|
||||||
|
* No lowering on gradient case op when input is `DeviceIndex` op.
|
||||||
|
* Extend the ragged version of `tf.gather` to support `batch_dims` and `axis` args.
|
||||||
|
* Update `tf.map_fn` to support RaggedTensors and SparseTensors.
|
||||||
|
* Deprecate `tf.group`. It is not useful in eager mode.
|
||||||
|
* Add CPU and GPU implementation of modified variation of [`FTRL`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrl)/[`FTRLV2`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrlV2) that can triggerred by `multiply_linear_by_lr` allowing a learning rate of zero.
|
||||||
|
|
||||||
|
### `tf.data`:
|
||||||
|
* `tf.data.experimental.dense_to_ragged_batch` works correctly with tuples.
|
||||||
|
* `tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank.
|
||||||
|
* `tf.data.experimental.cardinality` is now a method on `tf.data.Dataset`.
|
||||||
|
* `tf.data.Dataset` now supports `len(Dataset)` when the cardinality is finite.
|
||||||
|
|
||||||
|
### `tf.distribute`:
|
||||||
|
* Expose experimental [`tf.distribute.DistributedDataset`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedDataset?hl=en) and [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator) to distribute input data when using `tf.distribute` to scale training on multiple devices.
|
||||||
|
* Added a [`get_next_as_optional`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en#get_next_as_optional) method for [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en) class to return a `tf.experimental.Optional` instance that contains the next value for all replicas or none instead of raising an out of range error. Also see *new* [guide on input distribution](https://www.tensorflow.org/tutorials/distribute/input).
|
||||||
|
* Allow var.assign on MirroredVariables with aggregation=NONE in replica context. Previously this would raise an error. We now allow this because many users and library writers find using `.assign` in replica context to be more convenient, instead of having to use `Strategy.extended.update` which was the previous way of updating variables in this situation.
|
||||||
|
* `tf.distribute.experimental.MultiWorkerMirroredStrategy` adds support for partial batches. Workers running out of data now continue to participate in the training with empty inputs, instead of raising an error. Learn more about [partial batches here](https://www.tensorflow.org/tutorials/distribute/input#partial_batches).
|
||||||
|
* Improve the performance of reading metrics eagerly under `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||||
|
* Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses.
|
||||||
|
* Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`.
|
||||||
|
* Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization.
|
||||||
|
|
||||||
|
### `tf.keras`:
|
||||||
|
* Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs.
|
||||||
|
* Added **categorical data** processing layers:
|
||||||
|
* `IntegerLookup` & `StringLookup`: build an index of categorical feature values
|
||||||
|
* `CategoryEncoding`: turn integer-encoded categories into one-hot, multi-hot, or tf-idf encoded representations
|
||||||
|
* `CategoryCrossing`: create new categorical features representing co-occurrences of previous categorical feature values
|
||||||
|
* `Hashing`: the hashing trick, for large-vocabulary categorical features
|
||||||
|
* `Discretization`: turn continuous numerical features into categorical features by binning their values
|
||||||
|
* Improved **image preprocessing** layers: `CenterCrop`, `Rescaling`
|
||||||
|
* Improved **image augmentation** layers: `RandomCrop`, `RandomFlip`, `RandomTranslation`, `RandomRotation`, `RandomHeight`, `RandomWidth`, `RandomZoom`, `RandomContrast`
|
||||||
|
* Improved **`TextVectorization`** layer, which handles string tokenization, n-gram generation, and token encoding
|
||||||
|
* The `TextVectorization` layer now accounts for the mask_token as part of the vocabulary size when output_mode='int'. This means that, if you have a max_tokens value of 5000, your output will have 5000 unique values (not 5001 as before).
|
||||||
|
* Change the return value of `TextVectorization.get_vocabulary()` from `byte` to `string`. Users who previously were calling 'decode' on the output of this method should no longer need to do so.
|
||||||
|
* Introduce new Keras dataset generation utilities :
|
||||||
|
* **[`image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)** is a utility based on `tf.data.Dataset`, meant to replace the legacy `ImageDataGenerator`. It takes you from a structured directory of images to a labeled dataset, in one function call. Note that it doesn't perform image data augmentation (which is meant to be done using preprocessing layers).
|
||||||
|
* **[`text_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory)** takes you from a structured directory of text files to a labeled dataset, in one function call.
|
||||||
|
* **[`timeseries_dataset_from_array`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/timeseries_dataset_from_array)** is a `tf.data.Dataset`-based replacement of the legacy `TimeseriesGenerator`. It takes you from an array of timeseries data to a dataset of shifting windows with their targets.
|
||||||
|
* Added [`experimental_steps_per_execution`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/Model?hl=en#compile)
|
||||||
|
arg to `model.compile` to indicate the number of batches to run per `tf.function` call. This can speed up Keras Models on TPUs up to 3x.
|
||||||
|
* Extends `tf.keras.layers.Lambda` layers to support multi-argument lambdas, and keyword arguments when calling the layer.
|
||||||
|
* Functional models now get constructed if *any* tensor in a layer call's arguments/keyword arguments comes from a keras input. Previously the functional api would only work if all of the elements in the first argument to the layer came from a keras input.
|
||||||
|
* Clean up `BatchNormalization` layer's `trainable` property to act like standard python state when it's used inside `tf.functions` (frozen at tracing time), instead of acting like a pseudo-variable whose updates *kind of sometimes* get reflected in already-traced `tf.function` traces.
|
||||||
|
* Add the `Conv1DTranspose` layer.
|
||||||
|
* Refine the semantics of `SensitivitySpecificityBase` derived metrics. See the updated API docstrings for [`tf.keras.metrics.SensitivityAtSpecificity`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SensitivityAtSpecificity) and [`tf.keras.metrics.SpecificityAtSensitivty`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SpecificityAtSensitivity).
|
||||||
|
|
||||||
|
### `tf.lite`:
|
||||||
|
* Converter
|
||||||
|
* Restored `inference_input_type` and `inference_output_type` flags in TF 2.x TFLiteConverter (backward compatible with TF 1.x) to support integer (tf.int8, tf.uint8) input and output types in post training full integer quantized models.
|
||||||
|
* Added support for converting and resizing models with dynamic (placeholder) dimensions. Previously, there was only limited support for dynamic batch size, and even that did not guarantee that the model could be properly resized at runtime.
|
||||||
|
* Enabled experimental support for a new quantization mode with 16-bit activations and 8-bit weights. See `lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8`.
|
||||||
|
* CPU
|
||||||
|
* Fix an issue w/ dynamic weights and `Conv2D` on x86.
|
||||||
|
* Add a runtime Android flag for enabling `XNNPACK` for optimized CPU performance.
|
||||||
|
* Add a runtime iOS flag for enabling `XNNPACK` for optimized CPU performance.
|
||||||
|
* Add a compiler flag to enable building a TFLite library that applies `XNNPACK` delegate automatically when the model has a `fp32` operation.
|
||||||
|
* GPU
|
||||||
|
* Allow GPU acceleration starting with internal graph nodes
|
||||||
|
* Experimental support for quantized models with the Android GPU delegate
|
||||||
|
* Add GPU delegate whitelist.
|
||||||
|
* Rename GPU whitelist -> compatibility (list).
|
||||||
|
* Improve GPU compatibility list entries from crash reports.
|
||||||
|
* NNAPI
|
||||||
|
* Set default value for `StatefulNnApiDelegate::Options::max_number_delegated_partitions` to 3.
|
||||||
|
* Add capability to disable `NNAPI` CPU and check `NNAPI` Errno.
|
||||||
|
* Fix crashes when using `NNAPI` with target accelerator specified with model containing Conv2d or FullyConnected or LSTM nodes with quantized weights.
|
||||||
|
* Fix `ANEURALNETWORKS_BAD_DATA` execution failures with `sum`/`max`/`min`/`reduce` operations with `scalar` inputs.
|
||||||
|
* Hexagon
|
||||||
|
* TFLite Hexagon Delegate out of experimental.
|
||||||
|
* Experimental `int8` support for most hexagon ops.
|
||||||
|
* Experimental per-channel quant support for `conv` in Hexagon delegate.
|
||||||
|
* Support dynamic batch size in C++ API.
|
||||||
|
* CoreML
|
||||||
|
* Opensource CoreML delegate
|
||||||
|
* Misc
|
||||||
|
* Enable building Android TFLite targets on Windows
|
||||||
|
* Add support for `BatchMatMul`.
|
||||||
|
* Add support for `half_pixel_centers` with `ResizeNearestNeighbor`.
|
||||||
|
* Add 3D support for `BatchToSpaceND`.
|
||||||
|
* Add 5D support for `BroadcastSub`, `Maximum`, `Minimum`, `Transpose` and `BroadcastDiv`.
|
||||||
|
* Rename `kTfLiteActRelu1` to `kTfLiteActReluN1To1`.
|
||||||
|
* Enable flex delegate on tensorflow.lite.Interpreter Python package.
|
||||||
|
* Add `Buckettize`, `SparseCross` and `BoostedTreesBucketize` to the flex whitelist.
|
||||||
|
* Add support for selective registration of flex ops.
|
||||||
|
* Add missing kernels for flex delegate whitelisted ops.
|
||||||
|
* Fix issue when using direct `ByteBuffer` inputs with graphs that have dynamic shapes.
|
||||||
|
* Fix error checking supported operations in a model containing `HardSwish`.
|
||||||
|
|
||||||
|
### Packaging Support
|
||||||
|
* Added `tf.sysconfig.get_build_info()`. Returns a dict that describes the build environment of the currently installed TensorFlow package, e.g. the NVIDIA CUDA and NVIDIA CuDNN versions used when TensorFlow was built.
|
||||||
|
|
||||||
|
### Profiler
|
||||||
|
* Fix a subtle use-after-free issue in `XStatVisitor::RefValue()`.
|
||||||
|
|
||||||
|
### TPU Enhancements
|
||||||
|
* Adds 3D mesh support in TPU configurations ops.
|
||||||
|
* Added TPU code for `FTRL` with `multiply_linear_by_lr`.
|
||||||
|
* Silently adds a new file system registry at `gstpu`.
|
||||||
|
* Support `restartType` in cloud tpu client.
|
||||||
|
* Depend on a specific version of google-api-python-client.
|
||||||
|
* Fixes apiclient import.
|
||||||
|
|
||||||
|
### Tracing and Debugging
|
||||||
|
* Add a `TFE_Py_Execute` traceme.
|
||||||
|
|
||||||
|
### XLA Support
|
||||||
|
* Implement stable `argmin` and `argmax`
|
||||||
|
|
||||||
|
## Thanks to our Contributors
|
||||||
|
|
||||||
|
This release contains contributions from many people at Google, as well as:
|
||||||
|
|
||||||
|
902449@58880@bigcat_chen@ASIC, Abdul Baseer Khan, Abhineet Choudhary, Abolfazl Shahbazi, Adam Hillier, ag.ramesh, Agoniii, Ajay P, Alex Hoffman, Alexander Bayandin, Alexander Grund, Alexandre Abadie, Alexey Rogachevskiy, amoitra, Andrew Stevens, Angus-Luo, Anshuman Tripathy, Anush Elangovan, Artem Mavrin, Ashutosh Hathidara, autoih, Ayushman Kumar, ayushmankumar7, Bairen Yi, Bas Aarts, Bastian Eichenberger, Ben Barsdell, bhack, Bharat Raghunathan, Biagio Montaruli, Bigcat-Himax, blueyi, Bryan Cutler, Byambaa, Carlos Hernandez-Vaquero, Chen Lei, Chris Knorowski, Christian Clauss, chuanqiw, CuiYifeng, Daniel Situnayake, Daria Zhuravleva, Dayananda-V, Deven Desai, Devi Sandeep Endluri, Dmitry Zakharov, Dominic Jack, Duncan Riach, Edgar Liberis, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, Eugene Kuznetsov, Eugene Mikhantiev, Evgenii Zheltonozhskii, Fabio Di Domenico, Fausto Morales, Fei Sun, feihugis, Felix E. Klee, flyingcat, Frederic Bastien, Fredrik Knutsson, frreiss, fsx950223, ganler, Gaurav Singh, Georgios Pinitas, Gian Marco Iodice, Giorgio Arena, Giuseppe Rossini, Gregory Keith, Guozhong Zhuang, gurushantj, Hahn Anselm, Harald Husum, Harjyot Bagga, Hristo Vrigazov, Ilya Persky, Ir1d, Itamar Turner-Trauring, jacco, Jake Tae, Janosh Riebesell, Jason Zaman, jayanth, Jeff Daily, Jens Elofsson, Jinzhe Zeng, JLZ, Jonas Skog, Jonathan Dekhtiar, Josh Meyer, Joshua Chia, Judd, justkw, Kaixi Hou, Kam D Kasravi, Kamil Rakoczy, Karol Gugala, Kayou, Kazuaki Ishizaki, Keith Smiley, Khaled Besrour, Kilaru Yasaswi Sri Chandra Gandhi, Kim, Young Soo, Kristian Hartikainen, Kwabena W. Agyeman, Leslie-Fang, Leslie-Fang-Intel, Li, Guizi, Lukas Geiger, Lutz Roeder, M\U00E5Ns Nilsson, Mahmoud Abuzaina, Manish, Marcel Koester, Marcin Sielski, marload, Martin Jul, Matt Conley, mdfaijul, Meng, Peng, Meteorix, Michael Käufl, Michael137, Milan Straka, Mitchell Vitez, Ml-0, Mokke Meguru, Mshr-H, nammbash, Nathan Luehr, naumkin, Neeraj Bhadani, ngc92, Nick Morgan, nihui, Niranjan Hasabnis, Niranjan Yadla, Nishidha Panpaliya, Oceania2018, oclyke, Ouyang Jin, OverLordGoldDragon, Owen Lyke, Patrick Hemmer, Paul Andrey, Peng Sun, periannath, Phil Pearl, Prashant Dandriyal, Prashant Kumar, Rahul Huilgol, Rajan Singh, Rajeshwar Reddy T, rangjiaheng, Rishit Dagli, Rohan Reddy, rpalakkal, rposts, Ruan Kunliang, Rushabh Vasani, Ryohei Ikegami, Semun Lee, Seo-Inyoung, Sergey Mironov, Sharada Shiddibhavi, ShengYang1, Shraiysh Vaishay, Shunya Ueta, shwetaoj, Siyavash Najafzade, Srinivasan Narayanamoorthy, Stephan Uphoff, storypku, sunchenggen, sunway513, Sven-Hendrik Haase, Swapnil Parekh, Tamas Bela Feher, Teng Lu, tigertang, tomas, Tomohiro Ubukata, tongxuan.ltx, Tony Tonev, Tzu-Wei Huang, Téo Bouvard, Uday Bondhugula, Vaibhav Jade, Vijay Tadikamalla, Vikram Dattu, Vincent Abriou, Vishnuvardhan Janapati, Vo Van Nghia, VoVAllen, Will Battel, William D. Irons, wyzhao, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, xutianming, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yixing Fu, Yong Tang, Yuan Tang, zhaozheng09, Zilin Zhu, zilinzhu, 张志豪
|
||||||
|
|
||||||
* `tf.image.extract_glimpse` has been updated to correctly process the case
|
|
||||||
where `centered=False` and `normalized=False`. This is a breaking change as
|
|
||||||
the output is different from (incorrect) previous versions. Note this
|
|
||||||
breaking change only impacts `tf.image.extract_glimpse` and
|
|
||||||
`tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of
|
|
||||||
`tf.compat.v1.image.extract_glimpse` does not change. The behavior of
|
|
||||||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
|
||||||
models will not be impacted.
|
|
||||||
|
|
||||||
# Release 2.1.1
|
# Release 2.1.1
|
||||||
|
|
||||||
|
12
WORKSPACE
12
WORKSPACE
@ -18,6 +18,18 @@ load("//tensorflow:workspace.bzl", "tf_repositories")
|
|||||||
# Please add all new TensorFlow dependencies in workspace.bzl.
|
# Please add all new TensorFlow dependencies in workspace.bzl.
|
||||||
tf_repositories()
|
tf_repositories()
|
||||||
|
|
||||||
|
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
|
||||||
|
|
||||||
|
git_repository(
|
||||||
|
name = "com_github_nelhage_rules_boost",
|
||||||
|
commit = "1e3a69bf2d5cd10c34b74f066054cd335d033d71",
|
||||||
|
remote = "https://github.com/nelhage/rules_boost",
|
||||||
|
shallow_since = "1591047380 -0700",
|
||||||
|
)
|
||||||
|
|
||||||
|
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||||
|
boost_deps()
|
||||||
|
|
||||||
register_toolchains("@local_config_python//:py_toolchain")
|
register_toolchains("@local_config_python//:py_toolchain")
|
||||||
|
|
||||||
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
|
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
|
||||||
|
1
native_client
Symbolic link
1
native_client
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../native_client
|
@ -221,8 +221,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype,
|
|||||||
// Wraps the deleter function of DLManagedTensor to match the function signature
|
// Wraps the deleter function of DLManagedTensor to match the function signature
|
||||||
// TFE_NewTensorHandleFromDeviceMemory.
|
// TFE_NewTensorHandleFromDeviceMemory.
|
||||||
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
|
||||||
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr);
|
TFE_CallDLManagedTensorDeleter(dlmt_vptr);
|
||||||
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks whether the stride array matches the layout of compact, row-majored
|
// Checks whether the stride array matches the layout of compact, row-majored
|
||||||
@ -324,7 +323,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
|
|||||||
|
|
||||||
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
|
||||||
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
|
||||||
total_bytes, &DeallocatorWrapperFunc, &dlmt, status);
|
total_bytes, &DeallocatorWrapperFunc, dlmt, status);
|
||||||
|
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
stream->ThenRecordEvent(definition_event.get());
|
stream->ThenRecordEvent(definition_event.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<TensorShape> output_tensor_shapes;
|
||||||
|
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||||
|
if (output.on_host_shape().is_dynamic()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto transfer_manager,
|
||||||
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
|
|
||||||
|
xla::Shape output_host_shape = output.on_host_shape();
|
||||||
|
xla::Shape output_device_shape = output.on_device_shape();
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||||
|
stream, &output, &output_host_shape, &output_device_shape));
|
||||||
|
|
||||||
|
output.set_shapes(output_host_shape, output_device_shape);
|
||||||
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
|
const xla::Shape& subshape =
|
||||||
|
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||||
|
output_tensor_shapes.push_back(shape);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
|
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Copy XLA results to the OpOutputList.
|
// Copy XLA results to the OpOutputList.
|
||||||
int output_num = 0;
|
int output_num = 0;
|
||||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
const TensorShape& shape = output_tensor_shapes[i];
|
||||||
const DataType& type = compilation_result->outputs[i].type;
|
const DataType& type = compilation_result->outputs[i].type;
|
||||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||||
<< DataTypeString(type);
|
<< DataTypeString(type);
|
||||||
|
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
|||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_image_ops
|
from tensorflow.python.ops import gen_image_ops
|
||||||
from tensorflow.python.ops import image_ops
|
from tensorflow.python.ops import image_ops
|
||||||
@ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS128From1024(self):
|
def testNMS128From1024(self):
|
||||||
num_boxes = 1024
|
num_boxes = 1024
|
||||||
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
||||||
@ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
self.assertEqual(indices_tf.size, max_output_size)
|
self.assertEqual(indices_tf.size, max_output_size)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3From6Boxes(self):
|
def testNMS3From6Boxes(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||||
@ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
|||||||
self.assertEqual(num_valid, 3)
|
self.assertEqual(num_valid, 3)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3Then2WithScoreThresh(self):
|
def testNMS3Then2WithScoreThresh(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
# One is filtered out by score threshold.
|
# One is filtered out by score threshold.
|
||||||
@ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
|||||||
self.assertEqual(num_valid, 2)
|
self.assertEqual(num_valid, 2)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3Then1WithScoreMaxThresh(self):
|
def testNMS3Then1WithScoreMaxThresh(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
# One is filtered out by score threshold.
|
# One is filtered out by score threshold.
|
||||||
@ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
|||||||
self.assertEqual(num_valid, 1)
|
self.assertEqual(num_valid, 1)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3])
|
self.assertAllClose(indices_tf[:num_valid], [3])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testSelectFromContinuousOverLap(self):
|
def testSelectFromContinuousOverLap(self):
|
||||||
# Tests that a suppressed box does not itself suppress other boxes.
|
# Tests that a suppressed box does not itself suppress other boxes.
|
||||||
|
|
||||||
@ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6(self):
|
def testBatchedNMSFrom6(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
@ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6Max3(self):
|
def testBatchedNMSFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
@ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
||||||
self.assertAllEqual([3, 3], num_valid_output)
|
self.assertAllEqual([3, 3], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSSingleFrom6Max3(self):
|
def testBatchedNMSSingleFrom6Max3(self):
|
||||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||||
@ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([0, 1, 2], indices_output)
|
self.assertAllEqual([0, 1, 2], indices_output)
|
||||||
self.assertAllEqual(3, num_valid_output)
|
self.assertAllEqual(3, num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSSingleFrom6NoPad(self):
|
def testBatchedNMSSingleFrom6NoPad(self):
|
||||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||||
@ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
||||||
self.assertAllEqual(5, num_valid_output)
|
self.assertAllEqual(5, num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSBatchDimsFrom6Max3(self):
|
def testBatchedNMSBatchDimsFrom6Max3(self):
|
||||||
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
@ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
||||||
self.assertAllEqual([[3, 3]], num_valid_output)
|
self.assertAllEqual([[3, 3]], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
@ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([3, 2], num_valid_output)
|
self.assertAllEqual([3, 2], num_valid_output)
|
||||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSUnsortedInputFrom6(self):
|
def testBatchedNMSUnsortedInputFrom6(self):
|
||||||
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
||||||
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
||||||
@ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
||||||
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
||||||
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
||||||
@ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
@ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
|||||||
self.assertAllEqual([3, 2], num_valid_output)
|
self.assertAllEqual([3, 2], num_valid_output)
|
||||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6DynamicInput(self):
|
def testBatchedNMSFrom6DynamicInput(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1202,6 +1202,9 @@ cc_library(
|
|||||||
srcs = ["transfer_manager.cc"],
|
srcs = ["transfer_manager.cc"],
|
||||||
hdrs = ["transfer_manager.h"],
|
hdrs = ["transfer_manager.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":compiler",
|
||||||
|
":executable",
|
||||||
|
":maybe_owning_device_memory",
|
||||||
":shaped_buffer",
|
":shaped_buffer",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -1210,8 +1213,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/stream_executor:device_memory",
|
"//tensorflow/stream_executor:device_memory",
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -33,6 +34,7 @@ limitations under the License.
|
|||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
/* static */ tensorflow::mutex
|
/* static */ tensorflow::mutex
|
||||||
TransferManager::platform_transfer_manager_mutex_(
|
TransferManager::platform_transfer_manager_mutex_(
|
||||||
tensorflow::LINKER_INITIALIZED);
|
tensorflow::LINKER_INITIALIZED);
|
||||||
@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice(
|
|||||||
std::move(done), transfer_metadata);
|
std::move(done), transfer_metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||||
|
ShapedBuffer* device_buffer,
|
||||||
|
Shape* host_shape,
|
||||||
|
Shape* device_shape) {
|
||||||
|
DCHECK(device_shape->is_dynamic());
|
||||||
|
Shape original_device_shape = *device_shape;
|
||||||
|
Shape original_host_shape = *host_shape;
|
||||||
|
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto compiler,
|
||||||
|
Compiler::GetForPlatform(stream->parent()->platform()));
|
||||||
|
TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
|
||||||
|
[&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
|
const Shape& buffer_shape =
|
||||||
|
ShapeUtil::GetSubshape(*device_shape, index);
|
||||||
|
if (buffer_shape.IsTuple()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Shape& host_sub_shape =
|
||||||
|
*ShapeUtil::GetMutableSubshape(host_shape, index);
|
||||||
|
Shape& device_sub_shape =
|
||||||
|
*ShapeUtil::GetMutableSubshape(device_shape, index);
|
||||||
|
if (device_sub_shape.is_static()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the dynamic shape metadata from the device stream.
|
||||||
|
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||||
|
Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
|
||||||
|
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||||
|
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||||
|
if (metadata_size == 0) {
|
||||||
|
return InvalidArgument("Dynamic shape metadata size should not be 0");
|
||||||
|
}
|
||||||
|
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
||||||
|
auto metadata_buffer =
|
||||||
|
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto metadata,
|
||||||
|
TransferArrayFromDevice(
|
||||||
|
stream,
|
||||||
|
ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
|
||||||
|
metadata_buffer));
|
||||||
|
|
||||||
|
// Update shape size from metadata.
|
||||||
|
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||||
|
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||||
|
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
host_shape->clear_dynamic_dimensions();
|
||||||
|
device_shape->clear_dynamic_dimensions();
|
||||||
|
|
||||||
|
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
|
||||||
|
original_device_shape));
|
||||||
|
TF_RET_CHECK(
|
||||||
|
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ void TransferManager::RegisterTransferManager(
|
/* static */ void TransferManager::RegisterTransferManager(
|
||||||
se::Platform::Id platform_id,
|
se::Platform::Id platform_id,
|
||||||
TransferManagerCreationFunction creation_function) {
|
TransferManagerCreationFunction creation_function) {
|
||||||
|
@ -184,6 +184,15 @@ class TransferManager {
|
|||||||
const se::DeviceMemoryBase& source,
|
const se::DeviceMemoryBase& source,
|
||||||
const TransferMetadata* transfer_metadata = nullptr);
|
const TransferMetadata* transfer_metadata = nullptr);
|
||||||
|
|
||||||
|
// Read from a device buffer and update the dynamic dimension sizes of
|
||||||
|
// `host_shape` and `device_shape`. The function takes in bounded dynamic
|
||||||
|
// shapes, and returns static shapes with dynamic shapes updated.
|
||||||
|
// The shape of the buffer also have to be compatible with the host shape and
|
||||||
|
// device shape.
|
||||||
|
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||||
|
ShapedBuffer* device_buffer,
|
||||||
|
Shape* host_shape, Shape* device_shape);
|
||||||
|
|
||||||
// Transfers the given literal into the Infeed interface of the device,
|
// Transfers the given literal into the Infeed interface of the device,
|
||||||
// using the given executor.
|
// using the given executor.
|
||||||
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||||
|
@ -264,86 +264,28 @@ Status UpdateDynamicInputs(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
|
||||||
se::Stream* stream, se::DeviceMemoryBase buffer,
|
|
||||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
|
||||||
stream->parent()->platform()));
|
|
||||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
|
||||||
xla::Shape buffer_shape_static =
|
|
||||||
xla::ShapeUtil::MakeStaticShape(buffer_shape);
|
|
||||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
|
||||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
|
||||||
TF_RET_CHECK(metadata_size != 0);
|
|
||||||
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
|
|
||||||
auto metadata_buffer =
|
|
||||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
|
||||||
return transfer_manager->TransferArrayFromDevice(
|
|
||||||
stream,
|
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
|
|
||||||
metadata_buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each subshape in the result buffer that's dynamic, read the dynamic
|
|
||||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
|
||||||
// is a static and concrete shape.
|
|
||||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
|
||||||
const xla::ShapedBuffer& shaped_buffer,
|
|
||||||
xla::Shape* output_host_shape,
|
|
||||||
xla::Shape* output_device_shape) {
|
|
||||||
DCHECK(output_device_shape->is_dynamic());
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
auto transfer_manager,
|
|
||||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
|
||||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
|
||||||
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
|
|
||||||
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
|
||||||
const xla::Shape& buffer_shape =
|
|
||||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
|
||||||
if (buffer_shape.IsTuple()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
xla::Shape& host_shape =
|
|
||||||
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
|
|
||||||
xla::Shape& device_shape =
|
|
||||||
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
|
|
||||||
if (device_shape.is_static()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(auto metadata,
|
|
||||||
ReadMetadataLiteral(stream, buffer, buffer_shape,
|
|
||||||
transfer_manager));
|
|
||||||
// Update shape size from metadata.
|
|
||||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
|
||||||
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
|
||||||
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
output_host_shape->clear_dynamic_dimensions();
|
|
||||||
output_device_shape->clear_dynamic_dimensions();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||||
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
XRTTupleAllocation* output_tuple;
|
XRTTupleAllocation* output_tuple;
|
||||||
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
|
xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
|
||||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
if (shaped_buffer->on_device_shape().is_dynamic()) {
|
||||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||||
// dimension sizes read from metadata.
|
// dimension sizes read from metadata.
|
||||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
|
||||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
|
||||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto transfer_manager,
|
||||||
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||||
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
*shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||||
device_ordinal, &output_tuple));
|
device_ordinal, &output_tuple));
|
||||||
} else {
|
} else {
|
||||||
// Fast-path: Don't copy shapes of output buffer.
|
// Fast-path: Don't copy shapes of output buffer.
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
*shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||||
}
|
}
|
||||||
// After the output tuple is created, we can release the output result
|
// After the output tuple is created, we can release the output result
|
||||||
// buffers, to make sure they won't be cleared by its destructor.
|
// buffers, to make sure they won't be cleared by its destructor.
|
||||||
|
@ -28,8 +28,8 @@ tf_proto_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
name = "master_proto",
|
name = "dispatcher_proto",
|
||||||
srcs = ["master.proto"],
|
srcs = ["dispatcher.proto"],
|
||||||
has_services = 1,
|
has_services = 1,
|
||||||
cc_api_version = 2,
|
cc_api_version = 2,
|
||||||
protodeps = tf_additional_all_protos() + [
|
protodeps = tf_additional_all_protos() + [
|
||||||
@ -49,17 +49,17 @@ tf_proto_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "master_impl",
|
name = "dispatcher_impl",
|
||||||
srcs = ["master_impl.cc"],
|
srcs = ["dispatcher_impl.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"master_impl.h",
|
"dispatcher_impl.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
":data_service",
|
":data_service",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_proto_cc",
|
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
@ -86,9 +86,9 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
@ -207,12 +207,12 @@ tf_cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "grpc_master_impl",
|
name = "grpc_dispatcher_impl",
|
||||||
srcs = ["grpc_master_impl.cc"],
|
srcs = ["grpc_dispatcher_impl.cc"],
|
||||||
hdrs = ["grpc_master_impl.h"],
|
hdrs = ["grpc_dispatcher_impl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":master_cc_grpc_proto",
|
":dispatcher_cc_grpc_proto",
|
||||||
":master_impl",
|
":dispatcher_impl",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
@ -250,7 +250,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
":grpc_master_impl",
|
":grpc_dispatcher_impl",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":grpc_worker_impl",
|
":grpc_worker_impl",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -268,9 +268,9 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
":worker_proto_cc",
|
":worker_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -287,12 +287,12 @@ tf_cc_test(
|
|||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
deps = [
|
deps = [
|
||||||
":data_service",
|
":data_service",
|
||||||
":grpc_master_impl",
|
":dispatcher_cc_grpc_proto",
|
||||||
|
":dispatcher_proto_cc",
|
||||||
|
":grpc_dispatcher_impl",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":grpc_worker_impl",
|
":grpc_worker_impl",
|
||||||
":local_credentials_factory",
|
":local_credentials_factory",
|
||||||
":master_cc_grpc_proto",
|
|
||||||
":master_proto_cc",
|
|
||||||
":server_lib",
|
":server_lib",
|
||||||
":test_cluster",
|
":test_cluster",
|
||||||
":test_util",
|
":test_util",
|
||||||
@ -309,11 +309,11 @@ tf_cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_grpc_library(
|
cc_grpc_library(
|
||||||
name = "master_cc_grpc_proto",
|
name = "dispatcher_cc_grpc_proto",
|
||||||
srcs = [":master_proto"],
|
srcs = [":dispatcher_proto"],
|
||||||
generate_mocks = True,
|
generate_mocks = True,
|
||||||
grpc_only = True,
|
grpc_only = True,
|
||||||
deps = [":master_proto_cc"],
|
deps = [":dispatcher_proto_cc"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_grpc_library(
|
cc_grpc_library(
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
#include "grpcpp/create_channel.h"
|
#include "grpcpp/create_channel.h"
|
||||||
#include "grpcpp/security/credentials.h"
|
#include "grpcpp/security/credentials.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
|
||||||
@ -54,8 +54,8 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
|
||||||
int64* dataset_id) {
|
int64* dataset_id) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetOrRegisterDatasetRequest req;
|
GetOrRegisterDatasetRequest req;
|
||||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||||
@ -69,9 +69,9 @@ Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
Status DataServiceDispatcherClient::CreateJob(int64 dataset_id,
|
||||||
ProcessingMode processing_mode,
|
ProcessingMode processing_mode,
|
||||||
int64* job_id) {
|
int64* job_id) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
CreateJobRequest req;
|
CreateJobRequest req;
|
||||||
req.set_dataset_id(dataset_id);
|
req.set_dataset_id(dataset_id);
|
||||||
@ -88,11 +88,9 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
Status DataServiceDispatcherClient::GetOrCreateJob(
|
||||||
ProcessingMode processing_mode,
|
int64 dataset_id, ProcessingMode processing_mode,
|
||||||
const std::string& job_name,
|
const std::string& job_name, int job_name_index, int64* job_id) {
|
||||||
int job_name_index,
|
|
||||||
int64* job_id) {
|
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetOrCreateJobRequest req;
|
GetOrCreateJobRequest req;
|
||||||
req.set_dataset_id(dataset_id);
|
req.set_dataset_id(dataset_id);
|
||||||
@ -112,9 +110,9 @@ Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetTasks(int64 job_id,
|
Status DataServiceDispatcherClient::GetTasks(int64 job_id,
|
||||||
std::vector<TaskInfo>* tasks,
|
std::vector<TaskInfo>* tasks,
|
||||||
bool* job_finished) {
|
bool* job_finished) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetTasksRequest req;
|
GetTasksRequest req;
|
||||||
req.set_job_id(job_id);
|
req.set_job_id(job_id);
|
||||||
@ -132,7 +130,8 @@ Status DataServiceMasterClient::GetTasks(int64 job_id,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
Status DataServiceDispatcherClient::GetWorkers(
|
||||||
|
std::vector<WorkerInfo>* workers) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetWorkersRequest req;
|
GetWorkersRequest req;
|
||||||
GetWorkersResponse resp;
|
GetWorkersResponse resp;
|
||||||
@ -148,12 +147,12 @@ Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::EnsureInitialized() {
|
Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||||
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
std::shared_ptr<grpc::ChannelCredentials> credentials;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||||
auto channel = grpc::CreateChannel(address_, credentials);
|
auto channel = grpc::CreateChannel(address_, credentials);
|
||||||
stub_ = MasterService::NewStub(channel);
|
stub_ = DispatcherService::NewStub(channel);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -187,10 +186,11 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceDispatcherClient(
|
||||||
const std::string& address, const std::string& protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
std::unique_ptr<DataServiceDispatcherClient>* out) {
|
||||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
auto client =
|
||||||
|
absl::make_unique<DataServiceDispatcherClient>(address, protocol);
|
||||||
TF_RETURN_IF_ERROR(client->Initialize());
|
TF_RETURN_IF_ERROR(client->Initialize());
|
||||||
*out = std::move(client);
|
*out = std::move(client);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -67,11 +67,11 @@ class DataServiceClientBase {
|
|||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Client for communicating with the tf.data service master.
|
// Client for communicating with the tf.data service dispatcher.
|
||||||
class DataServiceMasterClient : public DataServiceClientBase {
|
class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||||
public:
|
public:
|
||||||
DataServiceMasterClient(const std::string& address,
|
DataServiceDispatcherClient(const std::string& address,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: DataServiceClientBase(address, protocol) {}
|
: DataServiceClientBase(address, protocol) {}
|
||||||
|
|
||||||
// Registers a dataset with the tf.data service, and stores the generated
|
// Registers a dataset with the tf.data service, and stores the generated
|
||||||
@ -90,13 +90,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||||||
const std::string& job_name, int job_name_index,
|
const std::string& job_name, int job_name_index,
|
||||||
int64* job_id);
|
int64* job_id);
|
||||||
|
|
||||||
// Queries the master for the tasks associated with the specified job.
|
// Queries the dispatcher for the tasks associated with the specified job.
|
||||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||||
// be stored in `*job_finished`.
|
// be stored in `*job_finished`.
|
||||||
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks,
|
||||||
bool* job_finished);
|
bool* job_finished);
|
||||||
|
|
||||||
// Queries the master for its registered workers. The worker info will be
|
// Queries the dispatcher for its registered workers. The worker info will be
|
||||||
// stored in `*workers`.
|
// stored in `*workers`.
|
||||||
Status GetWorkers(std::vector<WorkerInfo>* workers);
|
Status GetWorkers(std::vector<WorkerInfo>* workers);
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||||||
Status EnsureInitialized() override;
|
Status EnsureInitialized() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<MasterService::Stub> stub_;
|
std::unique_ptr<DispatcherService::Stub> stub_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Client for communicating with the tf.data service worker.
|
// Client for communicating with the tf.data service worker.
|
||||||
@ -127,10 +127,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
|||||||
std::unique_ptr<WorkerService::Stub> stub_;
|
std::unique_ptr<WorkerService::Stub> stub_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates and initializes a new tf.data service master client.
|
// Creates and initializes a new tf.data service dispatcher client.
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceDispatcherClient(
|
||||||
const std::string& address, const std::string& protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out);
|
std::unique_ptr<DataServiceDispatcherClient>* out);
|
||||||
|
|
||||||
// Creates and initializes a new tf.data service worker client.
|
// Creates and initializes a new tf.data service worker client.
|
||||||
Status CreateDataServiceWorkerClient(
|
Status CreateDataServiceWorkerClient(
|
||||||
|
@ -19,9 +19,9 @@ limitations under the License.
|
|||||||
#include "grpcpp/security/credentials.h"
|
#include "grpcpp/security/credentials.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/core/data/compression_utils.h"
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/server_lib.h"
|
#include "tensorflow/core/data/service/server_lib.h"
|
||||||
#include "tensorflow/core/data/service/test_cluster.h"
|
#include "tensorflow/core/data/service/test_cluster.h"
|
||||||
#include "tensorflow/core/data/service/test_util.h"
|
#include "tensorflow/core/data/service/test_util.h"
|
||||||
@ -66,9 +66,10 @@ TEST(DataService, ProcessingModeToString) {
|
|||||||
TEST(DataService, GetWorkers) {
|
TEST(DataService, GetWorkers) {
|
||||||
TestCluster cluster(1);
|
TestCluster cluster(1);
|
||||||
TF_ASSERT_OK(cluster.Initialize());
|
TF_ASSERT_OK(cluster.Initialize());
|
||||||
DataServiceMasterClient master(cluster.MasterAddress(), kProtocol);
|
DataServiceDispatcherClient dispatcher(cluster.DispatcherAddress(),
|
||||||
|
kProtocol);
|
||||||
std::vector<WorkerInfo> workers;
|
std::vector<WorkerInfo> workers;
|
||||||
TF_EXPECT_OK(master.GetWorkers(&workers));
|
TF_EXPECT_OK(dispatcher.GetWorkers(&workers));
|
||||||
EXPECT_EQ(1, workers.size());
|
EXPECT_EQ(1, workers.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,11 +110,11 @@ message GetWorkersResponse {
|
|||||||
repeated WorkerInfo workers = 1;
|
repeated WorkerInfo workers = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
service MasterService {
|
service DispatcherService {
|
||||||
// Registers a worker with the master.
|
// Registers a worker with the dispatcher.
|
||||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
||||||
|
|
||||||
// Updates the master with information about the worker's state.
|
// Updates the dispatcher with information about the worker's state.
|
||||||
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
rpc WorkerUpdate(WorkerUpdateRequest) returns (WorkerUpdateResponse);
|
||||||
|
|
||||||
// Registers a dataset with the server, or returns its id if it is already
|
// Registers a dataset with the server, or returns its id if it is already
|
||||||
@ -134,6 +134,6 @@ service MasterService {
|
|||||||
// Reports a list of all tasks for a job.
|
// Reports a list of all tasks for a job.
|
||||||
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
||||||
|
|
||||||
// Reports a list of all workers registered with the master.
|
// Reports a list of all workers registered with the dispatcher.
|
||||||
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
|
rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse);
|
||||||
}
|
}
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/master_impl.h"
|
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
@ -26,8 +26,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
@ -53,10 +53,10 @@ Status CreateWorkerStub(const std::string& address,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DataServiceMasterImpl::DataServiceMasterImpl(const std::string protocol)
|
DataServiceDispatcherImpl::DataServiceDispatcherImpl(const std::string protocol)
|
||||||
: protocol_(protocol) {}
|
: protocol_(protocol) {}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::RegisterWorker(
|
Status DataServiceDispatcherImpl::RegisterWorker(
|
||||||
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
const RegisterWorkerRequest* request, RegisterWorkerResponse* response) {
|
||||||
VLOG(3) << "Received register worker request";
|
VLOG(3) << "Received register worker request";
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -86,8 +86,8 @@ Status DataServiceMasterImpl::RegisterWorker(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
Status DataServiceDispatcherImpl::WorkerUpdate(
|
||||||
WorkerUpdateResponse* response) {
|
const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
int64 worker_id = request->worker_id();
|
int64 worker_id = request->worker_id();
|
||||||
for (auto& update : request->updates()) {
|
for (auto& update : request->updates()) {
|
||||||
@ -106,7 +106,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetOrRegisterDataset(
|
Status DataServiceDispatcherImpl::GetOrRegisterDataset(
|
||||||
const GetOrRegisterDatasetRequest* request,
|
const GetOrRegisterDatasetRequest* request,
|
||||||
GetOrRegisterDatasetResponse* response) {
|
GetOrRegisterDatasetResponse* response) {
|
||||||
uint64 fingerprint;
|
uint64 fingerprint;
|
||||||
@ -128,8 +128,8 @@ Status DataServiceMasterImpl::GetOrRegisterDataset(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
int64 DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
|
||||||
const DatasetDef& dataset)
|
const DatasetDef& dataset)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
int64 dataset_id = next_dataset_id_++;
|
int64 dataset_id = next_dataset_id_++;
|
||||||
auto new_dataset =
|
auto new_dataset =
|
||||||
@ -142,8 +142,8 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
|||||||
return dataset_id;
|
return dataset_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
Status DataServiceDispatcherImpl::CreateJob(const CreateJobRequest* request,
|
||||||
CreateJobResponse* response) {
|
CreateJobResponse* response) {
|
||||||
VLOG(3) << "Received create job request for dataset id "
|
VLOG(3) << "Received create job request for dataset id "
|
||||||
<< request->dataset_id();
|
<< request->dataset_id();
|
||||||
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
||||||
@ -157,7 +157,7 @@ Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetOrCreateJob(
|
Status DataServiceDispatcherImpl::GetOrCreateJob(
|
||||||
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
||||||
VLOG(3) << "Received get or create job request for dataset id "
|
VLOG(3) << "Received get or create job request for dataset id "
|
||||||
<< request->dataset_id() << " with name " << request->job_name()
|
<< request->dataset_id() << " with name " << request->job_name()
|
||||||
@ -193,7 +193,7 @@ Status DataServiceMasterImpl::GetOrCreateJob(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Validates that the job matches the given processing_mode and dataset_id.
|
// Validates that the job matches the given processing_mode and dataset_id.
|
||||||
Status DataServiceMasterImpl::ValidateMatchingJob(
|
Status DataServiceDispatcherImpl::ValidateMatchingJob(
|
||||||
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
||||||
DCHECK(job.name().has_value());
|
DCHECK(job.name().has_value());
|
||||||
std::string job_name = job.name().value();
|
std::string job_name = job.name().value();
|
||||||
@ -214,10 +214,10 @@ Status DataServiceMasterImpl::ValidateMatchingJob(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
Status DataServiceDispatcherImpl::CreateJob(
|
||||||
ProcessingMode processing_mode,
|
int64 dataset_id, ProcessingMode processing_mode,
|
||||||
absl::optional<std::string> job_name,
|
absl::optional<std::string> job_name, int64* out_job_id)
|
||||||
int64* out_job_id) LOCKS_EXCLUDED(mu_) {
|
LOCKS_EXCLUDED(mu_) {
|
||||||
switch (processing_mode) {
|
switch (processing_mode) {
|
||||||
case ProcessingMode::PARALLEL_EPOCHS:
|
case ProcessingMode::PARALLEL_EPOCHS:
|
||||||
break;
|
break;
|
||||||
@ -274,14 +274,16 @@ Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTask(
|
const DataServiceDispatcherImpl::Task& DataServiceDispatcherImpl::CreateTask(
|
||||||
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
Job* job, const std::string& worker_address) LOCKS_EXCLUDED(mu_) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
return CreateTaskLocked(job, worker_address);
|
return CreateTaskLocked(job, worker_address);
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
const DataServiceDispatcherImpl::Task&
|
||||||
Job* job, const std::string& worker_address) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
DataServiceDispatcherImpl::CreateTaskLocked(Job* job,
|
||||||
|
const std::string& worker_address)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
int64 task_id = next_task_id_++;
|
int64 task_id = next_task_id_++;
|
||||||
DCHECK(!tasks_.contains(task_id));
|
DCHECK(!tasks_.contains(task_id));
|
||||||
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
tasks_.insert({task_id, Task(task_id, job->job_id(), job->dataset_id(),
|
||||||
@ -290,7 +292,7 @@ const DataServiceMasterImpl::Task& DataServiceMasterImpl::CreateTaskLocked(
|
|||||||
return tasks_.at(task_id);
|
return tasks_.at(task_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
Status DataServiceDispatcherImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
||||||
if (!worker->stub()) {
|
if (!worker->stub()) {
|
||||||
std::unique_ptr<WorkerService::Stub> stub;
|
std::unique_ptr<WorkerService::Stub> stub;
|
||||||
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
TF_RETURN_IF_ERROR(CreateWorkerStub(worker->address(), protocol_, &stub));
|
||||||
@ -299,8 +301,8 @@ Status DataServiceMasterImpl::EnsureWorkerStubInitialized(Worker* worker) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
Status DataServiceDispatcherImpl::AllocateTaskToWorker(const Task& task,
|
||||||
Worker* worker)
|
Worker* worker)
|
||||||
LOCKS_EXCLUDED(mu_) {
|
LOCKS_EXCLUDED(mu_) {
|
||||||
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
TF_RETURN_IF_ERROR(EnsureWorkerStubInitialized(worker));
|
||||||
grpc::ClientContext client_ctx;
|
grpc::ClientContext client_ctx;
|
||||||
@ -322,8 +324,8 @@ Status DataServiceMasterImpl::AllocateTaskToWorker(const Task& task,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
Status DataServiceDispatcherImpl::GetTasks(const GetTasksRequest* request,
|
||||||
GetTasksResponse* response) {
|
GetTasksResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(3) << "Looking up tasks for job id " << request->job_id();
|
VLOG(3) << "Looking up tasks for job id " << request->job_id();
|
||||||
auto it = jobs_.find(request->job_id());
|
auto it = jobs_.find(request->job_id());
|
||||||
@ -346,8 +348,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request,
|
Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
|
||||||
GetWorkersResponse* response) {
|
GetWorkersResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(3) << "Enter GetWorkers";
|
VLOG(3) << "Enter GetWorkers";
|
||||||
for (auto& worker : workers_) {
|
for (auto& worker : workers_) {
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
@ -40,11 +40,11 @@ namespace data {
|
|||||||
// ProcessingModeDef which determines what data it produces.
|
// ProcessingModeDef which determines what data it produces.
|
||||||
// * Task: A job is broken into multiple tasks, which each represent
|
// * Task: A job is broken into multiple tasks, which each represent
|
||||||
// iterating over all of or part of the dataset. Workers process tasks.
|
// iterating over all of or part of the dataset. Workers process tasks.
|
||||||
class DataServiceMasterImpl {
|
class DataServiceDispatcherImpl {
|
||||||
public:
|
public:
|
||||||
explicit DataServiceMasterImpl(const std::string protocol);
|
explicit DataServiceDispatcherImpl(const std::string protocol);
|
||||||
|
|
||||||
// See master.proto for API documentation.
|
// See dispatcher.proto for API documentation.
|
||||||
|
|
||||||
/// Worker-facing API.
|
/// Worker-facing API.
|
||||||
Status RegisterWorker(const RegisterWorkerRequest* request,
|
Status RegisterWorker(const RegisterWorkerRequest* request,
|
||||||
@ -191,7 +191,7 @@ class DataServiceMasterImpl {
|
|||||||
// Creates a new task for a job, returning a reference to the task.
|
// Creates a new task for a job, returning a reference to the task.
|
||||||
const Task& CreateTask(Job* job, const std::string& worker_address)
|
const Task& CreateTask(Job* job, const std::string& worker_address)
|
||||||
LOCKS_EXCLUDED(mu_);
|
LOCKS_EXCLUDED(mu_);
|
||||||
// Same as `CreateTask`, but expects that the master lock is already held.
|
// Same as `CreateTask`, but expects that the dispatcher lock is already held.
|
||||||
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
const Task& CreateTaskLocked(Job* job, const std::string& worker_address)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
// Validates that an existing job matches the given processing_mode and
|
// Validates that an existing job matches the given processing_mode and
|
||||||
@ -225,10 +225,10 @@ class DataServiceMasterImpl {
|
|||||||
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_MASTER_IMPL_H_
|
#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||||
|
|
||||||
#include "grpcpp/server_context.h"
|
#include "grpcpp/server_context.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
@ -25,18 +25,18 @@ using ::grpc::ServerBuilder;
|
|||||||
using ::grpc::ServerContext;
|
using ::grpc::ServerContext;
|
||||||
using ::grpc::Status;
|
using ::grpc::Status;
|
||||||
|
|
||||||
GrpcMasterImpl::GrpcMasterImpl(ServerBuilder* server_builder,
|
GrpcDispatcherImpl::GrpcDispatcherImpl(ServerBuilder* server_builder,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: impl_(protocol) {
|
: impl_(protocol) {
|
||||||
server_builder->RegisterService(this);
|
server_builder->RegisterService(this);
|
||||||
VLOG(1) << "Registered data service master";
|
VLOG(1) << "Registered data service dispatcher";
|
||||||
}
|
}
|
||||||
|
|
||||||
#define HANDLER(method) \
|
#define HANDLER(method) \
|
||||||
Status GrpcMasterImpl::method(ServerContext* context, \
|
Status GrpcDispatcherImpl::method(ServerContext* context, \
|
||||||
const method##Request* request, \
|
const method##Request* request, \
|
||||||
method##Response* response) { \
|
method##Response* response) { \
|
||||||
return ToGrpcStatus(impl_.method(request, response)); \
|
return ToGrpcStatus(impl_.method(request, response)); \
|
||||||
}
|
}
|
||||||
HANDLER(RegisterWorker);
|
HANDLER(RegisterWorker);
|
||||||
HANDLER(WorkerUpdate);
|
HANDLER(WorkerUpdate);
|
@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
||||||
|
|
||||||
#include "grpcpp/server_builder.h"
|
#include "grpcpp/server_builder.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/master_impl.h"
|
#include "tensorflow/core/data/service/dispatcher_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -29,14 +29,14 @@ namespace data {
|
|||||||
//
|
//
|
||||||
// ::grpc::ServerBuilder builder;
|
// ::grpc::ServerBuilder builder;
|
||||||
// // configure builder
|
// // configure builder
|
||||||
// GrpcMasterImpl data_service(&builder);
|
// GrpcDispatcherImpl data_service(&builder);
|
||||||
// builder.BuildAndStart()
|
// builder.BuildAndStart()
|
||||||
//
|
//
|
||||||
class GrpcMasterImpl : public MasterService::Service {
|
class GrpcDispatcherImpl : public DispatcherService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcMasterImpl(grpc::ServerBuilder* server_builder,
|
explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~GrpcMasterImpl() override {}
|
~GrpcDispatcherImpl() override {}
|
||||||
|
|
||||||
#define HANDLER(method) \
|
#define HANDLER(method) \
|
||||||
grpc::Status method(grpc::ServerContext* context, \
|
grpc::Status method(grpc::ServerContext* context, \
|
||||||
@ -52,12 +52,12 @@ class GrpcMasterImpl : public MasterService::Service {
|
|||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DataServiceMasterImpl impl_;
|
DataServiceDispatcherImpl impl_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(GrpcDispatcherImpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_MASTER_IMPL_H_
|
#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_
|
@ -26,9 +26,9 @@ using ::grpc::ServerContext;
|
|||||||
using ::grpc::Status;
|
using ::grpc::Status;
|
||||||
|
|
||||||
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: impl_(master_address, protocol) {
|
: impl_(dispatcher_address, protocol) {
|
||||||
server_builder->RegisterService(this);
|
server_builder->RegisterService(this);
|
||||||
VLOG(1) << "Registered data service worker";
|
VLOG(1) << "Registered data service worker";
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ namespace data {
|
|||||||
class GrpcWorkerImpl : public WorkerService::Service {
|
class GrpcWorkerImpl : public WorkerService::Service {
|
||||||
public:
|
public:
|
||||||
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~GrpcWorkerImpl() override {}
|
~GrpcWorkerImpl() override {}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/data/service/server_lib.h"
|
#include "tensorflow/core/data/service/server_lib.h"
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
#include "tensorflow/core/data/service/grpc_master_impl.h"
|
#include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/grpc_worker_impl.h"
|
#include "tensorflow/core/data/service/grpc_worker_impl.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
@ -72,18 +72,18 @@ void GrpcDataServerBase::Join() { server_->Wait(); }
|
|||||||
|
|
||||||
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
int GrpcDataServerBase::BoundPort() { return bound_port(); }
|
||||||
|
|
||||||
MasterGrpcDataServer::MasterGrpcDataServer(int port,
|
DispatchGrpcDataServer::DispatchGrpcDataServer(int port,
|
||||||
const std::string& protocol)
|
const std::string& protocol)
|
||||||
: GrpcDataServerBase(port, protocol) {}
|
: GrpcDataServerBase(port, protocol) {}
|
||||||
|
|
||||||
MasterGrpcDataServer::~MasterGrpcDataServer() { delete service_; }
|
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||||
auto service = absl::make_unique<GrpcMasterImpl>(builder, protocol_);
|
auto service = absl::make_unique<GrpcDispatcherImpl>(builder, protocol_);
|
||||||
service_ = service.release();
|
service_ = service.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
|
||||||
GetWorkersRequest req;
|
GetWorkersRequest req;
|
||||||
GetWorkersResponse resp;
|
GetWorkersResponse resp;
|
||||||
grpc::ServerContext ctx;
|
grpc::ServerContext ctx;
|
||||||
@ -95,19 +95,18 @@ Status MasterGrpcDataServer::NumWorkers(int* num_workers) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
WorkerGrpcDataServer::WorkerGrpcDataServer(int port,
|
WorkerGrpcDataServer::WorkerGrpcDataServer(
|
||||||
const std::string& protocol,
|
int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address, const std::string& worker_address)
|
||||||
const std::string& worker_address)
|
|
||||||
: GrpcDataServerBase(port, protocol),
|
: GrpcDataServerBase(port, protocol),
|
||||||
master_address_(master_address),
|
dispatcher_address_(dispatcher_address),
|
||||||
worker_address_(worker_address) {}
|
worker_address_(worker_address) {}
|
||||||
|
|
||||||
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
|
||||||
|
|
||||||
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
|
||||||
auto service =
|
auto service = absl::make_unique<GrpcWorkerImpl>(builder, dispatcher_address_,
|
||||||
absl::make_unique<GrpcWorkerImpl>(builder, master_address_, protocol_);
|
protocol_);
|
||||||
service_ = service.release();
|
service_ = service.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,25 +122,25 @@ Status WorkerGrpcDataServer::StartServiceInternal() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewMasterServer(int port, const std::string& protocol,
|
Status NewDispatchServer(int port, const std::string& protocol,
|
||||||
std::unique_ptr<MasterGrpcDataServer>* out_server) {
|
std::unique_ptr<DispatchGrpcDataServer>* out_server) {
|
||||||
*out_server = absl::make_unique<MasterGrpcDataServer>(port, protocol);
|
*out_server = absl::make_unique<DispatchGrpcDataServer>(port, protocol);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||||
return NewWorkerServer(port, protocol, master_address, /*worker_address=*/"",
|
return NewWorkerServer(port, protocol, dispatcher_address,
|
||||||
out_server);
|
/*worker_address=*/"", out_server);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address,
|
const std::string& worker_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
std::unique_ptr<WorkerGrpcDataServer>* out_server) {
|
||||||
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
*out_server = absl::make_unique<WorkerGrpcDataServer>(
|
||||||
port, protocol, master_address, worker_address);
|
port, protocol, dispatcher_address, worker_address);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ namespace data {
|
|||||||
|
|
||||||
// Forward declared because transitively depending on .grpc.pb.h files causes
|
// Forward declared because transitively depending on .grpc.pb.h files causes
|
||||||
// issues in the pywrap build.
|
// issues in the pywrap build.
|
||||||
class GrpcMasterImpl;
|
class GrpcDispatcherImpl;
|
||||||
class GrpcWorkerImpl;
|
class GrpcWorkerImpl;
|
||||||
|
|
||||||
// A grpc server for the tf.data service.
|
// A grpc server for the tf.data service.
|
||||||
@ -35,7 +35,7 @@ class GrpcDataServerBase {
|
|||||||
// server will find an available port in `Start()`. The chosen port can be
|
// server will find an available port in `Start()`. The chosen port can be
|
||||||
// found in the output of `Target()`.
|
// found in the output of `Target()`.
|
||||||
//
|
//
|
||||||
// master_address is only needed for worker data servers.
|
// dispatcher_address is only needed for worker data servers.
|
||||||
GrpcDataServerBase(int requested_port, const std::string& protocol);
|
GrpcDataServerBase(int requested_port, const std::string& protocol);
|
||||||
virtual ~GrpcDataServerBase() {}
|
virtual ~GrpcDataServerBase() {}
|
||||||
|
|
||||||
@ -70,12 +70,12 @@ class GrpcDataServerBase {
|
|||||||
std::unique_ptr<grpc::Server> server_;
|
std::unique_ptr<grpc::Server> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MasterGrpcDataServer : public GrpcDataServerBase {
|
class DispatchGrpcDataServer : public GrpcDataServerBase {
|
||||||
public:
|
public:
|
||||||
MasterGrpcDataServer(int requested_port, const std::string& protocol);
|
DispatchGrpcDataServer(int requested_port, const std::string& protocol);
|
||||||
~MasterGrpcDataServer() override;
|
~DispatchGrpcDataServer() override;
|
||||||
|
|
||||||
// Returns the number of workers registerd with the master.
|
// Returns the number of workers registerd with the dispatcher.
|
||||||
Status NumWorkers(int* num_workers);
|
Status NumWorkers(int* num_workers);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -83,14 +83,14 @@ class MasterGrpcDataServer : public GrpcDataServerBase {
|
|||||||
Status StartServiceInternal() override { return Status::OK(); }
|
Status StartServiceInternal() override { return Status::OK(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Owned. We use a raw pointer because GrpcMasterImpl is forward-declared.
|
// Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
|
||||||
GrpcMasterImpl* service_;
|
GrpcDispatcherImpl* service_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
class WorkerGrpcDataServer : public GrpcDataServerBase {
|
||||||
public:
|
public:
|
||||||
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
WorkerGrpcDataServer(int requested_port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address);
|
const std::string& worker_address);
|
||||||
~WorkerGrpcDataServer() override;
|
~WorkerGrpcDataServer() override;
|
||||||
|
|
||||||
@ -99,15 +99,15 @@ class WorkerGrpcDataServer : public GrpcDataServerBase {
|
|||||||
Status StartServiceInternal() override;
|
Status StartServiceInternal() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const std::string master_address_;
|
const std::string dispatcher_address_;
|
||||||
const std::string worker_address_;
|
const std::string worker_address_;
|
||||||
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
// Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
|
||||||
GrpcWorkerImpl* service_;
|
GrpcWorkerImpl* service_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates a master tf.data server and stores it in `*out_server`.
|
// Creates a dispatch tf.data server and stores it in `*out_server`.
|
||||||
Status NewMasterServer(int port, const std::string& protocol,
|
Status NewDispatchServer(int port, const std::string& protocol,
|
||||||
std::unique_ptr<MasterGrpcDataServer>* out_server);
|
std::unique_ptr<DispatchGrpcDataServer>* out_server);
|
||||||
|
|
||||||
// Creates a worker tf.data server and stores it in `*out_server`.
|
// Creates a worker tf.data server and stores it in `*out_server`.
|
||||||
//
|
//
|
||||||
@ -115,18 +115,18 @@ Status NewMasterServer(int port, const std::string& protocol,
|
|||||||
// will be chosen in Start(). This value can be queried with BoundPort().
|
// will be chosen in Start(). This value can be queried with BoundPort().
|
||||||
//
|
//
|
||||||
// The worker_address argument is optional. If left empty, it will default to
|
// The worker_address argument is optional. If left empty, it will default to
|
||||||
// "localhost:%port%". When the worker registers with the master, the worker
|
// "localhost:%port%". When the worker registers with the dispatcher, the worker
|
||||||
// will report the worker address, so that the master can tell clients where to
|
// will report the worker address, so that the dispatcher can tell clients where
|
||||||
// read from. The address may contain the placeholder "%port%", which will be
|
// to read from. The address may contain the placeholder "%port%", which will be
|
||||||
// replaced with the value of BoundPort().
|
// replaced with the value of BoundPort().
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
const std::string& worker_address,
|
const std::string& worker_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||||
|
|
||||||
// Creates a worker using the default worker_address.
|
// Creates a worker using the default worker_address.
|
||||||
Status NewWorkerServer(int port, const std::string& protocol,
|
Status NewWorkerServer(int port, const std::string& protocol,
|
||||||
const std::string& master_address,
|
const std::string& dispatcher_address,
|
||||||
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
std::unique_ptr<WorkerGrpcDataServer>* out_server);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -45,9 +45,9 @@ Status TestCluster::Initialize() {
|
|||||||
"Test cluster has already been initialized.");
|
"Test cluster has already been initialized.");
|
||||||
}
|
}
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
TF_RETURN_IF_ERROR(NewMasterServer(/*port=*/0, kProtocol, &master_));
|
TF_RETURN_IF_ERROR(NewDispatchServer(/*port=*/0, kProtocol, &dispatcher_));
|
||||||
TF_RETURN_IF_ERROR(master_->Start());
|
TF_RETURN_IF_ERROR(dispatcher_->Start());
|
||||||
master_address_ = absl::StrCat("localhost:", master_->BoundPort());
|
dispatcher_address_ = absl::StrCat("localhost:", dispatcher_->BoundPort());
|
||||||
workers_.reserve(num_workers_);
|
workers_.reserve(num_workers_);
|
||||||
worker_addresses_.reserve(num_workers_);
|
worker_addresses_.reserve(num_workers_);
|
||||||
for (int i = 0; i < num_workers_; ++i) {
|
for (int i = 0; i < num_workers_; ++i) {
|
||||||
@ -59,14 +59,14 @@ Status TestCluster::Initialize() {
|
|||||||
Status TestCluster::AddWorker() {
|
Status TestCluster::AddWorker() {
|
||||||
std::unique_ptr<WorkerGrpcDataServer> worker;
|
std::unique_ptr<WorkerGrpcDataServer> worker;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
NewWorkerServer(/*port=*/0, kProtocol, master_address_, &worker));
|
NewWorkerServer(/*port=*/0, kProtocol, dispatcher_address_, &worker));
|
||||||
TF_RETURN_IF_ERROR(worker->Start());
|
TF_RETURN_IF_ERROR(worker->Start());
|
||||||
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
worker_addresses_.push_back(absl::StrCat("localhost:", worker->BoundPort()));
|
||||||
workers_.push_back(std::move(worker));
|
workers_.push_back(std::move(worker));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string TestCluster::MasterAddress() { return master_address_; }
|
std::string TestCluster::DispatcherAddress() { return dispatcher_address_; }
|
||||||
|
|
||||||
std::string TestCluster::WorkerAddress(int index) {
|
std::string TestCluster::WorkerAddress(int index) {
|
||||||
DCHECK_GE(index, 0);
|
DCHECK_GE(index, 0);
|
||||||
|
@ -24,7 +24,7 @@ namespace data {
|
|||||||
// Helper class for unit testing a tf.data service cluster.
|
// Helper class for unit testing a tf.data service cluster.
|
||||||
class TestCluster {
|
class TestCluster {
|
||||||
public:
|
public:
|
||||||
// Creates a new test cluster with a master and `num_workers` workers.
|
// Creates a new test cluster with a dispatcher and `num_workers` workers.
|
||||||
explicit TestCluster(int num_workers);
|
explicit TestCluster(int num_workers);
|
||||||
|
|
||||||
// Initializes the test cluster. This must be called before interacting with
|
// Initializes the test cluster. This must be called before interacting with
|
||||||
@ -32,8 +32,8 @@ class TestCluster {
|
|||||||
Status Initialize();
|
Status Initialize();
|
||||||
// Adds a new worker to the cluster.
|
// Adds a new worker to the cluster.
|
||||||
Status AddWorker();
|
Status AddWorker();
|
||||||
// Returns the master address in the form "hostname:port".
|
// Returns the dispatcher address in the form "hostname:port".
|
||||||
std::string MasterAddress();
|
std::string DispatcherAddress();
|
||||||
// Returns the address of the worker at the specified index, in the form
|
// Returns the address of the worker at the specified index, in the form
|
||||||
// "hostname:port". The index must be non-negative and less than the number of
|
// "hostname:port". The index must be non-negative and less than the number of
|
||||||
// workers in the cluster.
|
// workers in the cluster.
|
||||||
@ -42,8 +42,8 @@ class TestCluster {
|
|||||||
private:
|
private:
|
||||||
bool initialized_ = false;
|
bool initialized_ = false;
|
||||||
int num_workers_;
|
int num_workers_;
|
||||||
std::unique_ptr<MasterGrpcDataServer> master_;
|
std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
|
||||||
std::string master_address_;
|
std::string dispatcher_address_;
|
||||||
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
|
std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
|
||||||
std::vector<std::string> worker_addresses_;
|
std::vector<std::string> worker_addresses_;
|
||||||
};
|
};
|
||||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/data/dataset.pb.h"
|
#include "tensorflow/core/data/dataset.pb.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
|
||||||
#include "tensorflow/core/data/standalone.h"
|
#include "tensorflow/core/data/standalone.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -45,9 +45,9 @@ auto* tf_data_service_created =
|
|||||||
"has been created.");
|
"has been created.");
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DataServiceWorkerImpl::DataServiceWorkerImpl(const std::string& master_address,
|
DataServiceWorkerImpl::DataServiceWorkerImpl(
|
||||||
const std::string& protocol)
|
const std::string& dispatcher_address, const std::string& protocol)
|
||||||
: master_address_(master_address), protocol_(protocol) {
|
: dispatcher_address_(dispatcher_address), protocol_(protocol) {
|
||||||
tf_data_service_created->GetCell()->Set(true);
|
tf_data_service_created->GetCell()->Set(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,14 +67,13 @@ void DataServiceWorkerImpl::Start(const std::string& worker_address) {
|
|||||||
heartbeat_thread_.reset(thread);
|
heartbeat_thread_.reset(thread);
|
||||||
Status s = Register();
|
Status s = Register();
|
||||||
while (!s.ok()) {
|
while (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to register with master at " << master_address_
|
LOG(WARNING) << "Failed to register with dispatcher at "
|
||||||
<< ": " << s;
|
<< dispatcher_address_ << ": " << s;
|
||||||
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
Env::Default()->SleepForMicroseconds(kHeartbeatIntervalMicros);
|
||||||
s = Register();
|
s = Register();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
|
||||||
ProcessTaskResponse* response) {
|
ProcessTaskResponse* response) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -169,29 +168,29 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::EnsureMasterStubInitialized()
|
Status DataServiceWorkerImpl::EnsureDispatcherStubInitialized()
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
if (!master_stub_) {
|
if (!dispatcher_stub_) {
|
||||||
::grpc::ChannelArguments args;
|
::grpc::ChannelArguments args;
|
||||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
CredentialsFactory::CreateClientCredentials(protocol_, &credentials));
|
||||||
auto channel =
|
auto channel =
|
||||||
::grpc::CreateCustomChannel(master_address_, credentials, args);
|
::grpc::CreateCustomChannel(dispatcher_address_, credentials, args);
|
||||||
master_stub_ = MasterService::NewStub(channel);
|
dispatcher_stub_ = DispatcherService::NewStub(channel);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
VLOG(3) << "Registering with master at " << master_address_;
|
VLOG(3) << "Registering with dispatcher at " << dispatcher_address_;
|
||||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||||
RegisterWorkerRequest req;
|
RegisterWorkerRequest req;
|
||||||
req.set_worker_address(worker_address_);
|
req.set_worker_address(worker_address_);
|
||||||
RegisterWorkerResponse resp;
|
RegisterWorkerResponse resp;
|
||||||
|
|
||||||
grpc::ClientContext ctx;
|
grpc::ClientContext ctx;
|
||||||
grpc::Status s = master_stub_->RegisterWorker(&ctx, req, &resp);
|
grpc::Status s = dispatcher_stub_->RegisterWorker(&ctx, req, &resp);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return grpc_util::WrapError("Failed to register worker", s);
|
return grpc_util::WrapError("Failed to register worker", s);
|
||||||
}
|
}
|
||||||
@ -205,8 +204,8 @@ Status DataServiceWorkerImpl::Register() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
|||||||
|
|
||||||
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
VLOG(3) << "Sending " << pending_completed_tasks_.size()
|
VLOG(3) << "Sending " << pending_completed_tasks_.size()
|
||||||
<< " task updates to master";
|
<< " task updates to dispatcher";
|
||||||
TF_RETURN_IF_ERROR(EnsureMasterStubInitialized());
|
TF_RETURN_IF_ERROR(EnsureDispatcherStubInitialized());
|
||||||
WorkerUpdateRequest req;
|
WorkerUpdateRequest req;
|
||||||
req.set_worker_id(worker_id_);
|
req.set_worker_id(worker_id_);
|
||||||
for (int task_id : pending_completed_tasks_) {
|
for (int task_id : pending_completed_tasks_) {
|
||||||
@ -217,7 +216,7 @@ Status DataServiceWorkerImpl::SendTaskUpdate() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
|||||||
|
|
||||||
WorkerUpdateResponse resp;
|
WorkerUpdateResponse resp;
|
||||||
grpc::ClientContext ctx;
|
grpc::ClientContext ctx;
|
||||||
grpc::Status s = master_stub_->WorkerUpdate(&ctx, req, &resp);
|
grpc::Status s = dispatcher_stub_->WorkerUpdate(&ctx, req, &resp);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return grpc_util::WrapError("Failed to send task updates", s);
|
return grpc_util::WrapError("Failed to send task updates", s);
|
||||||
}
|
}
|
||||||
@ -238,7 +237,7 @@ void DataServiceWorkerImpl::HeartbeatThread() {
|
|||||||
}
|
}
|
||||||
Status s = SendTaskUpdate();
|
Status s = SendTaskUpdate();
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to send task updates to master: " << s;
|
LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.pb.h"
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/data/standalone.h"
|
#include "tensorflow/core/data/standalone.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -29,17 +29,17 @@ namespace data {
|
|||||||
// A TensorFlow DataService serves dataset elements over RPC.
|
// A TensorFlow DataService serves dataset elements over RPC.
|
||||||
class DataServiceWorkerImpl {
|
class DataServiceWorkerImpl {
|
||||||
public:
|
public:
|
||||||
explicit DataServiceWorkerImpl(const std::string& master_address,
|
explicit DataServiceWorkerImpl(const std::string& dispatcher_address,
|
||||||
const std::string& protocol);
|
const std::string& protocol);
|
||||||
~DataServiceWorkerImpl();
|
~DataServiceWorkerImpl();
|
||||||
|
|
||||||
// Starts the worker. The worker needs to know its own address so that it can
|
// Starts the worker. The worker needs to know its own address so that it can
|
||||||
// register with the master.
|
// register with the dispatcher.
|
||||||
void Start(const std::string& worker_address);
|
void Start(const std::string& worker_address);
|
||||||
|
|
||||||
// See worker.proto for API documentation.
|
// See worker.proto for API documentation.
|
||||||
|
|
||||||
/// Master-facing API.
|
/// Dispatcher-facing API.
|
||||||
Status ProcessTask(const ProcessTaskRequest* request,
|
Status ProcessTask(const ProcessTaskRequest* request,
|
||||||
ProcessTaskResponse* response);
|
ProcessTaskResponse* response);
|
||||||
|
|
||||||
@ -48,15 +48,15 @@ class DataServiceWorkerImpl {
|
|||||||
GetElementResponse* response);
|
GetElementResponse* response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Sets master_stub_ if it isn't already set.
|
// Sets dispatcher_stub_ if it isn't already set.
|
||||||
Status EnsureMasterStubInitialized();
|
Status EnsureDispatcherStubInitialized();
|
||||||
// Registers the worker with the master.
|
// Registers the worker with the dispatcher.
|
||||||
Status Register();
|
Status Register();
|
||||||
// Sends task status to the master.
|
// Sends task status to the dispatcher.
|
||||||
Status SendTaskUpdate();
|
Status SendTaskUpdate();
|
||||||
// Creates an iterator to process a task.
|
// Creates an iterator to process a task.
|
||||||
Status ProcessTaskInternal(const TaskDef& task);
|
Status ProcessTaskInternal(const TaskDef& task);
|
||||||
// A thread for updating the master with worker status.
|
// A thread for updating the dispatcher with worker status.
|
||||||
void HeartbeatThread();
|
void HeartbeatThread();
|
||||||
|
|
||||||
typedef struct Task {
|
typedef struct Task {
|
||||||
@ -67,18 +67,19 @@ class DataServiceWorkerImpl {
|
|||||||
std::unique_ptr<standalone::Iterator> iterator;
|
std::unique_ptr<standalone::Iterator> iterator;
|
||||||
} Task;
|
} Task;
|
||||||
|
|
||||||
const std::string master_address_;
|
const std::string dispatcher_address_;
|
||||||
// Protocol for communicating with the master.
|
// Protocol for communicating with the dispatcher.
|
||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
// The worker's own address.
|
// The worker's own address.
|
||||||
std::string worker_address_;
|
std::string worker_address_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
int64 worker_id_ TF_GUARDED_BY(mu_);
|
int64 worker_id_ TF_GUARDED_BY(mu_);
|
||||||
std::unique_ptr<MasterService::Stub> master_stub_ TF_GUARDED_BY(mu_);
|
std::unique_ptr<DispatcherService::Stub> dispatcher_stub_ TF_GUARDED_BY(mu_);
|
||||||
// Information about tasks, keyed by task ids.
|
// Information about tasks, keyed by task ids.
|
||||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||||
// List of completed tasks which haven't yet been communicated to the master.
|
// List of completed tasks which haven't yet been communicated to the
|
||||||
|
// dispatcher.
|
||||||
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
std::vector<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
||||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||||
// Condition variable for notifying the heartbeat thread.
|
// Condition variable for notifying the heartbeat thread.
|
||||||
|
@ -5864,15 +5864,15 @@ cc_library(
|
|||||||
":string_format_op",
|
":string_format_op",
|
||||||
":string_join_op",
|
":string_join_op",
|
||||||
":string_length_op",
|
":string_length_op",
|
||||||
":string_lower_op",
|
# ":string_lower_op",
|
||||||
":string_ngrams_op",
|
":string_ngrams_op",
|
||||||
":string_split_op",
|
":string_split_op",
|
||||||
":string_strip_op",
|
":string_strip_op",
|
||||||
":string_to_hash_bucket_op",
|
":string_to_hash_bucket_op",
|
||||||
":string_upper_op",
|
# ":string_upper_op",
|
||||||
":substr_op",
|
":substr_op",
|
||||||
":unicode_ops",
|
# ":unicode_ops",
|
||||||
":unicode_script_op",
|
# ":unicode_script_op",
|
||||||
":unsorted_segment_join_op",
|
":unsorted_segment_join_op",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -5885,7 +5885,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@icu//:common",
|
# "@icu//:common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -6041,7 +6041,7 @@ tf_kernel_library(
|
|||||||
prefix = "string_lower_op",
|
prefix = "string_lower_op",
|
||||||
deps = STRING_DEPS + [
|
deps = STRING_DEPS + [
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@icu//:common",
|
# "@icu//:common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -6050,7 +6050,7 @@ tf_kernel_library(
|
|||||||
prefix = "string_upper_op",
|
prefix = "string_upper_op",
|
||||||
deps = STRING_DEPS + [
|
deps = STRING_DEPS + [
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@icu//:common",
|
# "@icu//:common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -6096,7 +6096,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"//third_party/icu/data:conversion_data",
|
"//third_party/icu/data:conversion_data",
|
||||||
"@icu//:common",
|
# "@icu//:common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -7125,10 +7125,10 @@ filegroup(
|
|||||||
"mutex_ops.*",
|
"mutex_ops.*",
|
||||||
"batch_kernels.*",
|
"batch_kernels.*",
|
||||||
"regex_replace_op.cc",
|
"regex_replace_op.cc",
|
||||||
"string_lower_op.cc", # Requires ICU for unicode.
|
# "string_lower_op.cc", # Requires ICU for unicode.
|
||||||
"string_upper_op.cc", # Requires ICU for unicode.
|
# "string_upper_op.cc", # Requires ICU for unicode.
|
||||||
"unicode_ops.cc",
|
"unicode_ops.cc",
|
||||||
"unicode_script_op.cc",
|
# "unicode_script_op.cc",
|
||||||
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
|
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
|
||||||
"mkl_*",
|
"mkl_*",
|
||||||
"xsmm_*",
|
"xsmm_*",
|
||||||
@ -8620,7 +8620,7 @@ tf_kernel_library(
|
|||||||
srcs = ["unicode_script_op.cc"],
|
srcs = ["unicode_script_op.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"@icu//:common",
|
# "@icu//:common",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -8652,6 +8652,39 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "deepspeech_cwise_ops",
|
||||||
|
srcs = [
|
||||||
|
"cwise_op_add_1.cc",
|
||||||
|
"cwise_op_add_2.cc",
|
||||||
|
"cwise_op_less.cc",
|
||||||
|
"cwise_op_minimum.cc",
|
||||||
|
"cwise_op_mul_1.cc",
|
||||||
|
"cwise_op_rsqrt.cc",
|
||||||
|
"cwise_op_squared_difference.cc",
|
||||||
|
"cwise_op_sub.cc",
|
||||||
|
"cwise_op_sigmoid.cc",
|
||||||
|
"cwise_op_tanh.cc",
|
||||||
|
],
|
||||||
|
gpu_srcs = [
|
||||||
|
"cwise_op_gpu_add.cu.cc",
|
||||||
|
"cwise_op_gpu_less.cu.cc",
|
||||||
|
"cwise_op_gpu_minimum.cu.cc",
|
||||||
|
"cwise_op_gpu_mul.cu.cc",
|
||||||
|
"cwise_op_gpu_rsqrt.cu.cc",
|
||||||
|
"cwise_op_gpu_squared_difference.cu.cc",
|
||||||
|
"cwise_op_gpu_sub.cu.cc",
|
||||||
|
"cwise_op_gpu_sigmoid.cu.cc",
|
||||||
|
"cwise_op_gpu_tanh.cu.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":cwise_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Header-only version of cwise_lib for clients that want to use the cwise_ops
|
# Header-only version of cwise_lib for clients that want to use the cwise_ops
|
||||||
# functionality in their own custom ops.
|
# functionality in their own custom ops.
|
||||||
cc_header_only_library(
|
cc_header_only_library(
|
||||||
|
@ -116,6 +116,7 @@ REGISTER_KERNEL(GPU, int16);
|
|||||||
REGISTER_KERNEL(GPU, qint16);
|
REGISTER_KERNEL(GPU, qint16);
|
||||||
REGISTER_KERNEL(GPU, quint16);
|
REGISTER_KERNEL(GPU, quint16);
|
||||||
REGISTER_KERNEL(GPU, uint32);
|
REGISTER_KERNEL(GPU, uint32);
|
||||||
|
REGISTER_KERNEL(GPU, int32);
|
||||||
REGISTER_KERNEL(GPU, qint32);
|
REGISTER_KERNEL(GPU, qint32);
|
||||||
REGISTER_KERNEL(GPU, int64);
|
REGISTER_KERNEL(GPU, int64);
|
||||||
REGISTER_KERNEL(GPU, uint64);
|
REGISTER_KERNEL(GPU, uint64);
|
||||||
|
@ -69,7 +69,7 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
|||||||
// Dataset for reading data from the tf.data service non-deterministically.
|
// Dataset for reading data from the tf.data service non-deterministically.
|
||||||
//
|
//
|
||||||
// This dataset interleaves dataset elements produced by multiple tf.data
|
// This dataset interleaves dataset elements produced by multiple tf.data
|
||||||
// workers. We periodically query the tf.data master to determine which workers
|
// workers. We periodically query the dispatcher to determine which workers
|
||||||
// to read from (in case workers are added or removed).
|
// to read from (in case workers are added or removed).
|
||||||
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
@ -199,12 +199,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
VLOG(3) << "Connecting to " << dataset()->address_
|
VLOG(3) << "Connecting to " << dataset()->address_
|
||||||
<< " in data service dataset op";
|
<< " in data service dataset op";
|
||||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||||
|
dataset()->protocol_);
|
||||||
if (dataset()->job_name_.empty()) {
|
if (dataset()->job_name_.empty()) {
|
||||||
TF_RETURN_IF_ERROR(master.CreateJob(
|
TF_RETURN_IF_ERROR(dispatcher.CreateJob(
|
||||||
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
|
TF_RETURN_IF_ERROR(dispatcher.GetOrCreateJob(
|
||||||
dataset()->dataset_id_, dataset()->processing_mode_,
|
dataset()->dataset_id_, dataset()->processing_mode_,
|
||||||
dataset()->job_name_, iterator_index_, &job_id_));
|
dataset()->job_name_, iterator_index_, &job_id_));
|
||||||
}
|
}
|
||||||
@ -283,11 +284,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
// Periodically refresh the task list.
|
// Periodically refresh the task list.
|
||||||
// Maintain one thread fetching elements for each task.
|
// Maintain one thread fetching elements for each task.
|
||||||
// TODO(aaudibert): Instead of polling, have master send updates when
|
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
|
||||||
// the list of tasks changes.
|
// the list of tasks changes.
|
||||||
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
|
||||||
VLOG(3) << "Starting task thread manager";
|
VLOG(3) << "Starting task thread manager";
|
||||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||||
|
dataset()->protocol_);
|
||||||
uint64 next_check = Env::Default()->NowMicros();
|
uint64 next_check = Env::Default()->NowMicros();
|
||||||
while (true) {
|
while (true) {
|
||||||
{
|
{
|
||||||
@ -305,18 +307,19 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
UpdateTasks(&master);
|
UpdateTasks(&dispatcher);
|
||||||
UpdateWorkerThreads(ctx.get());
|
UpdateWorkerThreads(ctx.get());
|
||||||
next_check = Env::Default()->NowMicros() +
|
next_check = Env::Default()->NowMicros() +
|
||||||
dataset()->task_refresh_interval_ms_ * 1000;
|
dataset()->task_refresh_interval_ms_ * 1000;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateTasks(DataServiceMasterClient* master) LOCKS_EXCLUDED(mu_) {
|
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
|
||||||
|
LOCKS_EXCLUDED(mu_) {
|
||||||
VLOG(3) << "Updating tasks";
|
VLOG(3) << "Updating tasks";
|
||||||
std::vector<TaskInfo> tasks;
|
std::vector<TaskInfo> tasks;
|
||||||
bool job_finished;
|
bool job_finished;
|
||||||
Status s = master->GetTasks(job_id_, &tasks, &job_finished);
|
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
|
||||||
<< s;
|
<< s;
|
||||||
|
@ -53,7 +53,7 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
|||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||||
|
|
||||||
DataServiceMasterClient client(address, protocol);
|
DataServiceDispatcherClient client(address, protocol);
|
||||||
int64 dataset_id;
|
int64 dataset_id;
|
||||||
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
OP_REQUIRES_OK(ctx, client.RegisterDataset(graph_def, &dataset_id));
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ namespace data {
|
|||||||
|
|
||||||
// Registers a dataset with the tf.data service.
|
// Registers a dataset with the tf.data service.
|
||||||
//
|
//
|
||||||
// The address and protocol inputs are used to connect to the tf.data master.
|
// The address and protocol inputs are used to connect to the dispatcher.
|
||||||
// The external state policy attribute determines whether to ignore, warn, or
|
// The external state policy attribute determines whether to ignore, warn, or
|
||||||
// error out when the dataset contains external state.
|
// error out when the dataset contains external state.
|
||||||
// The op produces a dataset id for identifying the registered dataset.
|
// The op produces a dataset id for identifying the registered dataset.
|
||||||
|
@ -61,6 +61,8 @@ message SavedObject {
|
|||||||
SavedConstant constant = 9;
|
SavedConstant constant = 9;
|
||||||
SavedResource resource = 10;
|
SavedResource resource = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
map<string, SaveableObject> saveable_objects = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A SavedUserObject is an object (in the object-oriented language of the
|
// A SavedUserObject is an object (in the object-oriented language of the
|
||||||
@ -162,3 +164,9 @@ message SavedResource {
|
|||||||
// device.
|
// device.
|
||||||
string device = 1;
|
string device = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message SaveableObject {
|
||||||
|
// Node ids of concrete functions for saving and loading from a checkpoint.
|
||||||
|
int32 save_function = 2;
|
||||||
|
int32 restore_function = 3;
|
||||||
|
}
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
// Also update tensorflow/tensorflow.bzl and
|
// Also update tensorflow/tensorflow.bzl and
|
||||||
// tensorflow/tools/pip_package/setup.py
|
// tensorflow/tools/pip_package/setup.py
|
||||||
#define TF_MAJOR_VERSION 2
|
#define TF_MAJOR_VERSION 2
|
||||||
#define TF_MINOR_VERSION 2
|
#define TF_MINOR_VERSION 3
|
||||||
#define TF_PATCH_VERSION 0
|
#define TF_PATCH_VERSION 0
|
||||||
|
|
||||||
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
||||||
|
@ -57,7 +57,6 @@ cc_library(
|
|||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}) + select({
|
}) + select({
|
||||||
"//tensorflow:fuchsia": [],
|
"//tensorflow:fuchsia": [],
|
||||||
"//tensorflow:windows": [],
|
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
],
|
],
|
||||||
|
@ -77,7 +77,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
amount of memory used, since `distribute` won't use more than
|
amount of memory used, since `distribute` won't use more than
|
||||||
`element_size` * `max_outstanding_requests` of memory.
|
`element_size` * `max_outstanding_requests` of memory.
|
||||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
||||||
the master for task changes.
|
the dispatcher for task changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if job_name is None:
|
if job_name is None:
|
||||||
@ -173,7 +173,7 @@ def _distribute(processing_mode,
|
|||||||
of memory used, since `distribute` won't use more than `element_size` *
|
of memory used, since `distribute` won't use more than `element_size` *
|
||||||
`max_outstanding_requests` of memory.
|
`max_outstanding_requests` of memory.
|
||||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the
|
||||||
master for task changes.
|
dispatcher for task changes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
|
@ -19,5 +19,5 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
||||||
from tensorflow.python.data.experimental.service.server_lib import MasterServer
|
from tensorflow.python.data.experimental.service.server_lib import DispatchServer
|
||||||
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
||||||
|
@ -24,35 +24,35 @@ from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.service.MasterServer", v1=[])
|
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||||||
class MasterServer(object):
|
class DispatchServer(object):
|
||||||
"""An in-process tf.data service master server.
|
"""An in-process tf.data service dispatch server.
|
||||||
|
|
||||||
A `tf.data.experimental.service.MasterServer` coordinates a cluster of
|
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
|
||||||
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||||||
register themselves with the master.
|
register themselves with the dispatcher.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> master_address = master.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, master_address=master_address)
|
... port=0, dispatcher_address=dispatcher_address)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
>>> print(list(dataset.as_numpy_iterator()))
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
When starting a dedicated tf.data master process, use join() to block
|
When starting a dedicated tf.data dispatch process, use join() to block
|
||||||
indefinitely after starting up the server.
|
indefinitely after starting up the server.
|
||||||
|
|
||||||
```
|
```
|
||||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||||
master.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, port, protocol=None, start=True):
|
def __init__(self, port, protocol=None, start=True):
|
||||||
"""Creates a new master server.
|
"""Creates a new dispatch server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to.
|
port: Specifies the port to bind to.
|
||||||
@ -68,15 +68,16 @@ class MasterServer(object):
|
|||||||
if protocol is None:
|
if protocol is None:
|
||||||
protocol = "grpc"
|
protocol = "grpc"
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol)
|
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(port, protocol)
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Starts this server.
|
"""Starts this server.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0,
|
||||||
>>> master.start()
|
... start=False)
|
||||||
|
>>> dispatcher.start()
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||||
@ -87,11 +88,11 @@ class MasterServer(object):
|
|||||||
def join(self):
|
def join(self):
|
||||||
"""Blocks until the server has shut down.
|
"""Blocks until the server has shut down.
|
||||||
|
|
||||||
This is useful when starting a dedicated master process.
|
This is useful when starting a dedicated dispatch process.
|
||||||
|
|
||||||
```
|
```
|
||||||
master = tf.data.experimental.service.MasterServer(port=5050)
|
dispatcher = tf.data.experimental.service.DispatchServer(port=5050)
|
||||||
master.join()
|
dispatcher.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -104,10 +105,10 @@ class MasterServer(object):
|
|||||||
def target(self):
|
def target(self):
|
||||||
"""Returns a target that can be used to connect to the server.
|
"""Returns a target that can be used to connect to the server.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
|
|
||||||
The returned string will be in the form protocol://address, e.g.
|
The returned string will be in the form protocol://address, e.g.
|
||||||
"grpc://localhost:5050".
|
"grpc://localhost:5050".
|
||||||
@ -136,7 +137,7 @@ class MasterServer(object):
|
|||||||
return "localhost:{0}".format(self._server.bound_port())
|
return "localhost:{0}".format(self._server.bound_port())
|
||||||
|
|
||||||
def _num_workers(self):
|
def _num_workers(self):
|
||||||
"""Returns the number of workers registered with the master."""
|
"""Returns the number of workers registered with the dispatcher."""
|
||||||
return self._server.num_workers()
|
return self._server.num_workers()
|
||||||
|
|
||||||
|
|
||||||
@ -147,15 +148,15 @@ class WorkerServer(object):
|
|||||||
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
||||||
processing for user-defined datasets, and provides the resulting elements over
|
processing for user-defined datasets, and provides the resulting elements over
|
||||||
RPC. A worker is associated with a single
|
RPC. A worker is associated with a single
|
||||||
`tf.data.experimental.service.MasterServer`.
|
`tf.data.experimental.service.DispatchServer`.
|
||||||
|
|
||||||
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(port=0)
|
||||||
>>> master_address = master.target.split("://")[1]
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||||||
>>> worker = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
... port=0, master_address=master_address)
|
... port=0, dispatcher_address=dispatcher_address)
|
||||||
>>> dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
... processing_mode="parallel_epochs", service=master.target))
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||||||
>>> print(list(dataset.as_numpy_iterator()))
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
@ -164,14 +165,14 @@ class WorkerServer(object):
|
|||||||
|
|
||||||
```
|
```
|
||||||
worker = tf.data.experimental.service.WorkerServer(
|
worker = tf.data.experimental.service.WorkerServer(
|
||||||
port=5051, master_address="grpc://localhost:5050")
|
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||||
worker.join()
|
worker.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
port,
|
port,
|
||||||
master_address,
|
dispatcher_address,
|
||||||
worker_address=None,
|
worker_address=None,
|
||||||
protocol=None,
|
protocol=None,
|
||||||
start=True):
|
start=True):
|
||||||
@ -180,11 +181,12 @@ class WorkerServer(object):
|
|||||||
Args:
|
Args:
|
||||||
port: Specifies the port to bind to. A value of 0 indicates that the
|
port: Specifies the port to bind to. A value of 0 indicates that the
|
||||||
worker can bind to any available port.
|
worker can bind to any available port.
|
||||||
master_address: Specifies the address of the master server.
|
dispatcher_address: Specifies the address of the dispatcher.
|
||||||
worker_address: (Optional.) Specifies the address of the worker server.
|
worker_address: (Optional.) Specifies the address of the worker server.
|
||||||
This address is passed to the master server so that the master can tell
|
This address is passed to the dispatcher so that the dispatcher can
|
||||||
clients how to connect to this worker. Defaults to `"localhost:%port%"`,
|
tell clients how to connect to this worker. Defaults to
|
||||||
where `%port%` will be replaced with the port used by the worker.
|
`"localhost:%port%"`, where `%port%` will be replaced with the port used
|
||||||
|
by the worker.
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||||
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`.
|
||||||
start: (Optional.) Boolean, indicating whether to start the server after
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
@ -201,7 +203,7 @@ class WorkerServer(object):
|
|||||||
|
|
||||||
self._protocol = protocol
|
self._protocol = protocol
|
||||||
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||||||
port, protocol, master_address, worker_address)
|
port, protocol, dispatcher_address, worker_address)
|
||||||
if start:
|
if start:
|
||||||
self._server.start()
|
self._server.start()
|
||||||
|
|
||||||
@ -221,7 +223,7 @@ class WorkerServer(object):
|
|||||||
|
|
||||||
```
|
```
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
worker_server = tf.data.experimental.service.WorkerServer(
|
||||||
port=5051, master_address="grpc://localhost:5050")
|
port=5051, dispatcher_address="grpc://localhost:5050")
|
||||||
worker_server.join()
|
worker_server.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -25,68 +25,68 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class ServerLibTest(test.TestCase):
|
class ServerLibTest(test.TestCase):
|
||||||
|
|
||||||
def testStartMaster(self):
|
def testStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0, start=False)
|
dispatcher = server_lib.DispatchServer(0, start=False)
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testMultipleStartMaster(self):
|
def testMultipleStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0, start=True)
|
dispatcher = server_lib.DispatchServer(0, start=True)
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStartWorker(self):
|
def testStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address, start=False)
|
worker = server_lib.WorkerServer(0, dispatcher._address, start=False)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testMultipleStartWorker(self):
|
def testMultipleStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address, start=True)
|
worker = server_lib.WorkerServer(0, dispatcher._address, start=True)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testStopMaster(self):
|
def testStopDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
|
|
||||||
def testStopWorker(self):
|
def testStopWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker._stop()
|
worker._stop()
|
||||||
|
|
||||||
def testStopStartMaster(self):
|
def testStopStartDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
master.start()
|
dispatcher.start()
|
||||||
|
|
||||||
def testStopStartWorker(self):
|
def testStopStartWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, "Server cannot be started after it has been stopped"):
|
RuntimeError, "Server cannot be started after it has been stopped"):
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|
||||||
def testJoinMaster(self):
|
def testJoinDispatcher(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
master._stop()
|
dispatcher._stop()
|
||||||
master.join()
|
dispatcher.join()
|
||||||
|
|
||||||
def testJoinWorker(self):
|
def testJoinWorker(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
worker = server_lib.WorkerServer(0, master._address)
|
worker = server_lib.WorkerServer(0, dispatcher._address)
|
||||||
worker._stop()
|
worker._stop()
|
||||||
worker.join()
|
worker.join()
|
||||||
|
|
||||||
def testMasterNumWorkers(self):
|
def testDispatcherNumWorkers(self):
|
||||||
master = server_lib.MasterServer(0)
|
dispatcher = server_lib.DispatchServer(0)
|
||||||
self.assertEqual(0, master._num_workers())
|
self.assertEqual(0, dispatcher._num_workers())
|
||||||
worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
worker1 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||||
self.assertEqual(1, master._num_workers())
|
self.assertEqual(1, dispatcher._num_workers())
|
||||||
worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable
|
worker2 = server_lib.WorkerServer(0, dispatcher._address) # pylint: disable=unused-variable
|
||||||
self.assertEqual(2, master._num_workers())
|
self.assertEqual(2, dispatcher._num_workers())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -28,13 +28,14 @@ limitations under the License.
|
|||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
PYBIND11_MODULE(_pywrap_server_lib, m) {
|
PYBIND11_MODULE(_pywrap_server_lib, m) {
|
||||||
py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer")
|
py::class_<tensorflow::data::DispatchGrpcDataServer>(m,
|
||||||
.def("start", &tensorflow::data::MasterGrpcDataServer::Start)
|
"DispatchGrpcDataServer")
|
||||||
.def("stop", &tensorflow::data::MasterGrpcDataServer::Stop)
|
.def("start", &tensorflow::data::DispatchGrpcDataServer::Start)
|
||||||
.def("join", &tensorflow::data::MasterGrpcDataServer::Join)
|
.def("stop", &tensorflow::data::DispatchGrpcDataServer::Stop)
|
||||||
.def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort)
|
.def("join", &tensorflow::data::DispatchGrpcDataServer::Join)
|
||||||
|
.def("bound_port", &tensorflow::data::DispatchGrpcDataServer::BoundPort)
|
||||||
.def("num_workers",
|
.def("num_workers",
|
||||||
[](tensorflow::data::MasterGrpcDataServer* server) -> int {
|
[](tensorflow::data::DispatchGrpcDataServer* server) -> int {
|
||||||
int num_workers;
|
int num_workers;
|
||||||
tensorflow::Status status = server->NumWorkers(&num_workers);
|
tensorflow::Status status = server->NumWorkers(&num_workers);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
@ -48,12 +49,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||||||
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
.def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TF_DATA_NewMasterServer",
|
"TF_DATA_NewDispatchServer",
|
||||||
[](int port, std::string protocol)
|
[](int port, std::string protocol)
|
||||||
-> std::unique_ptr<tensorflow::data::MasterGrpcDataServer> {
|
-> std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> {
|
||||||
std::unique_ptr<tensorflow::data::MasterGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::DispatchGrpcDataServer> server;
|
||||||
tensorflow::Status status =
|
tensorflow::Status status =
|
||||||
tensorflow::data::NewMasterServer(port, protocol, &server);
|
tensorflow::data::NewDispatchServer(port, protocol, &server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
@ -61,12 +62,12 @@ PYBIND11_MODULE(_pywrap_server_lib, m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"TF_DATA_NewWorkerServer",
|
"TF_DATA_NewWorkerServer",
|
||||||
[](int port, std::string protocol, std::string master_address,
|
[](int port, std::string protocol, std::string dispatcher_address,
|
||||||
std::string worker_address)
|
std::string worker_address)
|
||||||
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
-> std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> {
|
||||||
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
std::unique_ptr<tensorflow::data::WorkerGrpcDataServer> server;
|
||||||
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
tensorflow::Status status = tensorflow::data::NewWorkerServer(
|
||||||
port, protocol, master_address, worker_address, &server);
|
port, protocol, dispatcher_address, worker_address, &server);
|
||||||
tensorflow::MaybeRaiseFromStatus(status);
|
tensorflow::MaybeRaiseFromStatus(status);
|
||||||
return server;
|
return server;
|
||||||
},
|
},
|
||||||
|
@ -59,23 +59,25 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_workers: The number of workers in the cluster.
|
num_workers: The number of workers in the cluster.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The address of the master.
|
The address of the dispatcher.
|
||||||
"""
|
"""
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._servers = []
|
self._servers = []
|
||||||
for _ in range(num_workers):
|
for _ in range(num_workers):
|
||||||
self._servers.append(
|
self._servers.append(
|
||||||
server_lib.WorkerServer(
|
server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL))
|
port=0,
|
||||||
|
dispatcher_address=self._dispatcher._address,
|
||||||
|
protocol=PROTOCOL))
|
||||||
|
|
||||||
return self._master._address
|
return self._dispatcher._address
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeBasic(self):
|
def testDistributeBasic(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
@ -83,10 +85,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testDifferentShuffleOrders(self):
|
def testDifferentShuffleOrders(self):
|
||||||
random_seed.set_random_seed(None)
|
random_seed.set_random_seed(None)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
master_address = self.create_cluster(2)
|
dispatcher_address = self.create_cluster(2)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = ds.shuffle(num_elements)
|
ds = ds.shuffle(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
output = [elem.numpy() for elem in ds]
|
output = [elem.numpy() for elem in ds]
|
||||||
|
|
||||||
# The output will be two sequences of range(num_elements)
|
# The output will be two sequences of range(num_elements)
|
||||||
@ -104,9 +106,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testMultipleEpochs(self):
|
def testMultipleEpochs(self):
|
||||||
num_elements = 3
|
num_elements = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
||||||
|
|
||||||
@ -114,9 +116,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testRepeatedDataset(self):
|
def testRepeatedDataset(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_repetitions = 5
|
num_repetitions = 5
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
ds = ds.repeat(num_repetitions)
|
ds = ds.repeat(num_repetitions)
|
||||||
self.assertDatasetProduces(
|
self.assertDatasetProduces(
|
||||||
ds, expected_output=num_repetitions * list(range(num_elements)))
|
ds, expected_output=num_repetitions * list(range(num_elements)))
|
||||||
@ -125,12 +127,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testConcurrentEpoch(self):
|
def testConcurrentEpoch(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_datasets = 3
|
num_datasets = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
iterators = []
|
iterators = []
|
||||||
results = []
|
results = []
|
||||||
for _ in range(num_datasets):
|
for _ in range(num_datasets):
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
iterators.append(iter(ds))
|
iterators.append(iter(ds))
|
||||||
results.append([])
|
results.append([])
|
||||||
|
|
||||||
@ -146,9 +148,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.skipTest("Not yet implemented")
|
self.skipTest("Not yet implemented")
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_iterators = 3
|
num_iterators = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
result = []
|
result = []
|
||||||
iterators = []
|
iterators = []
|
||||||
for _ in range(num_iterators):
|
for _ in range(num_iterators):
|
||||||
@ -170,20 +172,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testMultiWorker(self):
|
def testMultiWorker(self):
|
||||||
num_workers = 3
|
num_workers = 3
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(num_workers)
|
dispatcher_address = self.create_cluster(num_workers)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
results = [elem.numpy() for elem in ds]
|
results = [elem.numpy() for elem in ds]
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testAddWorkerMidJob(self):
|
def testAddWorkerMidJob(self):
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._worker = server_lib.WorkerServer(
|
self._worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master._address)
|
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
results = []
|
results = []
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
@ -191,10 +193,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
results.append(next(iterator).numpy())
|
results.append(next(iterator).numpy())
|
||||||
|
|
||||||
self._new_worker = server_lib.WorkerServer(
|
self._new_worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
|
|
||||||
# Wait for the new worker to register with the master.
|
# Wait for the new worker to register with the dispatcher.
|
||||||
while self._master._num_workers() < 2:
|
while self._dispatcher._num_workers() < 2:
|
||||||
time.sleep(10 / 1000) # 10ms
|
time.sleep(10 / 1000) # 10ms
|
||||||
|
|
||||||
for elem in iterator:
|
for elem in iterator:
|
||||||
@ -206,12 +208,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
combinations.times(test_base.eager_only_combinations(),
|
combinations.times(test_base.eager_only_combinations(),
|
||||||
combinations.combine(use_same_port=[True, False])))
|
combinations.combine(use_same_port=[True, False])))
|
||||||
def testRestartWorker(self, use_same_port):
|
def testRestartWorker(self, use_same_port):
|
||||||
self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL)
|
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||||
self._worker = server_lib.WorkerServer(
|
self._worker = server_lib.WorkerServer(
|
||||||
port=0, master_address=self._master._address, protocol=PROTOCOL)
|
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master._address)
|
ds = _make_distributed_dataset(ds, self._dispatcher._address)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
midpoint = num_elements // 2
|
midpoint = num_elements // 2
|
||||||
@ -224,7 +226,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
port = int(self._worker._address.split(":")[1])
|
port = int(self._worker._address.split(":")[1])
|
||||||
self._worker._stop()
|
self._worker._stop()
|
||||||
self._new_worker = server_lib.WorkerServer(
|
self._new_worker = server_lib.WorkerServer(
|
||||||
port=port, master_address=self._master._address, protocol=PROTOCOL)
|
port=port,
|
||||||
|
dispatcher_address=self._dispatcher._address,
|
||||||
|
protocol=PROTOCOL)
|
||||||
|
|
||||||
# There may have been some elements prefetched from the first worker
|
# There may have been some elements prefetched from the first worker
|
||||||
# before it was stopped.
|
# before it was stopped.
|
||||||
@ -259,12 +263,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testInsideFunction(self):
|
def testInsideFunction(self):
|
||||||
num_workers = 3
|
num_workers = 3
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(num_workers)
|
dispatcher_address = self.create_cluster(num_workers)
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def f():
|
def f():
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
result = tensor_array_ops.TensorArray(
|
result = tensor_array_ops.TensorArray(
|
||||||
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
||||||
i = 0
|
i = 0
|
||||||
@ -279,10 +283,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobName(self):
|
def testSharedJobName(self):
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
iter2 = iter(ds2)
|
iter2 = iter(ds2)
|
||||||
results = []
|
results = []
|
||||||
@ -298,20 +302,22 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDifferentJobNames(self):
|
def testDifferentJobNames(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1")
|
ds1 = _make_distributed_dataset(
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2")
|
ds, dispatcher_address, job_name="job_name1")
|
||||||
|
ds2 = _make_distributed_dataset(
|
||||||
|
ds, dispatcher_address, job_name="job_name2")
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedJobNameMultiIteration(self):
|
def testSharedJobNameMultiIteration(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
# iteration 1
|
# iteration 1
|
||||||
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
self.assertDatasetProduces(ds2, [])
|
self.assertDatasetProduces(ds2, [])
|
||||||
@ -323,11 +329,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testSharedJobNameRepeat(self):
|
def testSharedJobNameRepeat(self):
|
||||||
num_elements = 100
|
num_elements = 100
|
||||||
num_repetitions = 3
|
num_repetitions = 3
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds1 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds1 = ds1.repeat(num_repetitions)
|
ds1 = ds1.repeat(num_repetitions)
|
||||||
ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name")
|
ds2 = _make_distributed_dataset(ds, dispatcher_address, job_name="job_name")
|
||||||
ds2 = ds2.repeat(num_repetitions)
|
ds2 = ds2.repeat(num_repetitions)
|
||||||
results = []
|
results = []
|
||||||
iter1 = iter(ds1)
|
iter1 = iter(ds1)
|
||||||
@ -345,7 +351,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testApplyDeterminismOption(self):
|
def testApplyDeterminismOption(self):
|
||||||
elements = list(range(10))
|
elements = list(range(10))
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
|
|
||||||
def dataset_fn(delay_ms):
|
def dataset_fn(delay_ms):
|
||||||
|
|
||||||
@ -362,7 +368,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
opts = dataset_ops.Options()
|
opts = dataset_ops.Options()
|
||||||
opts.experimental_deterministic = False
|
opts.experimental_deterministic = False
|
||||||
ds = ds.with_options(opts)
|
ds = ds.with_options(opts)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
self.checkDeterminism(
|
self.checkDeterminism(
|
||||||
@ -379,8 +385,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
options.experimental_external_state_policy = external_state_policy
|
options.experimental_external_state_policy = external_state_policy
|
||||||
ds = ds.with_options(options)
|
ds = ds.with_options(options)
|
||||||
|
|
||||||
master_address = self.create_cluster(3)
|
dispatcher_address = self.create_cluster(3)
|
||||||
ds = _make_distributed_dataset(ds, master_address)
|
ds = _make_distributed_dataset(ds, dispatcher_address)
|
||||||
next(iter(ds))
|
next(iter(ds))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
@ -400,12 +406,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeFromInterleave(self):
|
def testDistributeFromInterleave(self):
|
||||||
master_address = self.create_cluster(1)
|
dispatcher_address = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(2)
|
ds = dataset_ops.Dataset.range(2)
|
||||||
|
|
||||||
def interleave_fn(_):
|
def interleave_fn(_):
|
||||||
ds = dataset_ops.Dataset.range(2)
|
ds = dataset_ops.Dataset.range(2)
|
||||||
_make_distributed_dataset(ds, master_address)
|
_make_distributed_dataset(ds, dispatcher_address)
|
||||||
return ds
|
return ds
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
@ -123,6 +123,15 @@ class TPUTest(test.TestCase):
|
|||||||
result = bar() + 1
|
result = bar() + 1
|
||||||
self.assertAllEqual(result, 2)
|
self.assertAllEqual(result, 2)
|
||||||
|
|
||||||
|
def test_on_demand_op_with_dynamic_output(self):
|
||||||
|
with ops.device("/device:TPU:0"):
|
||||||
|
where_output = array_ops.where([True, False, True])
|
||||||
|
self.assertAllEqual(where_output, [[0], [2]])
|
||||||
|
|
||||||
|
with ops.device("/device:TPU:0"):
|
||||||
|
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
|
||||||
|
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||||
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
@ -4690,7 +4690,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||||||
labels=target, logits=output, axis=axis)
|
labels=target, logits=output, axis=axis)
|
||||||
|
|
||||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output.op.type == 'Softmax'):
|
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||||
# When softmax activation function is used for output operation, we
|
# When softmax activation function is used for output operation, we
|
||||||
# use logits from the softmax function directly to compute loss in order
|
# use logits from the softmax function directly to compute loss in order
|
||||||
# to prevent collapsing zero when training.
|
# to prevent collapsing zero when training.
|
||||||
@ -4735,7 +4735,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
|
|||||||
|
|
||||||
if (not from_logits and
|
if (not from_logits and
|
||||||
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output.op.type == 'Softmax'):
|
output.op.type == 'Softmax') and not hasattr(output, '_keras_history'):
|
||||||
# When softmax activation function is used for output operation, we
|
# When softmax activation function is used for output operation, we
|
||||||
# use logits from the softmax function directly to compute loss in order
|
# use logits from the softmax function directly to compute loss in order
|
||||||
# to prevent collapsing zero when training.
|
# to prevent collapsing zero when training.
|
||||||
@ -4814,7 +4814,7 @@ def binary_crossentropy(target, output, from_logits=False):
|
|||||||
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
|
||||||
|
|
||||||
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and
|
||||||
output.op.type == 'Sigmoid'):
|
output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'):
|
||||||
# When sigmoid activation function is used for output operation, we
|
# When sigmoid activation function is used for output operation, we
|
||||||
# use logits from the sigmoid function directly to compute loss in order
|
# use logits from the sigmoid function directly to compute loss in order
|
||||||
# to prevent collapsing zero when training.
|
# to prevent collapsing zero when training.
|
||||||
|
@ -665,8 +665,9 @@ class Callback(object):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: Integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict, contains the return value of `model.train_step`. Typically,
|
||||||
number and the size of the batch.
|
the values of the `Model`'s metrics are returned. Example:
|
||||||
|
`{'loss': 0.2, 'accuracy': 0.7}`.
|
||||||
"""
|
"""
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
self.on_batch_begin(batch, logs=logs)
|
self.on_batch_begin(batch, logs=logs)
|
||||||
@ -697,8 +698,9 @@ class Callback(object):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: Integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict, contains the return value of `model.test_step`. Typically,
|
||||||
number and the size of the batch.
|
the values of the `Model`'s metrics are returned. Example:
|
||||||
|
`{'loss': 0.2, 'accuracy': 0.7}`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
@ -725,8 +727,9 @@ class Callback(object):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
batch: Integer, index of batch within the current epoch.
|
batch: Integer, index of batch within the current epoch.
|
||||||
logs: Dict. Has keys `batch` and `size` representing the current batch
|
logs: Dict, contains the return value of `model.predict_step`,
|
||||||
number and the size of the batch.
|
it typically returns a dict with a key 'outputs' containing
|
||||||
|
the model's outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@doc_controls.for_subclass_implementers
|
@doc_controls.for_subclass_implementers
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import zipfile
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
@ -43,6 +44,8 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
|
|||||||
def skip_fetch_failure_exception(self):
|
def skip_fetch_failure_exception(self):
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
except zipfile.BadZipfile as e:
|
||||||
|
self.skipTest('Data loading error: Bad magic number for file header.')
|
||||||
except Exception as e: # pylint: disable=broad-except
|
except Exception as e: # pylint: disable=broad-except
|
||||||
if 'URL fetch failure' in str(e):
|
if 'URL fetch failure' in str(e):
|
||||||
self.skipTest('URL fetch error not considered failure of the test.')
|
self.skipTest('URL fetch error not considered failure of the test.')
|
||||||
|
@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# >> inputs = tf.keras.Input(10)
|
# >> inputs = tf.keras.Input(10)
|
||||||
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
||||||
# >> model = tf.keras.Model(inputs, outputs)
|
# >> model = tf.keras.Model(inputs, outputs)
|
||||||
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
|
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
|
||||||
return self._functional_construction_call(inputs, args, kwargs,
|
return self._functional_construction_call(inputs, args, kwargs,
|
||||||
input_list)
|
input_list)
|
||||||
|
|
||||||
@ -2891,9 +2891,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
self._expects_training_arg = ('training' in call_fn_args or
|
self._expects_training_arg = ('training' in call_fn_args or
|
||||||
self._call_accepts_kwargs)
|
self._call_accepts_kwargs)
|
||||||
# The default training arg will be any (non-None) default specified in the
|
# The default training arg will be any (non-None) default specified in the
|
||||||
# method signature, or `False` if no non-None default is specified.
|
# method signature, or None if no value is specified.
|
||||||
self._default_training_arg = self._call_fn_arg_defaults.get(
|
self._default_training_arg = self._call_fn_arg_defaults.get(
|
||||||
'training') or False
|
'training')
|
||||||
self._expects_mask_arg = ('mask' in call_fn_args or
|
self._expects_mask_arg = ('mask' in call_fn_args or
|
||||||
self._call_accepts_kwargs)
|
self._call_accepts_kwargs)
|
||||||
|
|
||||||
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||||
"""Check the arguments to see if we are constructing a functional model."""
|
"""Check the arguments to see if we are constructing a functional model."""
|
||||||
if keras_tensor.keras_tensors_enabled():
|
if keras_tensor.keras_tensors_enabled():
|
||||||
# We are constructing a functional model if any of the inputs
|
# We are constructing a functional model if any of the inputs
|
||||||
@ -3215,7 +3215,20 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
|||||||
for tensor in nest.flatten([inputs, args, kwargs]))
|
for tensor in nest.flatten([inputs, args, kwargs]))
|
||||||
else:
|
else:
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
all_inputs_symbolic = all(
|
||||||
|
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||||
|
if (base_layer_utils.is_subclassed(layer) and
|
||||||
|
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||||
|
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||||
|
raise ValueError('It appears you are trying to construct a '
|
||||||
|
'functional model, but not all of the inputs in '
|
||||||
|
'the first positional argument of your layer call '
|
||||||
|
'are symbolic tensors. '
|
||||||
|
'(Input objects, or the output of another layer) '
|
||||||
|
'Functional models cannot correctly track custom '
|
||||||
|
'layers unless all values in the first call argument '
|
||||||
|
'are symbolic.')
|
||||||
|
return all_inputs_symbolic
|
||||||
else:
|
else:
|
||||||
return (base_layer_utils.is_in_keras_graph() or
|
return (base_layer_utils.is_in_keras_graph() or
|
||||||
all(hasattr(t, '_keras_history') for t in input_list))
|
all(hasattr(t, '_keras_history') for t in input_list))
|
||||||
|
@ -650,6 +650,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
return self._nested_layer(inputs) * 0.5
|
return self._nested_layer(inputs) * 0.5
|
||||||
|
|
||||||
|
class CustomLayerDefaultTrainingNone(base_layer.Layer):
|
||||||
|
|
||||||
|
def __init__(self, nested_layer=None):
|
||||||
|
self._nested_layer = nested_layer or array_ops.identity
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
return self._nested_layer(inputs)
|
||||||
|
else:
|
||||||
|
return self._nested_layer(inputs) * 0.5
|
||||||
|
|
||||||
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
||||||
|
|
||||||
def __init__(self, nested_layer=None):
|
def __init__(self, nested_layer=None):
|
||||||
@ -701,21 +712,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||||||
# Outer layers/models should set the training context implicitly for all
|
# Outer layers/models should set the training context implicitly for all
|
||||||
# nested layers, respecting whatever mode the outer layer was run with.
|
# nested layers, respecting whatever mode the outer layer was run with.
|
||||||
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
||||||
self.assertAllEqual(layer(x), x)
|
# No outer value passed: use local defaults
|
||||||
|
self.assertAllEqual(layer(x), x * 0.25) # Use local default False
|
||||||
|
# Outer value passed: override local defaults
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
||||||
self.assertAllEqual(layer(x), x * 0.25)
|
# No outer value passed: use local defaults
|
||||||
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
|
# Outer value passed: override local defaults
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
# If the outer layer `call` doesn't take a training argument at all,
|
# If the outer layer `call` doesn't take a training argument at all,
|
||||||
# it'll set the nested scope as inference when no training arg is passed in.
|
# it'll set the nested scope as None when no training arg is passed in.
|
||||||
# If a training arg is passed in it won't use it directly in `call`, but
|
# If a training arg is passed in it won't use it directly in `call`, but
|
||||||
# it will set the nested training mode.
|
# it will set the nested training mode.
|
||||||
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
||||||
self.assertAllEqual(layer(x), x * 0.5)
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
|
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||||
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
|
layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue())
|
||||||
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
|
@ -252,6 +252,9 @@ class Layer(base_layer.Layer):
|
|||||||
# might want to turn it off, like Sequential model.
|
# might want to turn it off, like Sequential model.
|
||||||
self._auto_track_sub_layers = True
|
self._auto_track_sub_layers = True
|
||||||
|
|
||||||
|
# Mark this layer as having been originally built as a tf1 layer/model
|
||||||
|
self._originally_built_as_v1 = True
|
||||||
|
|
||||||
@trackable.no_automatic_dependency_tracking
|
@trackable.no_automatic_dependency_tracking
|
||||||
@generic_utils.default
|
@generic_utils.default
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
@ -651,6 +654,8 @@ class Layer(base_layer.Layer):
|
|||||||
ValueError: if the layer's `call` method returns None (an invalid value).
|
ValueError: if the layer's `call` method returns None (an invalid value).
|
||||||
RuntimeError: if `super().__init__()` was not called in the constructor.
|
RuntimeError: if `super().__init__()` was not called in the constructor.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
|
|
||||||
if not hasattr(self, '_thread_local'):
|
if not hasattr(self, '_thread_local'):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'You must call `super().__init__()` in the layer constructor.')
|
'You must call `super().__init__()` in the layer constructor.')
|
||||||
@ -818,6 +823,20 @@ class Layer(base_layer.Layer):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def _assert_built_as_v1(self):
|
||||||
|
if not hasattr(self, '_originally_built_as_v1'):
|
||||||
|
raise ValueError(
|
||||||
|
'Your Layer or Model is in an invalid state. This can happen if you '
|
||||||
|
'are interleaving estimator/non-estimator models or '
|
||||||
|
'interleaving models/layers made in tf.compat.v1.Graph.as_default() '
|
||||||
|
'with models/layers created outside of it. '
|
||||||
|
'Converting a model to an estimator (via model_to_estimator) '
|
||||||
|
'invalidates all models/layers made before the conversion (even '
|
||||||
|
'if they were not the model converted to an estimator). '
|
||||||
|
'Similarly, making a layer or a model inside a '
|
||||||
|
'a tf.compat.v1.Graph invalidates all layers/models you previously '
|
||||||
|
'made outside of the graph.')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self._dtype_policy.variable_dtype
|
return self._dtype_policy.variable_dtype
|
||||||
|
@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations
|
|||||||
from tensorflow.python.keras import initializers
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import layers
|
from tensorflow.python.keras import layers
|
||||||
|
from tensorflow.python.keras import losses
|
||||||
from tensorflow.python.keras import models
|
from tensorflow.python.keras import models
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
@ -931,6 +932,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
# Check that second input was correctly added to first.
|
# Check that second input was correctly added to first.
|
||||||
self.assertEqual(history.history['loss'][0], 0.0)
|
self.assertEqual(history.history['loss'][0], 0.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.times(
|
||||||
|
combinations.keras_mode_combinations(mode='eager'),
|
||||||
|
combinations.combine(use_keras_tensors=False)))
|
||||||
|
def test_only_some_in_first_arg_derived_from_keras_layer(self):
|
||||||
|
class MyAddAll(layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
x = inputs[0]
|
||||||
|
for inp in inputs[1:]:
|
||||||
|
if inp is not None:
|
||||||
|
x = x + inp
|
||||||
|
return x
|
||||||
|
|
||||||
|
input1 = input_layer_lib.Input(10)
|
||||||
|
input2 = input_layer_lib.Input(10)
|
||||||
|
layer = MyAddAll()
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
|
||||||
|
layer([0.0, input1, None, input2, None])
|
||||||
|
|
||||||
|
@combinations.generate(combinations.times(
|
||||||
|
combinations.keras_mode_combinations(mode='eager'),
|
||||||
|
combinations.combine(use_keras_tensors=True)))
|
||||||
|
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
|
||||||
|
# This functionality is unsupported in v1 graphs
|
||||||
|
|
||||||
|
class MyAddAll(layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
x = inputs[0]
|
||||||
|
for inp in inputs[1:]:
|
||||||
|
if inp is not None:
|
||||||
|
x = x + inp
|
||||||
|
return x
|
||||||
|
|
||||||
|
input1 = input_layer_lib.Input(10)
|
||||||
|
input2 = input_layer_lib.Input(10)
|
||||||
|
layer = MyAddAll()
|
||||||
|
outputs = layer([0.0, input1, None, input2, None])
|
||||||
|
model = training_lib.Model([input1, input2], outputs)
|
||||||
|
self.assertIn(layer, model.layers)
|
||||||
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
history = model.fit(
|
||||||
|
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||||
|
y=10 * np.ones((10, 10)),
|
||||||
|
batch_size=2)
|
||||||
|
# Check that second input was correctly added to first.
|
||||||
|
self.assertEqual(history.history['loss'][0], 0.0)
|
||||||
|
|
||||||
|
# Check serialization.
|
||||||
|
model = training_lib.Model.from_config(
|
||||||
|
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
|
||||||
|
model.compile(
|
||||||
|
'sgd',
|
||||||
|
'mse',
|
||||||
|
run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
history = model.fit(
|
||||||
|
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||||
|
y=10 * np.ones((10, 10)),
|
||||||
|
batch_size=2)
|
||||||
|
# Check that second input was correctly added to first.
|
||||||
|
self.assertEqual(history.history['loss'][0], 0.0)
|
||||||
|
|
||||||
@combinations.generate(combinations.keras_mode_combinations())
|
@combinations.generate(combinations.keras_mode_combinations())
|
||||||
def test_call_kwarg_derived_from_keras_layer(self):
|
def test_call_kwarg_derived_from_keras_layer(self):
|
||||||
|
|
||||||
@ -1069,7 +1136,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
input2 = input_layer_lib.Input(10)
|
input2 = input_layer_lib.Input(10)
|
||||||
input3 = input_layer_lib.Input(10)
|
input3 = input_layer_lib.Input(10)
|
||||||
|
|
||||||
outputs = AddAll()(
|
layer = AddAll()
|
||||||
|
outputs = layer(
|
||||||
[input1, 4 * array_ops.ones((1, 10))],
|
[input1, 4 * array_ops.ones((1, 10))],
|
||||||
x3={
|
x3={
|
||||||
'a': input2,
|
'a': input2,
|
||||||
@ -1077,6 +1145,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||||||
'c': 5 * array_ops.ones((1, 10))
|
'c': 5 * array_ops.ones((1, 10))
|
||||||
})
|
})
|
||||||
model = training_lib.Model([input1, input2, input3], outputs)
|
model = training_lib.Model([input1, input2, input3], outputs)
|
||||||
|
self.assertIn(layer, model.layers)
|
||||||
model.compile(
|
model.compile(
|
||||||
'sgd',
|
'sgd',
|
||||||
'mse',
|
'mse',
|
||||||
@ -1833,6 +1902,37 @@ class AddLossTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(model.get_weights(), model2.get_weights())
|
self.assertAllClose(model.get_weights(), model2.get_weights())
|
||||||
|
|
||||||
|
def test_add_loss_crossentropy_backtracking(self):
|
||||||
|
inputs = input_layer_lib.Input((2,))
|
||||||
|
labels = input_layer_lib.Input((1,))
|
||||||
|
outputs = layers.Dense(1, activation='sigmoid')(inputs)
|
||||||
|
model = functional.Functional([inputs, labels], outputs)
|
||||||
|
model.add_loss(losses.binary_crossentropy(labels, outputs))
|
||||||
|
model.compile('adam')
|
||||||
|
x = np.random.random((2, 2))
|
||||||
|
y = np.random.random((2, 1))
|
||||||
|
model.fit([x, y])
|
||||||
|
|
||||||
|
inputs = input_layer_lib.Input((2,))
|
||||||
|
labels = input_layer_lib.Input((2,))
|
||||||
|
outputs = layers.Dense(2, activation='softmax')(inputs)
|
||||||
|
model = functional.Functional([inputs, labels], outputs)
|
||||||
|
model.add_loss(losses.categorical_crossentropy(labels, outputs))
|
||||||
|
model.compile('adam')
|
||||||
|
x = np.random.random((2, 2))
|
||||||
|
y = np.random.random((2, 2))
|
||||||
|
model.fit([x, y])
|
||||||
|
|
||||||
|
inputs = input_layer_lib.Input((2,))
|
||||||
|
labels = input_layer_lib.Input((1,), dtype='int32')
|
||||||
|
outputs = layers.Dense(2, activation='softmax')(inputs)
|
||||||
|
model = functional.Functional([inputs, labels], outputs)
|
||||||
|
model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs))
|
||||||
|
model.compile('adam')
|
||||||
|
x = np.random.random((2, 2))
|
||||||
|
y = np.random.randint(0, 2, size=(2, 1))
|
||||||
|
model.fit([x, y])
|
||||||
|
|
||||||
|
|
||||||
@combinations.generate(combinations.keras_mode_combinations())
|
@combinations.generate(combinations.keras_mode_combinations())
|
||||||
class WeightAccessTest(keras_parameterized.TestCase):
|
class WeightAccessTest(keras_parameterized.TestCase):
|
||||||
@ -2116,13 +2216,13 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# In v2, construction still works when no `training` is specified
|
# In v2, construction still works when no `training` is specified
|
||||||
# When no value passed during construction, it uses the runtime value.
|
# When no value passed during construction, it uses the local default.
|
||||||
inputs = input_layer_lib.Input(10)
|
inputs = input_layer_lib.Input(10)
|
||||||
outputs = my_layer(inputs)
|
outputs = my_layer(inputs)
|
||||||
network = functional.Functional(inputs, outputs)
|
network = functional.Functional(inputs, outputs)
|
||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
|
|
||||||
# `None` value passed positionally during construction is ignored at runtime
|
# `None` value passed positionally during construction is ignored at runtime
|
||||||
inputs = input_layer_lib.Input(10)
|
inputs = input_layer_lib.Input(10)
|
||||||
@ -2131,7 +2231,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
else:
|
else:
|
||||||
# in v1 training would have defaulted to using the `None` inside the layer
|
# in v1 training would have defaulted to using the `None` inside the layer
|
||||||
# if training is not passed at runtime
|
# if training is not passed at runtime
|
||||||
@ -2144,7 +2244,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
else:
|
else:
|
||||||
# in v1 training would have defaulted to using the `None` inside the layer
|
# in v1 training would have defaulted to using the `None` inside the layer
|
||||||
# if training is not passed at runtime
|
# if training is not passed at runtime
|
||||||
|
@ -303,6 +303,7 @@ class Model(training_lib.Model):
|
|||||||
ValueError: In case of invalid arguments for
|
ValueError: In case of invalid arguments for
|
||||||
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
`optimizer`, `loss`, `metrics` or `sample_weight_mode`.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
self._run_eagerly = kwargs.pop('run_eagerly', None)
|
||||||
self._experimental_run_tf_function = kwargs.pop(
|
self._experimental_run_tf_function = kwargs.pop(
|
||||||
'experimental_run_tf_function', True)
|
'experimental_run_tf_function', True)
|
||||||
@ -773,6 +774,7 @@ class Model(training_lib.Model):
|
|||||||
ValueError: In case of mismatch between the provided input data
|
ValueError: In case of mismatch between the provided input data
|
||||||
and what the model expects.
|
and what the model expects.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('fit_v1').set(True)
|
_keras_api_gauge.get_cell('fit_v1').set(True)
|
||||||
# Legacy support
|
# Legacy support
|
||||||
if 'nb_epoch' in kwargs:
|
if 'nb_epoch' in kwargs:
|
||||||
@ -893,6 +895,7 @@ class Model(training_lib.Model):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid arguments.
|
ValueError: in case of invalid arguments.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('evaluate_v1').set(True)
|
_keras_api_gauge.get_cell('evaluate_v1').set(True)
|
||||||
self._assert_compile_was_called()
|
self._assert_compile_was_called()
|
||||||
self._check_call_args('evaluate')
|
self._check_call_args('evaluate')
|
||||||
@ -972,6 +975,7 @@ class Model(training_lib.Model):
|
|||||||
or in case a stateful model receives a number of samples
|
or in case a stateful model receives a number of samples
|
||||||
that is not a multiple of the batch size.
|
that is not a multiple of the batch size.
|
||||||
"""
|
"""
|
||||||
|
self._assert_built_as_v1()
|
||||||
_keras_api_gauge.get_cell('predict_v1').set(True)
|
_keras_api_gauge.get_cell('predict_v1').set(True)
|
||||||
self._check_call_args('predict')
|
self._check_call_args('predict')
|
||||||
|
|
||||||
|
@ -132,8 +132,7 @@ class Embedding(Layer):
|
|||||||
# right now. Checking for the presence of GPUs to avoid complicating the
|
# right now. Checking for the presence of GPUs to avoid complicating the
|
||||||
# TPU codepaths which can handle sparse optimizers. But if we are within
|
# TPU codepaths which can handle sparse optimizers. But if we are within
|
||||||
# a tf.function, we go back the graph mode logic and rely on the placer.
|
# a tf.function, we go back the graph mode logic and rely on the placer.
|
||||||
if (context.executing_eagerly() and context.context().num_gpus() and
|
if context.executing_eagerly() and context.context().num_gpus():
|
||||||
not ops.inside_function()):
|
|
||||||
with ops.device('cpu:0'):
|
with ops.device('cpu:0'):
|
||||||
self.embeddings = self.add_weight(
|
self.embeddings = self.add_weight(
|
||||||
shape=(self.input_dim, self.output_dim),
|
shape=(self.input_dim, self.output_dim),
|
||||||
|
@ -40,9 +40,10 @@ class ImagePreprocessingDistributionTest(
|
|||||||
preprocessing_test_utils.PreprocessingLayerTest):
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
|
||||||
def test_distribution(self, distribution):
|
def test_distribution(self, distribution):
|
||||||
np_images = np.random.random((1000, 32, 32, 3)).astype(np.float32)
|
# TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
|
||||||
|
np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
|
||||||
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
|
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
|
||||||
32, drop_remainder=True)
|
16, drop_remainder=True)
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
input_data = keras.Input(shape=(32, 32, 3), dtype=dtypes.float32)
|
input_data = keras.Input(shape=(32, 32, 3), dtype=dtypes.float32)
|
||||||
@ -58,7 +59,7 @@ class ImagePreprocessingDistributionTest(
|
|||||||
output = flatten_layer(preprocessed_image)
|
output = flatten_layer(preprocessed_image)
|
||||||
cls_layer = keras.layers.Dense(units=1, activation="sigmoid")
|
cls_layer = keras.layers.Dense(units=1, activation="sigmoid")
|
||||||
output = cls_layer(output)
|
output = cls_layer(output)
|
||||||
model = keras.Model(inputs=input_data, outputs=preprocessed_image)
|
model = keras.Model(inputs=input_data, outputs=output)
|
||||||
model.compile(loss="binary_crossentropy")
|
model.compile(loss="binary_crossentropy")
|
||||||
_ = model.predict(image_dataset)
|
_ = model.predict(image_dataset)
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
@ -1273,5 +1274,38 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(layer_1.name, layer.name)
|
self.assertEqual(layer_1.name, layer.name)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
class LearningPhaseTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_plain_call(self):
|
||||||
|
layer = image_preprocessing.RandomWidth(.5, seed=123)
|
||||||
|
shape = (12, 12, 3)
|
||||||
|
img = np.random.random((12,) + shape)
|
||||||
|
out = layer(img) # Default to training=True
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = layer(img, training=True)
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = layer(img, training=False)
|
||||||
|
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
def test_call_in_container(self):
|
||||||
|
layer1 = image_preprocessing.RandomWidth(.5, seed=123)
|
||||||
|
layer2 = image_preprocessing.RandomHeight(.5, seed=123)
|
||||||
|
seq = sequential.Sequential([layer1, layer2])
|
||||||
|
|
||||||
|
shape = (12, 12, 3)
|
||||||
|
img = np.random.random((12,) + shape)
|
||||||
|
out = seq(img) # Default to training=True
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = seq(img, training=True)
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = seq(img, training=False)
|
||||||
|
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -440,7 +440,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
|||||||
if not strategy_supports_loss_scaling():
|
if not strategy_supports_loss_scaling():
|
||||||
strategy = distribution_strategy_context.get_strategy()
|
strategy = distribution_strategy_context.get_strategy()
|
||||||
if isinstance(strategy,
|
if isinstance(strategy,
|
||||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
|
||||||
|
tpu_strategy.TPUStrategyV2)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
|
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
|
||||||
'unnecessary with TPUs, since they support bfloat16 instead of '
|
'unnecessary with TPUs, since they support bfloat16 instead of '
|
||||||
|
@ -4579,11 +4579,11 @@ def non_max_suppression_padded(boxes,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: When set pad_to_max_output_size to False for batched input.
|
ValueError: When set pad_to_max_output_size to False for batched input.
|
||||||
"""
|
"""
|
||||||
# if no new arguments are used and no later than 2020/4/20, use the old
|
# if no new arguments are used and no later than 2020/6/23, use the old
|
||||||
# version to give us time to fix TFLite conversion
|
# version to give us time to fix TFLite conversion after the TF 2.3 release.
|
||||||
if (not sorted_input) and \
|
if (not sorted_input) and \
|
||||||
(not canonicalized_coordinates) and \
|
(not canonicalized_coordinates) and \
|
||||||
tile_size == 512 and not compat.forward_compatible(2020, 4, 20):
|
tile_size == 512 and not compat.forward_compatible(2020, 6, 23):
|
||||||
return non_max_suppression_padded_v1(
|
return non_max_suppression_padded_v1(
|
||||||
boxes, scores, max_output_size, iou_threshold, score_threshold,
|
boxes, scores, max_output_size, iou_threshold, score_threshold,
|
||||||
pad_to_max_output_size, name)
|
pad_to_max_output_size, name)
|
||||||
|
@ -1870,25 +1870,27 @@ class MutableHashTable(LookupInterface):
|
|||||||
return {
|
return {
|
||||||
"table":
|
"table":
|
||||||
functools.partial(
|
functools.partial(
|
||||||
MutableHashTable._Saveable, table=self, name=self._name)
|
MutableHashTable._Saveable, table=self, name=self._name,
|
||||||
|
table_name=self._name)
|
||||||
}
|
}
|
||||||
|
|
||||||
class _Saveable(BaseSaverBuilder.SaveableObject):
|
class _Saveable(BaseSaverBuilder.SaveableObject):
|
||||||
"""SaveableObject implementation for MutableHashTable."""
|
"""SaveableObject implementation for DenseHashTable."""
|
||||||
|
|
||||||
def __init__(self, table, name):
|
def __init__(self, table, name, table_name=None):
|
||||||
tensors = table.export()
|
tensors = table.export()
|
||||||
specs = [
|
specs = [
|
||||||
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
||||||
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
||||||
]
|
]
|
||||||
|
self.table_name = table_name or name
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
|
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes, name=None):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
del restored_shapes # unused
|
del restored_shapes # unused
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
with ops.name_scope(name, "%s_table_restore" % self.name):
|
with ops.name_scope("%s_table_restore" % self.table_name):
|
||||||
with ops.colocate_with(self.op.resource_handle):
|
with ops.colocate_with(self.op.resource_handle):
|
||||||
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
||||||
restored_tensors[0],
|
restored_tensors[0],
|
||||||
@ -2166,25 +2168,27 @@ class DenseHashTable(LookupInterface):
|
|||||||
return {
|
return {
|
||||||
"table":
|
"table":
|
||||||
functools.partial(
|
functools.partial(
|
||||||
DenseHashTable._Saveable, table=self, name=self._name)
|
DenseHashTable._Saveable, table=self, name=self._name,
|
||||||
|
table_name=self._name)
|
||||||
}
|
}
|
||||||
|
|
||||||
class _Saveable(BaseSaverBuilder.SaveableObject):
|
class _Saveable(BaseSaverBuilder.SaveableObject):
|
||||||
"""SaveableObject implementation for DenseHashTable."""
|
"""SaveableObject implementation for DenseHashTable."""
|
||||||
|
|
||||||
def __init__(self, table, name):
|
def __init__(self, table, name, table_name=None):
|
||||||
tensors = table.export()
|
tensors = table.export()
|
||||||
specs = [
|
specs = [
|
||||||
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
||||||
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
||||||
]
|
]
|
||||||
|
self.table_name = table_name or name
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
super(DenseHashTable._Saveable, self).__init__(table, specs, name)
|
super(DenseHashTable._Saveable, self).__init__(table, specs, name)
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes, name=None):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
del restored_shapes # unused
|
del restored_shapes # unused
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
with ops.name_scope(name, "%s_table_restore" % self.name):
|
with ops.name_scope("%s_table_restore" % self.table_name):
|
||||||
with ops.colocate_with(self.op.resource_handle):
|
with ops.colocate_with(self.op.resource_handle):
|
||||||
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
||||||
restored_tensors[0],
|
restored_tensors[0],
|
||||||
|
@ -45,6 +45,7 @@ from tensorflow.python.saved_model import nested_structure_coder
|
|||||||
from tensorflow.python.saved_model import revived_types
|
from tensorflow.python.saved_model import revived_types
|
||||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||||
from tensorflow.python.training.saving import checkpoint_options
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
@ -146,6 +147,18 @@ class Loader(object):
|
|||||||
self._setup_functions_structures()
|
self._setup_functions_structures()
|
||||||
self._setup_functions_captures()
|
self._setup_functions_captures()
|
||||||
|
|
||||||
|
self._create_saveable_object_factories()
|
||||||
|
|
||||||
|
def _create_saveable_object_factories(self):
|
||||||
|
for node_id, proto in enumerate(self._proto.nodes):
|
||||||
|
node = self.get(node_id)
|
||||||
|
node._self_saveable_object_factories = {} # pylint: disable=protected-access
|
||||||
|
for name, saveable_object_proto in proto.saveable_objects.items():
|
||||||
|
node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access
|
||||||
|
saveable_object_util.restored_saved_object_factory(
|
||||||
|
self.get(saveable_object_proto.save_function),
|
||||||
|
self.get(saveable_object_proto.restore_function)))
|
||||||
|
|
||||||
def _load_edges(self):
|
def _load_edges(self):
|
||||||
"""Adds edges from objects to other objects and functions."""
|
"""Adds edges from objects to other objects and functions."""
|
||||||
for node_id, object_proto in enumerate(self._proto.nodes):
|
for node_id, object_proto in enumerate(self._proto.nodes):
|
||||||
|
@ -1795,6 +1795,22 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
options = load_options.LoadOptions(experimental_io_device="/job:localhost")
|
options = load_options.LoadOptions(experimental_io_device="/job:localhost")
|
||||||
self.assertEqual("/job:localhost", options.experimental_io_device)
|
self.assertEqual("/job:localhost", options.experimental_io_device)
|
||||||
|
|
||||||
|
def test_load_custom_saveable_object(self, cycles):
|
||||||
|
root = tracking.AutoTrackable()
|
||||||
|
root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
|
||||||
|
root.table.insert("foo", 15)
|
||||||
|
|
||||||
|
@def_function.function(
|
||||||
|
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
|
||||||
|
def lookup(key):
|
||||||
|
return root.table.lookup(key)
|
||||||
|
|
||||||
|
root.lookup = lookup
|
||||||
|
|
||||||
|
imported = cycle(root, cycles)
|
||||||
|
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
|
||||||
|
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
|
||||||
|
|
||||||
|
|
||||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.core.framework import versions_pb2
|
from tensorflow.core.framework import versions_pb2
|
||||||
@ -53,6 +54,7 @@ from tensorflow.python.saved_model import tag_constants
|
|||||||
from tensorflow.python.saved_model import utils_impl
|
from tensorflow.python.saved_model import utils_impl
|
||||||
from tensorflow.python.training.saving import checkpoint_options
|
from tensorflow.python.training.saving import checkpoint_options
|
||||||
from tensorflow.python.training.saving import functional_saver
|
from tensorflow.python.training.saving import functional_saver
|
||||||
|
from tensorflow.python.training.saving import saveable_object_util
|
||||||
from tensorflow.python.training.tracking import base
|
from tensorflow.python.training.tracking import base
|
||||||
from tensorflow.python.training.tracking import graph_view
|
from tensorflow.python.training.tracking import graph_view
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
@ -136,12 +138,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
|||||||
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
||||||
self._serialization_cache)
|
self._serialization_cache)
|
||||||
|
|
||||||
def list_functions(self, obj):
|
def list_functions(self, obj, extra_functions=None):
|
||||||
obj_functions = self._functions.get(obj, None)
|
obj_functions = self._functions.get(obj, None)
|
||||||
if obj_functions is None:
|
if obj_functions is None:
|
||||||
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
||||||
self._serialization_cache)
|
self._serialization_cache)
|
||||||
self._functions[obj] = obj_functions
|
self._functions[obj] = obj_functions
|
||||||
|
if extra_functions:
|
||||||
|
obj_functions = obj_functions.copy()
|
||||||
|
obj_functions.update(extra_functions)
|
||||||
return obj_functions
|
return obj_functions
|
||||||
|
|
||||||
|
|
||||||
@ -177,6 +182,12 @@ class _SaveableView(object):
|
|||||||
self.slot_variables = slot_variables
|
self.slot_variables = slot_variables
|
||||||
self.concrete_functions = []
|
self.concrete_functions = []
|
||||||
|
|
||||||
|
self.saveable_objects_for_node, all_saveable_functions = (
|
||||||
|
self._add_saveable_objects())
|
||||||
|
saveable_object_functions = {
|
||||||
|
"__SAVEABLE_FUNCTION_{}".format(n): fn
|
||||||
|
for n, fn in enumerate(all_saveable_functions)}
|
||||||
|
|
||||||
# Maps functions -> wrapped functions that capture variables
|
# Maps functions -> wrapped functions that capture variables
|
||||||
self.wrapped_functions = wrapped_functions or {}
|
self.wrapped_functions = wrapped_functions or {}
|
||||||
# Maps names of concrete functions in the object to names of wrapped
|
# Maps names of concrete functions in the object to names of wrapped
|
||||||
@ -190,7 +201,8 @@ class _SaveableView(object):
|
|||||||
nodes_without_functions = list(self.nodes)
|
nodes_without_functions = list(self.nodes)
|
||||||
seen_function_names = set()
|
seen_function_names = set()
|
||||||
for node in nodes_without_functions:
|
for node in nodes_without_functions:
|
||||||
for function in checkpoint_view.list_functions(node).values():
|
for function in checkpoint_view.list_functions(
|
||||||
|
node, saveable_object_functions).values():
|
||||||
if function not in self.node_ids:
|
if function not in self.node_ids:
|
||||||
self.node_ids[function] = len(self.nodes)
|
self.node_ids[function] = len(self.nodes)
|
||||||
self.nodes.append(function)
|
self.nodes.append(function)
|
||||||
@ -209,6 +221,25 @@ class _SaveableView(object):
|
|||||||
seen_function_names.add(concrete_function.name)
|
seen_function_names.add(concrete_function.name)
|
||||||
self.concrete_functions.append(concrete_function)
|
self.concrete_functions.append(concrete_function)
|
||||||
|
|
||||||
|
def _add_saveable_objects(self):
|
||||||
|
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
||||||
|
# Maps node -> local name -> (save function, restore function)
|
||||||
|
saveable_objects_map = object_identity.ObjectIdentityDictionary()
|
||||||
|
all_saveable_functions = []
|
||||||
|
for node in self.nodes:
|
||||||
|
if resource_variable_ops.is_resource_variable(node):
|
||||||
|
# Resource (and TPU/Mirrored) variables are automatically revived with
|
||||||
|
# their saveables defined, so there is no need to trace the save
|
||||||
|
# and restore functions.
|
||||||
|
continue
|
||||||
|
saveable_map = saveable_object_util.trace_save_restore_functions(node)
|
||||||
|
if saveable_map:
|
||||||
|
saveable_objects_map[node] = saveable_map
|
||||||
|
for save_fn, restore_fn in saveable_map.values():
|
||||||
|
all_saveable_functions.append(save_fn)
|
||||||
|
all_saveable_functions.append(restore_fn)
|
||||||
|
return saveable_objects_map, all_saveable_functions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root(self):
|
def root(self):
|
||||||
return self.nodes[0]
|
return self.nodes[0]
|
||||||
@ -233,6 +264,15 @@ class _SaveableView(object):
|
|||||||
child_proto.node_id = self.node_ids[ref_function]
|
child_proto.node_id = self.node_ids[ref_function]
|
||||||
child_proto.local_name = local_name
|
child_proto.local_name = local_name
|
||||||
|
|
||||||
|
if node not in self.saveable_objects_for_node:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for local_name, (save_fn, restore_fn) in (
|
||||||
|
self.saveable_objects_for_node[node].items()):
|
||||||
|
saveable_object_proto = object_proto.saveable_objects[local_name]
|
||||||
|
saveable_object_proto.save_function = self.node_ids[save_fn]
|
||||||
|
saveable_object_proto.restore_function = self.node_ids[restore_fn]
|
||||||
|
|
||||||
def map_resources(self):
|
def map_resources(self):
|
||||||
"""Makes new resource handle ops corresponding to existing resource tensors.
|
"""Makes new resource handle ops corresponding to existing resource tensors.
|
||||||
|
|
||||||
@ -605,7 +645,9 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
|
|||||||
# the exported graph (thus the `to_graph` argument).
|
# the exported graph (thus the `to_graph` argument).
|
||||||
saver = functional_saver.MultiDeviceSaver(
|
saver = functional_saver.MultiDeviceSaver(
|
||||||
saveable_view.checkpoint_view.frozen_saveable_objects(
|
saveable_view.checkpoint_view.frozen_saveable_objects(
|
||||||
object_map=object_map, to_graph=exported_graph))
|
object_map=object_map, to_graph=exported_graph,
|
||||||
|
call_with_mapped_captures=functools.partial(
|
||||||
|
_call_function_with_mapped_captures, resource_map=resource_map)))
|
||||||
|
|
||||||
with exported_graph.as_default():
|
with exported_graph.as_default():
|
||||||
signatures = _generate_signatures(signature_functions, resource_map)
|
signatures = _generate_signatures(signature_functions, resource_map)
|
||||||
|
@ -1169,7 +1169,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
|||||||
|
|
||||||
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
|
||||||
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
|
||||||
return py::handle(EagerTensorFromHandle(thandle));
|
|
||||||
|
PyObject* pyhandle = EagerTensorFromHandle(thandle);
|
||||||
|
return tensorflow::PyoOrThrow(pyhandle);
|
||||||
});
|
});
|
||||||
|
|
||||||
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
|
||||||
|
@ -265,7 +265,8 @@ class TPUEmbedding(tracking.AutoTrackable):
|
|||||||
Adam or Adagrad).
|
Adam or Adagrad).
|
||||||
"""
|
"""
|
||||||
self._strategy = distribution_strategy_context.get_strategy()
|
self._strategy = distribution_strategy_context.get_strategy()
|
||||||
self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy)
|
self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
|
||||||
|
tpu_strategy.TPUStrategyV2))
|
||||||
self._pipeline_execution_with_tensor_core = (
|
self._pipeline_execution_with_tensor_core = (
|
||||||
pipeline_execution_with_tensor_core)
|
pipeline_execution_with_tensor_core)
|
||||||
|
|
||||||
|
@ -17,15 +17,26 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
|
|
||||||
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.saving import saveable_object
|
from tensorflow.python.training.saving import saveable_object
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -279,7 +290,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
("Two different ResourceVariable objects with the same "
|
("Two different ResourceVariable objects with the same "
|
||||||
"shared_name '%s' were passed to the Saver. This likely means "
|
"shared_name '%s' were passed to the Saver. This likely means "
|
||||||
"that they were created in different Graphs or isolation "
|
"that they were created in different Graphs or isoWlation "
|
||||||
"contexts, and may not be checkpointed together.") %
|
"contexts, and may not be checkpointed together.") %
|
||||||
(var._shared_name,))
|
(var._shared_name,))
|
||||||
else:
|
else:
|
||||||
@ -349,3 +360,147 @@ def validate_and_slice_inputs(names_to_saveables):
|
|||||||
for converted_saveable_object in saveable_objects_for_op(op, name):
|
for converted_saveable_object in saveable_objects_for_op(op, name):
|
||||||
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
||||||
return saveables
|
return saveables
|
||||||
|
|
||||||
|
|
||||||
|
def trace_save_restore_functions(object_to_save):
|
||||||
|
"""Gathers all SaveableObjects and traces the save and restore ops."""
|
||||||
|
saveable_map = {} # Maps name -> (save function, restore function)
|
||||||
|
for name, saveable_factory in (
|
||||||
|
object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
|
||||||
|
if not callable(saveable_factory):
|
||||||
|
if isinstance(saveable_factory, saveable_object.SaveableObject):
|
||||||
|
logging.debug(
|
||||||
|
"Trackable {} should return callable factories, not SaveableObjects"
|
||||||
|
" in `_gather_saveables_for_checkpoint`. This could lead to "
|
||||||
|
"problems loading the SavedModel back into Python."
|
||||||
|
.format(object_to_save))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_factory_for_restored_saveable_object(saveable_factory):
|
||||||
|
saveable_map[name] = (saveable_factory.keywords["save_function"],
|
||||||
|
saveable_factory.keywords["restore_function"])
|
||||||
|
else:
|
||||||
|
concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function(
|
||||||
|
saveable_factory, object_to_save)
|
||||||
|
if concrete_save_fn is not None:
|
||||||
|
saveable_map[name] = (concrete_save_fn, concrete_restore_fn)
|
||||||
|
return saveable_map
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_save_and_restore_function(saveable_factory, object_to_save):
|
||||||
|
"""Traces the save and restore concrete functions."""
|
||||||
|
saveables = []
|
||||||
|
|
||||||
|
@def_function.function(
|
||||||
|
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
||||||
|
def save_fn(checkpoint_key):
|
||||||
|
maybe_saveable = saveable_factory(name=checkpoint_key)
|
||||||
|
if isinstance(maybe_saveable, saveable_object.SaveableObject):
|
||||||
|
maybe_saveable = [maybe_saveable]
|
||||||
|
saveables[:] = maybe_saveable
|
||||||
|
|
||||||
|
# Return list of all SaveSpecs created by the factory.
|
||||||
|
ret = []
|
||||||
|
for saveable in saveables:
|
||||||
|
for spec in saveable.specs:
|
||||||
|
ret.append({"name": spec.name, "tensor": spec.tensor,
|
||||||
|
"slice_spec": spec.slice_spec})
|
||||||
|
return ret
|
||||||
|
|
||||||
|
concrete_save_fn = save_fn.get_concrete_function()
|
||||||
|
if any(isinstance(saveable, trackable.PythonStateSaveable)
|
||||||
|
for saveable in saveables):
|
||||||
|
logging.warn(
|
||||||
|
"Note that object {} stores python values into the checkpoint. "
|
||||||
|
"These values will not be restored when loading the SavedModel "
|
||||||
|
"into python.".format(object_to_save))
|
||||||
|
return None, None
|
||||||
|
if any(isinstance(saveable, trackable.NoRestoreSaveable)
|
||||||
|
for saveable in saveables):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
restored_type_specs = []
|
||||||
|
tensor_structure = []
|
||||||
|
for saveable in saveables:
|
||||||
|
saveable_tensor_structure = []
|
||||||
|
tensor_structure.append(saveable_tensor_structure)
|
||||||
|
for spec in saveable.specs:
|
||||||
|
restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
|
||||||
|
saveable_tensor_structure.append(spec.name)
|
||||||
|
|
||||||
|
@def_function.function(input_signature=restored_type_specs)
|
||||||
|
def restore_fn(*restored_tensors):
|
||||||
|
structured_restored_tensors = nest.pack_sequence_as(
|
||||||
|
tensor_structure, restored_tensors)
|
||||||
|
for saveable, restored_tensors in zip(saveables,
|
||||||
|
structured_restored_tensors):
|
||||||
|
saveable.restore(restored_tensors, restored_shapes=None)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
concrete_restore_fn = restore_fn.get_concrete_function()
|
||||||
|
return concrete_save_fn, concrete_restore_fn
|
||||||
|
|
||||||
|
|
||||||
|
class RestoredSaveableObject(saveable_object.SaveableObject):
|
||||||
|
"""SaveableObject restored from SavedModel using the traced save/restore."""
|
||||||
|
|
||||||
|
def __init__(self, save_function, restore_function, name):
|
||||||
|
self.save_function = save_function
|
||||||
|
self.restore_function = restore_function
|
||||||
|
|
||||||
|
if tensor_util.is_tensor(name):
|
||||||
|
name_tensor = name
|
||||||
|
else:
|
||||||
|
with ops.init_scope():
|
||||||
|
name_tensor = constant_op.constant(name)
|
||||||
|
tensors = save_function(name_tensor)
|
||||||
|
specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"])
|
||||||
|
for x in tensors]
|
||||||
|
super(RestoredSaveableObject, self).__init__(None, specs, name)
|
||||||
|
|
||||||
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
|
del restored_shapes # unused
|
||||||
|
return self.restore_function(
|
||||||
|
*[restored_tensors[i] for i in range(len(self.specs))])
|
||||||
|
|
||||||
|
|
||||||
|
def restored_saved_object_factory(save_function, restore_function):
|
||||||
|
return functools.partial(RestoredSaveableObject,
|
||||||
|
save_function=save_function,
|
||||||
|
restore_function=restore_function)
|
||||||
|
|
||||||
|
|
||||||
|
def create_saveable_object(factory, name, call_with_mapped_captures):
|
||||||
|
"""Creates a SaveableObject while potentially in a different graph.
|
||||||
|
|
||||||
|
When creating the frozen saver for SavedModel, the save and restore ops are
|
||||||
|
placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
|
||||||
|
save and restore, the function captures must be mapped to the new graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory: Factory method for creating the SaveableObject.
|
||||||
|
name: Checkpoint key of this SaveableObject.
|
||||||
|
call_with_mapped_captures: Helper that calls a tf.function while remapping
|
||||||
|
the captures.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a SaveableObject.
|
||||||
|
"""
|
||||||
|
if (call_with_mapped_captures is None or
|
||||||
|
not is_factory_for_restored_saveable_object(factory)):
|
||||||
|
return factory(name=name)
|
||||||
|
|
||||||
|
concrete_save_fn = factory.keywords["save_function"]
|
||||||
|
def save_fn(name):
|
||||||
|
return call_with_mapped_captures(concrete_save_fn, [name])
|
||||||
|
|
||||||
|
concrete_restore_fn = factory.keywords["restore_function"]
|
||||||
|
def restore_fn(*restored_tensors):
|
||||||
|
return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
|
||||||
|
|
||||||
|
return factory(save_function=save_fn, restore_function=restore_fn, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def is_factory_for_restored_saveable_object(factory):
|
||||||
|
return (isinstance(factory, functools.partial) and
|
||||||
|
factory.func is RestoredSaveableObject)
|
||||||
|
@ -293,9 +293,10 @@ class CheckpointPosition(object):
|
|||||||
checkpoint_key = serialized_tensor.checkpoint_key
|
checkpoint_key = serialized_tensor.checkpoint_key
|
||||||
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
dtype = self._checkpoint.dtype_map[checkpoint_key]
|
||||||
base_type = dtype.base_dtype
|
base_type = dtype.base_dtype
|
||||||
|
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
with ops.device("/cpu:0"):
|
with ops.device(io_device):
|
||||||
# Run the restore itself on the CPU.
|
# Run the restore itself on the io_device(CPU or specified).
|
||||||
value, = io_ops.restore_v2(
|
value, = io_ops.restore_v2(
|
||||||
prefix=self._checkpoint.save_path_tensor,
|
prefix=self._checkpoint.save_path_tensor,
|
||||||
tensor_names=[checkpoint_key],
|
tensor_names=[checkpoint_key],
|
||||||
@ -611,6 +612,12 @@ class Trackable(object):
|
|||||||
# building.
|
# building.
|
||||||
self._self_name_based_restores = set()
|
self._self_name_based_restores = set()
|
||||||
|
|
||||||
|
# Dictionary of SaveableObjects factories. This dictionary is defined when
|
||||||
|
# the object is loaded from the SavedModel. When writing a custom class,
|
||||||
|
# prefer overriding "_gather_saveables_from_checkpoint" to using this
|
||||||
|
# attribute.
|
||||||
|
self._self_saveable_object_factories = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _object_identifier(self):
|
def _object_identifier(self):
|
||||||
"""String used to identify this object in a SavedModel.
|
"""String used to identify this object in a SavedModel.
|
||||||
@ -972,7 +979,7 @@ class Trackable(object):
|
|||||||
lambda name="global_name_for_this_object":
|
lambda name="global_name_for_this_object":
|
||||||
SaveableObject(name=name, ...)}
|
SaveableObject(name=name, ...)}
|
||||||
"""
|
"""
|
||||||
return {}
|
return self._self_saveable_object_factories
|
||||||
|
|
||||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||||
"""Lists extra dependencies to serialize.
|
"""Lists extra dependencies to serialize.
|
||||||
|
@ -208,7 +208,7 @@ class ObjectGraphView(object):
|
|||||||
|
|
||||||
def _add_attributes_to_object_graph(
|
def _add_attributes_to_object_graph(
|
||||||
self, trackable_objects, object_graph_proto, node_ids, object_names,
|
self, trackable_objects, object_graph_proto, node_ids, object_names,
|
||||||
object_map):
|
object_map, call_with_mapped_captures):
|
||||||
"""Create SaveableObjects and corresponding SerializedTensor protos."""
|
"""Create SaveableObjects and corresponding SerializedTensor protos."""
|
||||||
named_saveable_objects = []
|
named_saveable_objects = []
|
||||||
if self._saveables_cache is None:
|
if self._saveables_cache is None:
|
||||||
@ -253,7 +253,9 @@ class ObjectGraphView(object):
|
|||||||
break
|
break
|
||||||
if saveables is None:
|
if saveables is None:
|
||||||
if callable(saveable_factory):
|
if callable(saveable_factory):
|
||||||
maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
|
maybe_saveable = saveable_object_util.create_saveable_object(
|
||||||
|
saveable_factory, attribute.checkpoint_key,
|
||||||
|
call_with_mapped_captures)
|
||||||
else:
|
else:
|
||||||
maybe_saveable = saveable_factory
|
maybe_saveable = saveable_factory
|
||||||
if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
|
if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
|
||||||
@ -332,7 +334,8 @@ class ObjectGraphView(object):
|
|||||||
return object_graph_proto
|
return object_graph_proto
|
||||||
|
|
||||||
def _serialize_gathered_objects(self, trackable_objects, path_to_root,
|
def _serialize_gathered_objects(self, trackable_objects, path_to_root,
|
||||||
object_map=None):
|
object_map=None,
|
||||||
|
call_with_mapped_captures=None):
|
||||||
"""Create SaveableObjects and protos for gathered objects."""
|
"""Create SaveableObjects and protos for gathered objects."""
|
||||||
object_names = object_identity.ObjectIdentityDictionary()
|
object_names = object_identity.ObjectIdentityDictionary()
|
||||||
for obj, path in path_to_root.items():
|
for obj, path in path_to_root.items():
|
||||||
@ -354,7 +357,8 @@ class ObjectGraphView(object):
|
|||||||
object_graph_proto=object_graph_proto,
|
object_graph_proto=object_graph_proto,
|
||||||
node_ids=node_ids,
|
node_ids=node_ids,
|
||||||
object_names=object_names,
|
object_names=object_names,
|
||||||
object_map=object_map))
|
object_map=object_map,
|
||||||
|
call_with_mapped_captures=call_with_mapped_captures))
|
||||||
return named_saveable_objects, object_graph_proto, feed_additions
|
return named_saveable_objects, object_graph_proto, feed_additions
|
||||||
|
|
||||||
def serialize_object_graph(self):
|
def serialize_object_graph(self):
|
||||||
@ -382,7 +386,8 @@ class ObjectGraphView(object):
|
|||||||
return self._serialize_gathered_objects(
|
return self._serialize_gathered_objects(
|
||||||
trackable_objects, path_to_root)
|
trackable_objects, path_to_root)
|
||||||
|
|
||||||
def frozen_saveable_objects(self, object_map=None, to_graph=None):
|
def frozen_saveable_objects(self, object_map=None, to_graph=None,
|
||||||
|
call_with_mapped_captures=None):
|
||||||
"""Creates SaveableObjects with the current object graph frozen."""
|
"""Creates SaveableObjects with the current object graph frozen."""
|
||||||
trackable_objects, path_to_root = self._breadth_first_traversal()
|
trackable_objects, path_to_root = self._breadth_first_traversal()
|
||||||
if to_graph:
|
if to_graph:
|
||||||
@ -393,7 +398,8 @@ class ObjectGraphView(object):
|
|||||||
named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
|
named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
|
||||||
trackable_objects,
|
trackable_objects,
|
||||||
path_to_root,
|
path_to_root,
|
||||||
object_map)
|
object_map,
|
||||||
|
call_with_mapped_captures)
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
object_graph_tensor = constant_op.constant(
|
object_graph_tensor = constant_op.constant(
|
||||||
graph_proto.SerializeToString(), dtype=dtypes.string)
|
graph_proto.SerializeToString(), dtype=dtypes.string)
|
||||||
|
@ -59,7 +59,7 @@ load(
|
|||||||
# not contain rc or alpha, only numbers.
|
# not contain rc or alpha, only numbers.
|
||||||
# Also update tensorflow/core/public/version.h
|
# Also update tensorflow/core/public/version.h
|
||||||
# and tensorflow/tools/pip_package/setup.py
|
# and tensorflow/tools/pip_package/setup.py
|
||||||
VERSION = "2.2.0"
|
VERSION = "2.3.0"
|
||||||
VERSION_MAJOR = VERSION.split(".")[0]
|
VERSION_MAJOR = VERSION.split(".")[0]
|
||||||
|
|
||||||
# Sanitize a dependency so that it works correctly from code that includes
|
# Sanitize a dependency so that it works correctly from code that includes
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
path: "tensorflow.data.experimental.service.MasterServer"
|
path: "tensorflow.data.experimental.service.DispatchServer"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.DispatchServer\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member {
|
member {
|
||||||
name: "target"
|
name: "target"
|
@ -4,7 +4,7 @@ tf_class {
|
|||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
argspec: "args=[\'self\', \'port\', \'dispatcher_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "join"
|
name: "join"
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
path: "tensorflow.data.experimental.service"
|
path: "tensorflow.data.experimental.service"
|
||||||
tf_module {
|
tf_module {
|
||||||
member {
|
member {
|
||||||
name: "MasterServer"
|
name: "DispatchServer"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
member {
|
||||||
|
@ -17,4 +17,4 @@ SET PYTHON_DIRECTORY=Python35
|
|||||||
|
|
||||||
CALL tensorflow\tools\ci_build\release\common_win.bat
|
CALL tensorflow\tools\ci_build\release\common_win.bat
|
||||||
|
|
||||||
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2 --define=no_tensorflow_py_deps=true" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
||||||
|
@ -17,4 +17,4 @@ SET PYTHON_DIRECTORY=Python36
|
|||||||
|
|
||||||
CALL tensorflow\tools\ci_build\release\common_win.bat
|
CALL tensorflow\tools\ci_build\release\common_win.bat
|
||||||
|
|
||||||
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2 --define=no_tensorflow_py_deps=true" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
||||||
|
@ -17,4 +17,4 @@ SET PYTHON_DIRECTORY=Python37
|
|||||||
|
|
||||||
CALL tensorflow\tools\ci_build\release\common_win.bat
|
CALL tensorflow\tools\ci_build\release\common_win.bat
|
||||||
|
|
||||||
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2 --define=no_tensorflow_py_deps=true" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
||||||
|
@ -17,4 +17,5 @@ SET PYTHON_DIRECTORY=Python38
|
|||||||
|
|
||||||
CALL tensorflow\tools\ci_build\release\common_win.bat
|
CALL tensorflow\tools\ci_build\release\common_win.bat
|
||||||
|
|
||||||
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
call tensorflow\tools\ci_build\windows\cpu\pip\run.bat --release_build --extra_build_flags "--config=v2 --define=no_tensorflow_py_deps=true" --extra_test_flags "--test_env=TF2_BEHAVIOR=1" --project_name "tensorflow_cpu"
|
||||||
|
|
||||||
|
@ -99,8 +99,8 @@ tensorflow::data::GrpcDataServerBase::Join
|
|||||||
tensorflow::data::GrpcDataServerBase::Start
|
tensorflow::data::GrpcDataServerBase::Start
|
||||||
tensorflow::data::GrpcDataServerBase::Stop
|
tensorflow::data::GrpcDataServerBase::Stop
|
||||||
tensorflow::data::GrpcDataServerBase::BoundPort
|
tensorflow::data::GrpcDataServerBase::BoundPort
|
||||||
tensorflow::data::MasterGrpcDataServer::NumWorkers
|
tensorflow::data::DispatchGrpcDataServer::NumWorkers
|
||||||
tensorflow::data::NewMasterServer
|
tensorflow::data::NewDispatchServer
|
||||||
tensorflow::data::NewWorkerServer
|
tensorflow::data::NewWorkerServer
|
||||||
|
|
||||||
[protos_all] # device_lib, dtypes
|
[protos_all] # device_lib, dtypes
|
||||||
|
@ -84,7 +84,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -84,7 +84,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -126,7 +126,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -126,7 +126,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -84,7 +84,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -93,7 +93,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Install bazel
|
# Install bazel
|
||||||
ARG BAZEL_VERSION=3.0.0
|
ARG BAZEL_VERSION=3.1.0
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
||||||
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
|
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
|
||||||
|
@ -84,7 +84,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -93,7 +93,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Install bazel
|
# Install bazel
|
||||||
ARG BAZEL_VERSION=3.0.0
|
ARG BAZEL_VERSION=3.1.0
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
wget -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
||||||
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
|
wget -O /bazel/LICENSE.txt "https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE" && \
|
||||||
|
@ -83,7 +83,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -91,7 +91,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Build and install bazel
|
# Build and install bazel
|
||||||
ENV BAZEL_VERSION 3.0.0
|
ENV BAZEL_VERSION 3.1.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
@ -83,7 +83,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -91,7 +91,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Build and install bazel
|
# Build and install bazel
|
||||||
ENV BAZEL_VERSION 3.0.0
|
ENV BAZEL_VERSION 3.1.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
@ -125,7 +125,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -133,7 +133,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Build and install bazel
|
# Build and install bazel
|
||||||
ENV BAZEL_VERSION 3.0.0
|
ENV BAZEL_VERSION 3.1.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
@ -125,7 +125,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
@ -133,7 +133,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
enum34
|
enum34
|
||||||
|
|
||||||
# Build and install bazel
|
# Build and install bazel
|
||||||
ENV BAZEL_VERSION 3.0.0
|
ENV BAZEL_VERSION 3.1.0
|
||||||
WORKDIR /
|
WORKDIR /
|
||||||
RUN mkdir /bazel && \
|
RUN mkdir /bazel && \
|
||||||
cd /bazel && \
|
cd /bazel && \
|
||||||
|
@ -14,7 +14,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -13,7 +13,7 @@ RUN python3 -m pip --no-cache-dir install \
|
|||||||
keras_preprocessing \
|
keras_preprocessing \
|
||||||
matplotlib \
|
matplotlib \
|
||||||
mock \
|
mock \
|
||||||
numpy \
|
'numpy<1.19.0' \
|
||||||
scipy \
|
scipy \
|
||||||
sklearn \
|
sklearn \
|
||||||
pandas \
|
pandas \
|
||||||
|
@ -49,7 +49,7 @@ from setuptools.dist import Distribution
|
|||||||
# result for pip.
|
# result for pip.
|
||||||
# Also update tensorflow/tensorflow.bzl and
|
# Also update tensorflow/tensorflow.bzl and
|
||||||
# tensorflow/core/public/version.h
|
# tensorflow/core/public/version.h
|
||||||
_VERSION = '2.2.0'
|
_VERSION = '2.3.0'
|
||||||
|
|
||||||
REQUIRED_PACKAGES = [
|
REQUIRED_PACKAGES = [
|
||||||
'absl-py >= 0.7.0',
|
'absl-py >= 0.7.0',
|
||||||
@ -63,8 +63,8 @@ REQUIRED_PACKAGES = [
|
|||||||
'numpy >= 1.16.0, < 1.19.0',
|
'numpy >= 1.16.0, < 1.19.0',
|
||||||
'opt_einsum >= 2.3.2',
|
'opt_einsum >= 2.3.2',
|
||||||
'protobuf >= 3.9.2',
|
'protobuf >= 3.9.2',
|
||||||
'tensorboard >= 2.2.0, < 2.3.0',
|
'tensorboard >= 2.3.0, < 3',
|
||||||
'tensorflow_estimator >= 2.2.0, < 2.3.0',
|
'tensorflow_estimator >= 2.3.0, < 2.4.0',
|
||||||
'termcolor >= 1.1.0',
|
'termcolor >= 1.1.0',
|
||||||
'wrapt >= 1.11.1',
|
'wrapt >= 1.11.1',
|
||||||
'wheel >= 0.26',
|
'wheel >= 0.26',
|
||||||
|
@ -292,6 +292,26 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_http_archive(
|
||||||
|
name = "LinaroArmGcc72",
|
||||||
|
build_file = clean_dep("//third_party/toolchains/embedded/linaro-gcc72-armeabi:linaro-gcc72-armeabi.BUILD"),
|
||||||
|
strip_prefix = "gcc-linaro-7.2.1-2017.11-x86_64_arm-linux-gnueabihf/",
|
||||||
|
urls = [
|
||||||
|
"https://releases.linaro.org/components/toolchain/binaries/7.2-2017.11/arm-linux-gnueabihf/gcc-linaro-7.2.1-2017.11-x86_64_arm-linux-gnueabihf.tar.xz",
|
||||||
|
],
|
||||||
|
sha256 = "cee0087b1f1205b73996651b99acd3a926d136e71047048f1758ffcec69b1ca2",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_http_archive(
|
||||||
|
name = "LinaroAarch64Gcc72",
|
||||||
|
build_file = clean_dep("//third_party/toolchains/embedded/linaro-gcc72-aarch64:linaro-gcc72-aarch64.BUILD"),
|
||||||
|
strip_prefix = "gcc-linaro-7.2.1-2017.11-x86_64_aarch64-linux-gnu/",
|
||||||
|
urls = [
|
||||||
|
"https://releases.linaro.org/components/toolchain/binaries/7.2-2017.11/aarch64-linux-gnu/gcc-linaro-7.2.1-2017.11-x86_64_aarch64-linux-gnu.tar.xz",
|
||||||
|
],
|
||||||
|
sha256 = "20181f828e1075f1a493947ff91e82dd578ce9f8638fbdfc39e24b62857d8f8d",
|
||||||
|
)
|
||||||
|
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "libxsmm_archive",
|
name = "libxsmm_archive",
|
||||||
build_file = clean_dep("//third_party:libxsmm.BUILD"),
|
build_file = clean_dep("//third_party:libxsmm.BUILD"),
|
||||||
@ -409,12 +429,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "org_sqlite",
|
name = "org_sqlite",
|
||||||
build_file = clean_dep("//third_party:sqlite.BUILD"),
|
build_file = clean_dep("//third_party:sqlite.BUILD"),
|
||||||
sha256 = "f3c79bc9f4162d0b06fa9fe09ee6ccd23bb99ce310b792c5145f87fbcc30efca",
|
sha256 = "e9cec01d4519e2d49b3810615237325263fe1feaceae390ee12b4a29bd73dbe2",
|
||||||
strip_prefix = "sqlite-amalgamation-3310100",
|
strip_prefix = "sqlite-amalgamation-3320300",
|
||||||
system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
|
system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2020/sqlite-amalgamation-3310100.zip",
|
"https://storage.googleapis.com/mirror.tensorflow.org/www.sqlite.org/2020/sqlite-amalgamation-3320300.zip",
|
||||||
"https://www.sqlite.org/2020/sqlite-amalgamation-3310100.zip",
|
"https://www.sqlite.org/2020/sqlite-amalgamation-3320300.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
third_party/aws/workspace.bzl
vendored
2
third_party/aws/workspace.bzl
vendored
@ -9,7 +9,7 @@ def repo():
|
|||||||
third_party_http_archive(
|
third_party_http_archive(
|
||||||
name = "aws",
|
name = "aws",
|
||||||
urls = [
|
urls = [
|
||||||
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
"https://mirror.tensorflow.orgg/github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
||||||
"https://github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
"https://github.com/aws/aws-sdk-cpp/archive/1.7.336.tar.gz",
|
||||||
],
|
],
|
||||||
sha256 = "758174f9788fed6cc1e266bcecb20bf738bd5ef1c3d646131c9ed15c2d6c5720",
|
sha256 = "758174f9788fed6cc1e266bcecb20bf738bd5ef1c3d646131c9ed15c2d6c5720",
|
||||||
|
@ -1,3 +1,14 @@
|
|||||||
|
--- ./absl/time/internal/cctz/include/cctz/civil_time_detail.h 2020-08-06 01:33:56.005757145 +0200
|
||||||
|
+++ ./absl/time/internal/cctz/include/cctz/civil_time_detail.h 2020-08-06 01:33:35.460579387 +0200
|
||||||
|
@@ -23,7 +23,7 @@
|
||||||
|
#include "absl/base/config.h"
|
||||||
|
|
||||||
|
// Disable constexpr support unless we are in C++14 mode.
|
||||||
|
-#if __cpp_constexpr >= 201304 || (defined(_MSC_VER) && _MSC_VER >= 1910)
|
||||||
|
+#if (!defined(NO_CONSTEXPR_FOR_YOU) && __cpp_constexpr >= 201304) || (defined(_MSC_VER) && _MSC_VER >= 1910)
|
||||||
|
#define CONSTEXPR_D constexpr // data
|
||||||
|
#define CONSTEXPR_F constexpr // function
|
||||||
|
#define CONSTEXPR_M constexpr // member
|
||||||
--- ./absl/time/internal/cctz/BUILD.bazel 2019-09-23 13:20:52.000000000 -0700
|
--- ./absl/time/internal/cctz/BUILD.bazel 2019-09-23 13:20:52.000000000 -0700
|
||||||
+++ ./absl/time/internal/cctz/BUILD.bazel.fixed 2019-09-23 13:20:48.000000000 -0700
|
+++ ./absl/time/internal/cctz/BUILD.bazel.fixed 2019-09-23 13:20:48.000000000 -0700
|
||||||
@@ -74,15 +74,6 @@
|
@@ -74,15 +74,6 @@
|
||||||
@ -301,4 +312,3 @@
|
|||||||
+ .internal_compressed_tuple::template Storage<CompressedTuple, I>::get();
|
+ .internal_compressed_tuple::template Storage<CompressedTuple, I>::get();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
5
third_party/icu/BUILD.bazel
vendored
5
third_party/icu/BUILD.bazel
vendored
@ -1,5 +1,8 @@
|
|||||||
|
# We make everything here private to make any dependencies on ICU become a build
|
||||||
|
# failure and easier/faster to track down, as it's not needed for DeepSpeech and
|
||||||
|
# causes linking problems on Windows.
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:private"],
|
||||||
)
|
)
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
1
third_party/py/BUILD.tpl
vendored
1
third_party/py/BUILD.tpl
vendored
@ -67,5 +67,4 @@ config_setting(
|
|||||||
)
|
)
|
||||||
|
|
||||||
%{PYTHON_INCLUDE_GENRULE}
|
%{PYTHON_INCLUDE_GENRULE}
|
||||||
%{NUMPY_INCLUDE_GENRULE}
|
|
||||||
%{PYTHON_IMPORT_LIB_GENRULE}
|
%{PYTHON_IMPORT_LIB_GENRULE}
|
||||||
|
16
third_party/py/python_configure.bzl
vendored
16
third_party/py/python_configure.bzl
vendored
@ -210,7 +210,7 @@ def _create_local_python_repository(repository_ctx):
|
|||||||
python_lib = _get_python_lib(repository_ctx, python_bin)
|
python_lib = _get_python_lib(repository_ctx, python_bin)
|
||||||
_check_python_lib(repository_ctx, python_lib)
|
_check_python_lib(repository_ctx, python_lib)
|
||||||
python_include = _get_python_include(repository_ctx, python_bin)
|
python_include = _get_python_include(repository_ctx, python_bin)
|
||||||
numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
|
# numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
|
||||||
python_include_rule = _symlink_genrule_for_dir(
|
python_include_rule = _symlink_genrule_for_dir(
|
||||||
repository_ctx,
|
repository_ctx,
|
||||||
python_include,
|
python_include,
|
||||||
@ -233,12 +233,12 @@ def _create_local_python_repository(repository_ctx):
|
|||||||
[python_import_lib_src],
|
[python_import_lib_src],
|
||||||
[python_import_lib_name],
|
[python_import_lib_name],
|
||||||
)
|
)
|
||||||
numpy_include_rule = _symlink_genrule_for_dir(
|
#numpy_include_rule = _symlink_genrule_for_dir(
|
||||||
repository_ctx,
|
# repository_ctx,
|
||||||
numpy_include,
|
# numpy_include,
|
||||||
"numpy_include/numpy",
|
# "numpy_include/numpy",
|
||||||
"numpy_include",
|
# "numpy_include",
|
||||||
)
|
#)
|
||||||
|
|
||||||
platform_constraint = ""
|
platform_constraint = ""
|
||||||
if repository_ctx.attr.platform_constraint:
|
if repository_ctx.attr.platform_constraint:
|
||||||
@ -247,7 +247,7 @@ def _create_local_python_repository(repository_ctx):
|
|||||||
"%{PYTHON_BIN_PATH}": python_bin,
|
"%{PYTHON_BIN_PATH}": python_bin,
|
||||||
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
|
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
|
||||||
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
|
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
|
||||||
"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
|
#"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
|
||||||
"%{PLATFORM_CONSTRAINT}": platform_constraint,
|
"%{PLATFORM_CONSTRAINT}": platform_constraint,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
2
third_party/repo.bzl
vendored
2
third_party/repo.bzl
vendored
@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
_SINGLE_URL_WHITELIST = depset([
|
_SINGLE_URL_WHITELIST = depset([
|
||||||
"arm_compiler",
|
"arm_compiler",
|
||||||
|
"LinaroArmGcc72",
|
||||||
|
"LinaroAarch64Gcc72",
|
||||||
])
|
])
|
||||||
|
|
||||||
def _is_windows(ctx):
|
def _is_windows(ctx):
|
||||||
|
67
third_party/toolchains/embedded/linaro-gcc72-aarch64/BUILD
vendored
Normal file
67
third_party/toolchains/embedded/linaro-gcc72-aarch64/BUILD
vendored
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# This is the entry point for --crosstool_top.
|
||||||
|
#
|
||||||
|
# The cc_toolchain rule used is found by:
|
||||||
|
#
|
||||||
|
# 1. Finding the appropriate toolchain in the CROSSTOOL file based on the --cpu
|
||||||
|
# and --compiler command line flags (if they exist, otherwise using the
|
||||||
|
# "default_target_cpu" / "default_toolchain" fields in the CROSSTOOL file)
|
||||||
|
# 2. Concatenating the "target_cpu" and "compiler" fields of the toolchain in
|
||||||
|
# use and using that as a key in the map in the "toolchains" attribute
|
||||||
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
|
load(":linaro_toolchain_config.bzl", "linaro_toolchain_config")
|
||||||
|
|
||||||
|
cc_toolchain_suite(
|
||||||
|
name = "toolchain",
|
||||||
|
toolchains = {
|
||||||
|
"aarch64": ":cc-compiler-aarch64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "empty",
|
||||||
|
srcs = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "gcc_linux_all_files",
|
||||||
|
srcs = [
|
||||||
|
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:tool-wrappers",
|
||||||
|
"@LinaroAarch64Gcc72//:compiler_pieces",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "gcc_linux_linker_files",
|
||||||
|
srcs = [
|
||||||
|
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:ld",
|
||||||
|
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:ar",
|
||||||
|
"@LinaroAarch64Gcc72//:compiler_pieces",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "gcc_linux_compiler_files",
|
||||||
|
srcs = [
|
||||||
|
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:gcc",
|
||||||
|
"//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:as",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
linaro_toolchain_config(name = "linaro_aarch64")
|
||||||
|
|
||||||
|
cc_toolchain(
|
||||||
|
name = "cc-compiler-aarch64",
|
||||||
|
all_files = ":gcc_linux_all_files",
|
||||||
|
compiler_files = ":gcc_linux_compiler_files",
|
||||||
|
toolchain_identifier = "gcc72_linaro_aarch64",
|
||||||
|
toolchain_config = ":linaro_aarch64",
|
||||||
|
dwp_files = ":empty",
|
||||||
|
dynamic_runtime_lib = ":empty",
|
||||||
|
linker_files = ":gcc_linux_linker_files",
|
||||||
|
objcopy_files = "//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:objcopy",
|
||||||
|
static_runtime_lib = ":empty",
|
||||||
|
strip_files = "//third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc:strip",
|
||||||
|
supports_param_files = 1,
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
)
|
79
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/BUILD
vendored
Normal file
79
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/BUILD
vendored
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package(default_visibility = ['//third_party/toolchains/embedded/linaro-gcc72-aarch64:__pkg__'])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'gcc',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:gcc',
|
||||||
|
'aarch64-linux-gnu-gcc',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'ar',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:ar',
|
||||||
|
'aarch64-linux-gnu-ar',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'ld',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:ld',
|
||||||
|
'aarch64-linux-gnu-ld',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'nm',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:nm',
|
||||||
|
'aarch64-linux-gnu-nm',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'objcopy',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:objcopy',
|
||||||
|
'aarch64-linux-gnu-objcopy',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'objdump',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:objdump',
|
||||||
|
'aarch64-linux-gnu-objdump',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'strip',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:strip',
|
||||||
|
'aarch64-linux-gnu-strip',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'as',
|
||||||
|
srcs = [
|
||||||
|
'@LinaroAarch64Gcc72//:as',
|
||||||
|
'aarch64-linux-gnu-as',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = 'tool-wrappers',
|
||||||
|
srcs = [
|
||||||
|
':gcc',
|
||||||
|
':ar',
|
||||||
|
':ld',
|
||||||
|
':nm',
|
||||||
|
':objcopy',
|
||||||
|
':objdump',
|
||||||
|
':strip',
|
||||||
|
':as',
|
||||||
|
],
|
||||||
|
)
|
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-ar
vendored
Executable file
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-ar
vendored
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash --norc
|
||||||
|
|
||||||
|
exec -a aarch64-linux-gnu-ar \
|
||||||
|
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-ar \
|
||||||
|
"$@"
|
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-as
vendored
Executable file
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-as
vendored
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash --norc
|
||||||
|
|
||||||
|
exec -a aarch64-linux-gnu-as \
|
||||||
|
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-as \
|
||||||
|
"$@"
|
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-cpp
vendored
Executable file
5
third_party/toolchains/embedded/linaro-gcc72-aarch64/gcc/aarch64-linux-gnu-cpp
vendored
Executable file
@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash --norc
|
||||||
|
|
||||||
|
exec -a aarch64-linux-gnu-cpp \
|
||||||
|
external/LinaroAarch64Gcc72/bin/aarch64-linux-gnu-cpp \
|
||||||
|
"$@"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user