Remove computation_shape from tpu_config.py comments.

PiperOrigin-RevId: 209673552
This commit is contained in:
Ruoxin Sang 2018-08-21 15:45:12 -07:00 committed by TensorFlower Gardener
parent daf992961e
commit 7989b2bc99

View File

@ -65,7 +65,7 @@ class TPUConfig(
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
product(computation_shape) * num_shards.
num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
@ -103,7 +103,7 @@ class TPUConfig(
input mode.
Raises:
ValueError: If `computation_shape` or `computation_shape` are invalid.
ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
"""
def __new__(cls,
@ -137,7 +137,7 @@ class TPUConfig(
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
# Parse computation_shape
# Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8]:
raise ValueError(