Remove computation_shape from tpu_config.py comments.

PiperOrigin-RevId: 209673552
This commit is contained in:
Ruoxin Sang 2018-08-21 15:45:12 -07:00 committed by TensorFlower Gardener
parent daf992961e
commit 7989b2bc99

View File

@ -65,7 +65,7 @@ class TPUConfig(
The number of model replicas in the system. For non-model-parallelism The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals model-parallelism, the total number of TPU cores equals
product(computation_shape) * num_shards. num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism. num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning is required by model-parallelism which enables partitioning
@ -103,7 +103,7 @@ class TPUConfig(
input mode. input mode.
Raises: Raises:
ValueError: If `computation_shape` or `computation_shape` are invalid. ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
""" """
def __new__(cls, def __new__(cls,
@ -137,7 +137,7 @@ class TPUConfig(
raise ValueError( raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.') 'input_partition_dims requires setting num_cores_per_replica.')
# Parse computation_shape # Check num_cores_per_replica
if num_cores_per_replica is not None: if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8]: if num_cores_per_replica not in [1, 2, 4, 8]:
raise ValueError( raise ValueError(