Merge pull request #40808 from geetachavan1/cherrypicks_X2P52

[CherryPick:r2.3] Allow tf.distribute.TPUStrategy to be used with TPUEmbedding API and ensure that LossScaleOptimizer properly rejects it.
This commit is contained in:
Goldie Gadde 2020-06-25 21:39:08 -07:00 committed by GitHub
commit 9cc469ac21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -440,7 +440,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
if not strategy_supports_loss_scaling(): if not strategy_supports_loss_scaling():
strategy = distribution_strategy_context.get_strategy() strategy = distribution_strategy_context.get_strategy()
if isinstance(strategy, if isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
tpu_strategy.TPUStrategyV2)):
raise ValueError( raise ValueError(
'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 'Loss scaling is not supported with TPUStrategy. Loss scaling is '
'unnecessary with TPUs, since they support bfloat16 instead of ' 'unnecessary with TPUs, since they support bfloat16 instead of '

View File

@ -265,7 +265,8 @@ class TPUEmbedding(tracking.AutoTrackable):
Adam or Adagrad). Adam or Adagrad).
""" """
self._strategy = distribution_strategy_context.get_strategy() self._strategy = distribution_strategy_context.get_strategy()
self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy) self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV2))
self._pipeline_execution_with_tensor_core = ( self._pipeline_execution_with_tensor_core = (
pipeline_execution_with_tensor_core) pipeline_execution_with_tensor_core)