diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py index f86a7869ec4..43c0d1a40c3 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py @@ -53,10 +53,10 @@ class TpuBackend(xla_client.Backend): if worker == 'local' or 'local://' in worker: # We usually want to cache for local backends to prevent double # initialization, except where `force` == True. - if force: - return TpuBackend(_tpu_client.TpuClient.Get(worker)) if worker == 'local': worker = 'local://' + if force: + return TpuBackend(_tpu_client.TpuClient.Get(worker)) if TpuBackend._local_backend is None: logging.info('Starting the local TPU driver.') TpuBackend._local_backend = TpuBackend(