Add an option to assign devices without forming a ring
PiperOrigin-RevId: 331193260 Change-Id: I1f0fc9a35e0949dfb2e7c5cd4062adb240a8801b
This commit is contained in:
parent
6d56f6a3d2
commit
705f81d42a
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
@ -313,10 +314,22 @@ def _ring_3d(x_size, y_size, z_size):
|
||||
return ret
|
||||
|
||||
|
||||
class DeviceOrderMode(enum.IntEnum):
|
||||
"""The way of determining device orders when computing device assignment."""
|
||||
# By default the mode is set to AUTO, the library will choose to form rings
|
||||
# when that is possible.
|
||||
AUTO = 0
|
||||
# Form rings for replicas and model-parallel cores.
|
||||
RING = 1
|
||||
# Form meshes for replicas and/or model-parallel cores.
|
||||
MESH = 2
|
||||
|
||||
|
||||
def device_assignment(topology,
|
||||
computation_shape=None,
|
||||
computation_stride=None,
|
||||
num_replicas=1):
|
||||
num_replicas=1,
|
||||
device_order_mode=DeviceOrderMode.AUTO):
|
||||
"""Computes a device_assignment of a computation across a TPU topology.
|
||||
|
||||
Attempts to choose a compact grid of cores for locality.
|
||||
@ -341,6 +354,9 @@ def device_assignment(topology,
|
||||
TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
|
||||
num_replicas: The number of computation replicas to run. The replicas will
|
||||
be packed into the free spaces of the topology.
|
||||
device_order_mode: An enum of `DeviceOrderMode` class which indicates
|
||||
whether to assign devices to form rings or meshes, or let the library to
|
||||
choose.
|
||||
|
||||
Returns:
|
||||
A DeviceAssignment object, which describes the mapping between the logical
|
||||
@ -450,6 +466,12 @@ def device_assignment(topology,
|
||||
computation_shape[-1] == 2 # Only handle 3D case.
|
||||
and np.prod(computation_stride) == 1 # Ensure no stride.
|
||||
and num_replicas == max_replicas) # Full replication.
|
||||
|
||||
if device_order_mode != DeviceOrderMode.AUTO:
|
||||
if device_order_mode == DeviceOrderMode.RING and not enable_3d_tiling:
|
||||
raise ValueError("cannot assign ring order in the given topology")
|
||||
enable_3d_tiling = device_order_mode == DeviceOrderMode.RING
|
||||
|
||||
if enable_3d_tiling:
|
||||
assignment = []
|
||||
inner_ring = _ring_3d(computation_shape[0], computation_shape[1],
|
||||
|
Loading…
Reference in New Issue
Block a user