Remove num_cores from the TPU Strategy constructor as that is not supported or correct.
PiperOrigin-RevId: 226396428
This commit is contained in:
parent
0c31fca446
commit
795f9cbf91
@ -126,8 +126,7 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
tpu_cluster_resolver=None,
|
tpu_cluster_resolver=None,
|
||||||
steps_per_run=None,
|
steps_per_run=None):
|
||||||
num_cores=None):
|
|
||||||
"""Initializes the TPUStrategy object.
|
"""Initializes the TPUStrategy object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -138,11 +137,9 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
metrics, summaries etc.
|
metrics, summaries etc.
|
||||||
This parameter is only used when Distribution Strategy is used with
|
This parameter is only used when Distribution Strategy is used with
|
||||||
estimator or keras.
|
estimator or keras.
|
||||||
num_cores: Number of cores to use on the TPU. If None specified, then
|
|
||||||
auto-detect the cores and topology of the TPU system.
|
|
||||||
"""
|
"""
|
||||||
super(TPUStrategy, self).__init__(TPUExtended(
|
super(TPUStrategy, self).__init__(TPUExtended(
|
||||||
self, tpu_cluster_resolver, steps_per_run, num_cores))
|
self, tpu_cluster_resolver, steps_per_run))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps_per_run(self):
|
def steps_per_run(self):
|
||||||
@ -161,8 +158,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
container_strategy,
|
container_strategy,
|
||||||
tpu_cluster_resolver=None,
|
tpu_cluster_resolver=None,
|
||||||
steps_per_run=None,
|
steps_per_run=None):
|
||||||
num_cores=None):
|
|
||||||
super(TPUExtended, self).__init__(container_strategy)
|
super(TPUExtended, self).__init__(container_strategy)
|
||||||
|
|
||||||
if tpu_cluster_resolver is None:
|
if tpu_cluster_resolver is None:
|
||||||
@ -175,8 +171,6 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
|
|
||||||
self._tpu_cluster_resolver = tpu_cluster_resolver
|
self._tpu_cluster_resolver = tpu_cluster_resolver
|
||||||
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
|
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
|
||||||
# TODO(sourabhbajaj): Change this from num_cores to metadata_override
|
|
||||||
self._num_cores_override = num_cores
|
|
||||||
|
|
||||||
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
|
# TODO(jhseu): Switch to DeviceAssignment to support pods and model
|
||||||
# parallelism.
|
# parallelism.
|
||||||
@ -570,7 +564,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _num_replicas_in_sync(self):
|
def _num_replicas_in_sync(self):
|
||||||
return self._num_cores_override or self._tpu_metadata.num_cores
|
return self._tpu_metadata.num_cores
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experimental_between_graph(self):
|
def experimental_between_graph(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user