[tf.data service] Export distribute transformation
PiperOrigin-RevId: 313620607 Change-Id: I4f43f44ac7a8a9b27226b7d63aef374f3c8a178d
This commit is contained in:
parent
c197499613
commit
b0fcb6c18c
@ -13,5 +13,6 @@ py_library(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/data/experimental/ops:dataset_ops",
|
"//tensorflow/python/data/experimental/ops:dataset_ops",
|
||||||
"//tensorflow/python/data/experimental/ops:iterator_ops",
|
"//tensorflow/python/data/experimental/ops:iterator_ops",
|
||||||
|
"//tensorflow/python/data/experimental/service",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -53,6 +53,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@copy_to_device
|
@@copy_to_device
|
||||||
@@dense_to_ragged_batch
|
@@dense_to_ragged_batch
|
||||||
@@dense_to_sparse_batch
|
@@dense_to_sparse_batch
|
||||||
|
@@distribute
|
||||||
@@enumerate_dataset
|
@@enumerate_dataset
|
||||||
@@from_variant
|
@@from_variant
|
||||||
@@get_next_as_optional
|
@@get_next_as_optional
|
||||||
@ -89,6 +90,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
|
from tensorflow.python.data.experimental import service
|
||||||
from tensorflow.python.data.experimental.ops.batching import dense_to_ragged_batch
|
from tensorflow.python.data.experimental.ops.batching import dense_to_ragged_batch
|
||||||
from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
|
from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
|
||||||
from tensorflow.python.data.experimental.ops.batching import map_and_batch
|
from tensorflow.python.data.experimental.ops.batching import map_and_batch
|
||||||
@ -150,4 +152,9 @@ from tensorflow.python.framework.type_spec import TypeSpec as Structure
|
|||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
remove_undocumented(__name__)
|
|
||||||
|
_allowed_symbols = [
|
||||||
|
"service",
|
||||||
|
]
|
||||||
|
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
||||||
|
@ -23,11 +23,13 @@ import six
|
|||||||
|
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.experimental.ops import compression_ops
|
from tensorflow.python.data.experimental.ops import compression_ops
|
||||||
|
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||||
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMode(object):
|
class ProcessingMode(object):
|
||||||
@ -240,11 +242,18 @@ def _distribute(processing_mode,
|
|||||||
# to limit memory usage.
|
# to limit memory usage.
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec))
|
lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec))
|
||||||
|
|
||||||
|
# Disable autosharding for shared jobs.
|
||||||
|
if job_name:
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.service.distribute")
|
||||||
def distribute(processing_mode,
|
def distribute(processing_mode,
|
||||||
service,
|
service,
|
||||||
job_name=None,
|
job_name=None,
|
||||||
@ -289,32 +298,65 @@ def distribute(processing_mode,
|
|||||||
executed locally.
|
executed locally.
|
||||||
|
|
||||||
The `job_name` argument allows jobs to be shared across multiple
|
The `job_name` argument allows jobs to be shared across multiple
|
||||||
datasets. Instead of each dataset creating its own job, all datasets with the
|
datasets. Instead of each dataset creating its own job, all
|
||||||
same `job_name` will consume from the same job. A new job will
|
datasets with the same `job_name` will consume from the same job. A new job
|
||||||
be created for each iteration of the dataset (with each repetition of
|
will be created for each iteration of the dataset (with each repetition of
|
||||||
`Dataset.repeat` counting as a new iteration). The following example
|
`Dataset.repeat` counting as a new iteration). Suppose two training workers
|
||||||
demonstrates shared iteration, with the assumption that the tf.data service is
|
(in either a single client or multi-client setup) iterate over the below
|
||||||
running with a single worker.
|
dataset, and there is a single tf.data worker:
|
||||||
|
|
||||||
```
|
```
|
||||||
range5_dataset = tf.data.Dataset.range(5)
|
range5_dataset = tf.data.Dataset.range(5)
|
||||||
dataset1 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
dataset = range5_dataset.apply(tf.data.experimental.service.distribute(
|
||||||
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
"parallel_epochs", "grpc://dataservice:5000", job_name="my_job_name"))
|
||||||
dataset2 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
for iteration in range(3):
|
||||||
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
print(list(dataset))
|
||||||
iter_1_1 = iter(dataset1)
|
|
||||||
iter_1_2 = iter(dataset1)
|
|
||||||
iter_2_1 = iter(dataset2)
|
|
||||||
iter_2_2 = iter(dataset2)
|
|
||||||
print(next(iter_1_1)) # Prints "0"
|
|
||||||
# iter_1_2 consumes from the same job as iter_1_1
|
|
||||||
print(next(iter_1_2)) # Prints "1"
|
|
||||||
# iter_2_1 consumes from a new job
|
|
||||||
print(next(iter_2_1)) # Prints "0"
|
|
||||||
# iter_2_2 consumes from the same job as iter_2_1
|
|
||||||
print(next(iter_2_2)) # Prints "1"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The elements of each job will be split between the two processes, with
|
||||||
|
elements being consumed by the processes on a first-come first-served basis.
|
||||||
|
One possible result is that process 1 prints
|
||||||
|
|
||||||
|
```
|
||||||
|
[0, 2, 4]
|
||||||
|
[0, 1, 3]
|
||||||
|
[1]
|
||||||
|
```
|
||||||
|
|
||||||
|
and process 2 prints
|
||||||
|
|
||||||
|
```
|
||||||
|
[1, 3]
|
||||||
|
[2, 4]
|
||||||
|
[0, 2, 3, 4]
|
||||||
|
```
|
||||||
|
|
||||||
|
Job names must not be re-used across different training jobs within the
|
||||||
|
lifetime of the tf.data service. In general, the tf.data service is expected
|
||||||
|
to live for the duration of a single training job.
|
||||||
|
To use the tf.data service with multiple training jobs, make sure to use
|
||||||
|
different job names to avoid conflicts. For example, suppose a training job
|
||||||
|
calls `distribute` with `job_name="job"` and reads until end of input. If
|
||||||
|
another independent job connects to the same tf.data service and tries to read
|
||||||
|
from `job_name="job"`, it will immediately receive end of input, without
|
||||||
|
getting any data.
|
||||||
|
|
||||||
|
**Keras and Distribution Strategies**
|
||||||
|
|
||||||
|
The dataset produced by the `distribute` transformation can be passed to
|
||||||
|
Keras' `Model.fit` or Distribution Strategy's
|
||||||
|
`tf.distribute.Strategy.experimental_distribute_dataset` like any other
|
||||||
|
`tf.data.Dataset`. We recommend setting a `job_name` on the call to
|
||||||
|
`distribute` so that if there are multiple workers, they read data from the
|
||||||
|
same job. Note that the autosharding normally performed by
|
||||||
|
`experimental_distribute_dataset` will be disabled when setting a `job_name`,
|
||||||
|
since sharing the job already results in splitting data across the workers.
|
||||||
|
When using a shared job, data will be dynamically balanced across workers, so
|
||||||
|
that they reach end of input about the same time. This results in better
|
||||||
|
worker utilization than with autosharding, where each worker processes an
|
||||||
|
independent set of files, and some workers may run out of data earlier than
|
||||||
|
others.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
processing_mode: A string specifying the policy for how data should be
|
processing_mode: A string specifying the policy for how data should be
|
||||||
processed by tf.data workers. Currently, the only supported value is
|
processed by tf.data workers. Currently, the only supported value is
|
||||||
|
15
tensorflow/python/data/experimental/service/BUILD
Normal file
15
tensorflow/python/data/experimental/service/BUILD
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "service",
|
||||||
|
srcs = ["__init__.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/data/experimental/ops:data_service_ops",
|
||||||
|
],
|
||||||
|
)
|
21
tensorflow/python/data/experimental/service/__init__.py
Normal file
21
tensorflow/python/data/experimental/service/__init__.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Experimental API for using the tf.data service."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
@ -16,6 +16,7 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"config/threading/__init__.py",
|
"config/threading/__init__.py",
|
||||||
"data/__init__.py",
|
"data/__init__.py",
|
||||||
"data/experimental/__init__.py",
|
"data/experimental/__init__.py",
|
||||||
|
"data/experimental/service/__init__.py",
|
||||||
"debugging/__init__.py",
|
"debugging/__init__.py",
|
||||||
"debugging/experimental/__init__.py",
|
"debugging/experimental/__init__.py",
|
||||||
"distribute/__init__.py",
|
"distribute/__init__.py",
|
||||||
|
@ -16,6 +16,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
|||||||
"config/threading/__init__.py",
|
"config/threading/__init__.py",
|
||||||
"data/__init__.py",
|
"data/__init__.py",
|
||||||
"data/experimental/__init__.py",
|
"data/experimental/__init__.py",
|
||||||
|
"data/experimental/service/__init__.py",
|
||||||
"debugging/__init__.py",
|
"debugging/__init__.py",
|
||||||
"debugging/experimental/__init__.py",
|
"debugging/experimental/__init__.py",
|
||||||
"distribute/__init__.py",
|
"distribute/__init__.py",
|
||||||
|
@ -80,6 +80,10 @@ tf_module {
|
|||||||
name: "UNKNOWN_CARDINALITY"
|
name: "UNKNOWN_CARDINALITY"
|
||||||
mtype: "<type \'int\'>"
|
mtype: "<type \'int\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "service"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Counter"
|
name: "Counter"
|
||||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
path: "tensorflow.data.experimental.service"
|
||||||
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "distribute"
|
||||||
|
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -68,6 +68,10 @@ tf_module {
|
|||||||
name: "UNKNOWN_CARDINALITY"
|
name: "UNKNOWN_CARDINALITY"
|
||||||
mtype: "<type \'int\'>"
|
mtype: "<type \'int\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "service"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Counter"
|
name: "Counter"
|
||||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
path: "tensorflow.data.experimental.service"
|
||||||
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "distribute"
|
||||||
|
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user