STT-tensorflow/tensorflow/python/distribute
2020-11-18 17:23:10 -08:00
..
cluster_resolver Enable more TFRT test targets that are newly passing 2020-09-23 11:44:50 -07:00
coordinator PSv2: Only instantiate monitoring.ExponentialBuckets if enable_metrics is True. 2020-11-06 10:06:28 -08:00
experimental
integration_test Disable mwms_peer_failure test on Python 3.8 2020-10-23 12:07:59 -07:00
parallel_device Parallel device: fix variable initialization in tf.function 2020-09-28 09:16:33 -07:00
v1 [retry] Graduate MultiWorkerMirroredStrategy out of experimental 2020-10-20 18:16:47 -07:00
all_reduce_test.py Remove unnecessary eval() calls 2020-06-30 17:18:32 -07:00
all_reduce.py
BUILD Disable a few failing tests on py3.8 version 2020-10-21 19:31:09 -07:00
central_storage_strategy.py Merge pull request from kushanam:distribute_dali_ctl 2020-10-19 09:25:22 -07:00
checkpoint_utils_test.py
checkpointing_test.py Add callable wrapper to CheckpointValueInitializer so that we can delay the variable restore until after variable creation scopes have been called. 2020-09-01 15:42:47 -07:00
collective_all_reduce_strategy_test.py [retry] Graduate MultiWorkerMirroredStrategy out of experimental 2020-10-20 18:16:47 -07:00
collective_all_reduce_strategy.py Set a timeout to check health RPC 2020-10-21 13:02:25 -07:00
collective_util_test.py Fix constructor of CommunicationOptions 2020-11-11 13:20:08 -08:00
collective_util.py Fix constructor of CommunicationOptions 2020-11-11 13:20:08 -08:00
combinations_test.py Update NamedDistribution for lazy creation of the runner in tf.distribute. 2020-10-15 10:50:24 -07:00
combinations.py Support aborting RING communication in multi worker collectives 2020-10-21 17:03:09 -07:00
cross_device_ops_test.py Fix a typo in CommunicationImplementation 2020-10-21 11:57:16 -07:00
cross_device_ops.py Do not use NCCL when reducing tensors on CPUs. 2020-10-21 20:14:58 -07:00
cross_device_utils_test.py Refactor collective utils to be of one replica 2020-10-13 20:02:22 -07:00
cross_device_utils.py Support V2 collective in dist strat 2020-10-20 18:29:04 -07:00
custom_training_loop_gradient_test.py Support Google-internal TPU resolution in strategy combinations. 2020-05-27 14:29:14 -07:00
custom_training_loop_input_test.py Rename "experimental_distribute_datasets_from_function" to "distribute_datasets_from_function". 2020-09-23 18:15:32 -07:00
device_util_test.py Try to deduce job, replica and task from config.list_logical_devices() again 2020-06-16 15:22:24 -07:00
device_util.py Use __slots__ for small classes 2020-06-28 18:41:22 +02:00
distribute_config.py
distribute_coordinator_context.py
distribute_coordinator_test.py
distribute_coordinator.py Distribute Coordinator currently assumes TF_CONFIG to be the only way to configure a strategy. We now allow cluster resolvers to be passed as arguments to instantiate the strategy instead of TF_CONFIG which should be used instead if set by the user. 2020-03-16 12:03:17 -07:00
distribute_lib_test.py Graduate experimental_hints to options in all_reduce/reduce/batch_reduce 2020-10-16 11:54:24 -07:00
distribute_lib.py Fix and test all_gather gradient. 2020-10-21 03:47:11 -07:00
distribute_utils_test.py Get namedtuple _make method from instance instead of class. 2020-08-10 09:10:33 -07:00
distribute_utils.py Install _distributed_container only at variable creation 2020-09-16 00:17:33 -07:00
distributed_file_utils_test.py Ensure distributed_file_utils.remove_temp_dirpath() can be safely called multiple times. 2020-04-27 16:05:05 -07:00
distributed_file_utils.py Ensure distributed_file_utils.remove_temp_dirpath() can be safely called multiple times. 2020-04-27 16:05:05 -07:00
distribution_strategy_context.py Generate replica_id tensor at call time 2020-07-27 19:21:33 -07:00
estimator_training.py
input_lib_test.py Always enable get_next_as_optional unless the dataset is finite. 2020-11-13 00:23:39 -08:00
input_lib_type_spec_test.py correct pylint formattings - 2 2020-11-18 17:23:10 -08:00
input_lib.py correct pylint formattings - 2 2020-11-18 17:23:10 -08:00
input_ops_test.py
input_ops.py [tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined. 2020-09-30 12:18:30 -07:00
metrics_v1_test.py
mirrored_run.py Return the correct replica id within a sync group for MWMS. Currently we return the local replica id within a worker as opposed to within a sync group. 2020-10-09 13:20:06 -07:00
mirrored_strategy_test.py Retire MultiWorkerAllReduce 2020-08-27 00:12:37 -07:00
mirrored_strategy.py Merge pull request from kushanam:distribute_dali_ctl 2020-10-19 09:25:22 -07:00
mirrored_variable_test.py Use utility to identify OnWrite and OnRead synchronized variables. 2020-07-27 14:14:19 -07:00
moving_averages_test.py Add test_util.main() and test_util.set_logical_devices_to_at_least() 2020-10-06 16:30:51 -07:00
multi_process_lib.py MultiProcessRunner: Disable MPR on python 3.8 + Linux case. 2020-10-19 15:02:34 -07:00
multi_process_runner_no_init_test.py TF Internal API: tf_export a few distribute-related symbols: 2020-10-07 14:38:53 -07:00
multi_process_runner_test.py MultiProcessRunner: Skips the tests that timeout on oss. 2020-10-15 18:21:55 -07:00
multi_process_runner.py MultiProcessRunner: Open source multi_process_runner with a OSS backend. 2020-10-15 14:09:23 -07:00
multi_worker_continuous_run_test.py MultiProcessRunner: symbol replacement: barrier->get_barrier 2020-10-07 10:51:25 -07:00
multi_worker_test_base_test.py Use MPR for fault tolerance test 2020-08-21 00:08:42 -07:00
multi_worker_test_base.py Set 'GRPC_FAIL_FAST' to 'user_caller' in MultiProcessCluster. 2020-10-20 15:16:14 -07:00
multi_worker_util_test.py Move away from deprecated asserts 2020-06-30 16:10:22 -07:00
multi_worker_util.py
numpy_dataset_test.py
numpy_dataset.py
one_device_strategy_test.py Add InputOption support to all remaining strategies. 2020-06-24 16:20:39 -07:00
one_device_strategy.py Merge pull request from kushanam:distribute_dali_ctl 2020-10-19 09:25:22 -07:00
packed_distributed_variable_test.py Support packed variable in DistributedVariable. Add an option to enable packed variable in TPUStrategy. 2020-06-18 20:12:02 -07:00
packed_distributed_variable.py Introduce a SaveContext to detect whether we are building a graph for a SavedModel. And don't use packed variables under a SaveContext. 2020-06-23 12:21:25 -07:00
parameter_server_strategy_test.py PSv2: Dedup the legacy ParameterServerStrategy class (as the estimator usage of it uses ParameterServerStrategyV1). 2020-10-21 12:16:22 -07:00
parameter_server_strategy_v2_test.py Allow parameter server strategy sharded variable creator to efficiently use TF initializers. 2020-10-21 18:28:38 -07:00
parameter_server_strategy_v2.py Change the docstrings of parameter_server_strategy_v2 and cluster coordinator to make them more accurate and more consistent with our tutorial. 2020-10-21 18:54:49 -07:00
parameter_server_strategy.py PSv2: Dedup the legacy ParameterServerStrategy class (as the estimator usage of it uses ParameterServerStrategyV1). 2020-10-21 12:16:22 -07:00
ps_values_test.py Replace usages of Tensorflow DistributionStrategy method experimental_run_v2 with run. 2020-06-29 11:22:53 -07:00
ps_values.py [TF DistStrat] Add support for deepcopy on AggregatingVariable (PS) 2020-08-19 08:57:16 -07:00
README.md Graduate TPUStrategy from experimental. 2020-06-20 13:10:50 -07:00
reduce_util.py
remote_mirrored_strategy_eager_test.py
sharded_variable_test.py Use ops dispatch to overwrite the behavior of embedding_lookup ops when called with ShardedVariable. Otherwise ShardedVariable will be converted to a dense tensor when passing to embedding_lookup. 2020-10-21 18:02:03 -07:00
sharded_variable.py Use ops dispatch to overwrite the behavior of embedding_lookup ops when called with ShardedVariable. Otherwise ShardedVariable will be converted to a dense tensor when passing to embedding_lookup. 2020-10-21 18:02:03 -07:00
shared_variable_creator_test.py Move away from deprecated asserts 2020-06-30 16:10:22 -07:00
shared_variable_creator.py
single_loss_example.py Update minimize_loss_test to not rely on Keras. 2020-07-07 21:39:06 -07:00
step_fn.py
strategy_combinations_test.py Create different strategy based on TF1/2 in strategy_combinations 2020-10-09 17:02:10 -07:00
strategy_combinations.py Fix central storage 2 gpus combination 2020-10-12 14:24:49 -07:00
strategy_common_test.py Split strategy_common_test into two pieces as this test is currently timing out. 2020-10-13 10:11:48 -07:00
strategy_gather_test.py Fix and test all_gather gradient. 2020-10-21 03:47:11 -07:00
strategy_test_lib.py Remove numpy_datasets from V2 strategies 2020-10-12 14:30:17 -07:00
summary_op_util.py
test_util_test.py Add test_util.main() and test_util.set_logical_devices_to_at_least() 2020-10-06 16:30:51 -07:00
test_util.py Adding new APIs under tf.distribute: gather and all_gather. 2020-10-19 18:09:08 -07:00
tf_function_test.py Always retrace in tf.saved_model.save 2020-10-10 12:18:19 -07:00
tpu_strategy_compilation_test.py Pass non empty MLIR module serialized string when constructing TpuCompilationCacheKey. 2020-07-24 16:40:48 -07:00
tpu_strategy_test.py Make compilation error check generic. 2020-10-19 14:48:27 -07:00
tpu_strategy.py Fix and test all_gather gradient. 2020-10-21 03:47:11 -07:00
tpu_values.py Followup CL to add support for saving non distributed version of variables with policy enabled. 2020-08-06 01:34:57 -07:00
values_test.py Disallow saving if the function cannot be used for inference 2020-10-15 21:08:51 -07:00
values_util.py Disallow saving if the function cannot be used for inference 2020-10-15 21:08:51 -07:00
values.py [retry]DistributedDataset creates elements with fixed spec to help avoid retracing 2020-10-20 22:48:17 -07:00
vars_test.py Add test_util.main() and test_util.set_logical_devices_to_at_least() 2020-10-06 16:30:51 -07:00
warm_starting_util_test.py
zero_batch_test.py Fix input size used for batch normalization. 2020-04-09 22:01:21 -07:00

Tensorflow Distribute Libraries

Overview

tf.distribute.Strategy is a TensorFlow API to distribute training across multiple GPUs, multiple machines or TPUs. Using this API, users can distribute their existing models and training code with minimal code changes.

It can be used with TensorFlow's high level APIs, tf.keras and tf.estimator, with just a couple of lines of code change. It does so by changing the underlying components of TensorFlow to become strategy-aware. This includes variables, layers, models, optimizers, metrics, summaries, and checkpoints.

Documentation

Distributed Training Guide

Distributed Training With Keras Tutorial

Distributed Training With Custom Training Loops Tutorial

Multiworker Training With Keras Tutorial

Multiworker Training With Estimator Tutorial

Save and Load with Distribution Strategy

Simple Examples

Using compile fit with GPUs.

# Create the strategy instance. It will automatically detect all the GPUs.
mirrored_strategy = tf.distribute.MirroredStrategy()

# Create and compile the keras model under strategy.scope()
with mirrored_strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  model.compile(loss='mse', optimizer='sgd')

# Call model.fit and model.evaluate as before.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
model.evaluate(dataset)

Custom training loop with TPUs.

# Create the strategy instance.
tpu_strategy = tf.distribute.TPUStrategy(resolver)


# Create the keras model under strategy.scope()
with tpu_strategy.scope():
  model = keras.layers.Dense(1, name="dense")

# Create custom training loop body as tf.function.
@tf.function
def train_step(iterator):
  def step_fn(inputs):
    images, targets = inputs
    with tf.GradientTape() as tape:
      outputs = model(images)
      loss = tf.reduce_sum(outputs - targets)
    grads = tape.gradient(loss, model.variables)
    return grads

  return tpu_strategy.run(
      step_fn, args=(next(iterator),))

# Run the loop body once on at dataset.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10
input_iterator = iter(tpu_strategy.experimental_distribute_dataset(dataset))
train_step(input_iterator)

Testing

Tests here should cover all distribution strategies to ensure feature parity. This can be done using the test decorators in strategy_combinations.py.