Add an option to assign devices without forming a ring

PiperOrigin-RevId: 331193260
Change-Id: I1f0fc9a35e0949dfb2e7c5cd4062adb240a8801b
This commit is contained in:
Yuanzhong Xu 2020-09-11 11:56:15 -07:00 committed by TensorFlower Gardener
parent 6d56f6a3d2
commit 705f81d42a

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
import math
import numpy as np
@ -313,10 +314,22 @@ def _ring_3d(x_size, y_size, z_size):
return ret
class DeviceOrderMode(enum.IntEnum):
"""The way of determining device orders when computing device assignment."""
# By default the mode is set to AUTO, the library will choose to form rings
# when that is possible.
AUTO = 0
# Form rings for replicas and model-parallel cores.
RING = 1
# Form meshes for replicas and/or model-parallel cores.
MESH = 2
def device_assignment(topology,
computation_shape=None,
computation_stride=None,
num_replicas=1):
num_replicas=1,
device_order_mode=DeviceOrderMode.AUTO):
"""Computes a device_assignment of a computation across a TPU topology.
Attempts to choose a compact grid of cores for locality.
@ -341,6 +354,9 @@ def device_assignment(topology,
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
device_order_mode: An enum of `DeviceOrderMode` class which indicates
whether to assign devices to form rings or meshes, or let the library to
choose.
Returns:
A DeviceAssignment object, which describes the mapping between the logical
@ -450,6 +466,12 @@ def device_assignment(topology,
computation_shape[-1] == 2 # Only handle 3D case.
and np.prod(computation_stride) == 1 # Ensure no stride.
and num_replicas == max_replicas) # Full replication.
if device_order_mode != DeviceOrderMode.AUTO:
if device_order_mode == DeviceOrderMode.RING and not enable_3d_tiling:
raise ValueError("cannot assign ring order in the given topology")
enable_3d_tiling = device_order_mode == DeviceOrderMode.RING
if enable_3d_tiling:
assignment = []
inner_ring = _ring_3d(computation_shape[0], computation_shape[1],