Remove computation_shape
from tpu_config.py comments.
PiperOrigin-RevId: 209673552
This commit is contained in:
parent
daf992961e
commit
7989b2bc99
@ -65,7 +65,7 @@ class TPUConfig(
|
||||
The number of model replicas in the system. For non-model-parallelism
|
||||
case, this number equals the total number of TPU cores. For
|
||||
model-parallelism, the total number of TPU cores equals
|
||||
product(computation_shape) * num_shards.
|
||||
num_cores_per_replica * num_shards.
|
||||
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
|
||||
An integer which describes the number of TPU cores per model replica. This
|
||||
is required by model-parallelism which enables partitioning
|
||||
@ -103,7 +103,7 @@ class TPUConfig(
|
||||
input mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If `computation_shape` or `computation_shape` are invalid.
|
||||
ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
|
||||
"""
|
||||
|
||||
def __new__(cls,
|
||||
@ -137,7 +137,7 @@ class TPUConfig(
|
||||
raise ValueError(
|
||||
'input_partition_dims requires setting num_cores_per_replica.')
|
||||
|
||||
# Parse computation_shape
|
||||
# Check num_cores_per_replica
|
||||
if num_cores_per_replica is not None:
|
||||
if num_cores_per_replica not in [1, 2, 4, 8]:
|
||||
raise ValueError(
|
||||
|
Loading…
Reference in New Issue
Block a user