diff --git a/tensorflow/python/tpu/device_assignment.py b/tensorflow/python/tpu/device_assignment.py index 9e805655a01..f8cb4e16266 100644 --- a/tensorflow/python/tpu/device_assignment.py +++ b/tensorflow/python/tpu/device_assignment.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum import math import numpy as np @@ -313,10 +314,22 @@ def _ring_3d(x_size, y_size, z_size): return ret +class DeviceOrderMode(enum.IntEnum): + """The way of determining device orders when computing device assignment.""" + # By default the mode is set to AUTO, the library will choose to form rings + # when that is possible. + AUTO = 0 + # Form rings for replicas and model-parallel cores. + RING = 1 + # Form meshes for replicas and/or model-parallel cores. + MESH = 2 + + def device_assignment(topology, computation_shape=None, computation_stride=None, - num_replicas=1): + num_replicas=1, + device_order_mode=DeviceOrderMode.AUTO): """Computes a device_assignment of a computation across a TPU topology. Attempts to choose a compact grid of cores for locality. @@ -341,6 +354,9 @@ def device_assignment(topology, TPU topology. If None, the `computation_stride` is `[1] * topology_rank`. num_replicas: The number of computation replicas to run. The replicas will be packed into the free spaces of the topology. + device_order_mode: An enum of `DeviceOrderMode` class which indicates + whether to assign devices to form rings or meshes, or let the library to + choose. Returns: A DeviceAssignment object, which describes the mapping between the logical @@ -450,6 +466,12 @@ def device_assignment(topology, computation_shape[-1] == 2 # Only handle 3D case. and np.prod(computation_stride) == 1 # Ensure no stride. and num_replicas == max_replicas) # Full replication. + + if device_order_mode != DeviceOrderMode.AUTO: + if device_order_mode == DeviceOrderMode.RING and not enable_3d_tiling: + raise ValueError("cannot assign ring order in the given topology") + enable_3d_tiling = device_order_mode == DeviceOrderMode.RING + if enable_3d_tiling: assignment = [] inner_ring = _ring_3d(computation_shape[0], computation_shape[1],