Merge pull request #40808 from geetachavan1/cherrypicks_X2P52
[CherryPick:r2.3] Allow tf.distribute.TPUStrategy to be used with TPUEmbedding API and ensure that LossScaleOptimizer properly rejects it.
This commit is contained in:
commit
9cc469ac21
|
@ -440,7 +440,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||||
if not strategy_supports_loss_scaling():
|
if not strategy_supports_loss_scaling():
|
||||||
strategy = distribution_strategy_context.get_strategy()
|
strategy = distribution_strategy_context.get_strategy()
|
||||||
if isinstance(strategy,
|
if isinstance(strategy,
|
||||||
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
|
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
|
||||||
|
tpu_strategy.TPUStrategyV2)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
|
'Loss scaling is not supported with TPUStrategy. Loss scaling is '
|
||||||
'unnecessary with TPUs, since they support bfloat16 instead of '
|
'unnecessary with TPUs, since they support bfloat16 instead of '
|
||||||
|
|
|
@ -265,7 +265,8 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||||
Adam or Adagrad).
|
Adam or Adagrad).
|
||||||
"""
|
"""
|
||||||
self._strategy = distribution_strategy_context.get_strategy()
|
self._strategy = distribution_strategy_context.get_strategy()
|
||||||
self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy)
|
self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
|
||||||
|
tpu_strategy.TPUStrategyV2))
|
||||||
self._pipeline_execution_with_tensor_core = (
|
self._pipeline_execution_with_tensor_core = (
|
||||||
pipeline_execution_with_tensor_core)
|
pipeline_execution_with_tensor_core)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue