From b878ae39183019c6ed0dfa35338c04d2b5adf145 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Wed, 24 Jun 2020 19:00:55 -0700 Subject: [PATCH] Allow tf.distribute.TPUStrategy to be used with TPUEmbedding API and ensure that LossScaleOptimizer properly rejects it. PiperOrigin-RevId: 318186211 Change-Id: Id3b9cb8288e5d28ddbaec97d5b35627ab35bc08d --- .../keras/mixed_precision/experimental/loss_scale_optimizer.py | 3 ++- tensorflow/python/tpu/tpu_embedding_v2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index 72d9d0390fd..d7de7c41208 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -440,7 +440,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): if not strategy_supports_loss_scaling(): strategy = distribution_strategy_context.get_strategy() if isinstance(strategy, - (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, + tpu_strategy.TPUStrategyV2)): raise ValueError( 'Loss scaling is not supported with TPUStrategy. Loss scaling is ' 'unnecessary with TPUs, since they support bfloat16 instead of ' diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 6db67fea367..eea2dea53c2 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -265,7 +265,8 @@ class TPUEmbedding(tracking.AutoTrackable): Adam or Adagrad). """ self._strategy = distribution_strategy_context.get_strategy() - self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy) + self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy, + tpu_strategy.TPUStrategyV2)) self._pipeline_execution_with_tensor_core = ( pipeline_execution_with_tensor_core)