tf.contrib.AsyncCheckpointSaverHook creates a thread to asynchronously run the export, which currently breaks when exporting a TPU graph. PiperOrigin-RevId: 258184145
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# =============================================================================
|
|
|
|
"""Helper library for functions used during TPU compilation."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import contextlib
|
|
import threading
|
|
|
|
|
|
class TpuContext(threading.local):
|
|
"""A context object holding state about the TPU computation being built."""
|
|
|
|
def __init__(self):
|
|
"""Creates a new TpuContext."""
|
|
self._number_of_shards = None
|
|
|
|
@property
|
|
def number_of_shards(self):
|
|
return self._number_of_shards
|
|
|
|
def set_number_of_shards(self, number_of_shards):
|
|
self._number_of_shards = number_of_shards
|
|
|
|
|
|
# The Tpu context holds the number of shards when a sharded computation is
|
|
# being built, or None if no computation is being built.
|
|
_current_tpu_context = TpuContext()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tpu_shard_context(number_of_shards):
|
|
if _current_tpu_context.number_of_shards is not None:
|
|
raise NotImplementedError("tpu_shard_context cannot be nested.")
|
|
try:
|
|
_current_tpu_context.set_number_of_shards(number_of_shards)
|
|
yield
|
|
finally:
|
|
_current_tpu_context.set_number_of_shards(None)
|
|
|
|
|
|
def get_tpu_context():
|
|
return _current_tpu_context
|
|
|
|
|
|
# Decorator function for tpu computation func that was passed to tpu.rewrite()
|
|
# if there is an embedded training loop in this func, trace tools will generate
|
|
# step markers for each iteration.
|
|
def on_device_training_loop(func):
|
|
# Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
|
|
setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
|
|
return func
|