Apply tf1->tf2 name replaces to doc-strings and comments in tensorflow.

No code changes, only doc-strings and comments.

PiperOrigin-RevId: 244275767
This commit is contained in:
Mark Daoust 2019-04-18 15:56:07 -07:00 committed by TensorFlower Gardener
parent 7fdf27b688
commit 18b680216e
22 changed files with 851 additions and 763 deletions

View File

@ -22,8 +22,11 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["train.basic_train_loop"]) @tf_export(v1=["train.basic_train_loop"])
def basic_train_loop(supervisor, train_step_fn, args=None, def basic_train_loop(supervisor,
kwargs=None, master=""): train_step_fn,
args=None,
kwargs=None,
master=""):
"""Basic loop to train a model. """Basic loop to train a model.
Calls `train_step_fn` in a loop to train a model. The function is called as: Calls `train_step_fn` in a loop to train a model. The function is called as:
@ -32,17 +35,18 @@ def basic_train_loop(supervisor, train_step_fn, args=None,
train_step_fn(session, *args, **kwargs) train_step_fn(session, *args, **kwargs)
``` ```
It is passed a `tf.Session` in addition to `args` and `kwargs`. The function It is passed a `tf.compat.v1.Session` in addition to `args` and `kwargs`. The
function
typically runs one training step in the session. typically runs one training step in the session.
Args: Args:
supervisor: `tf.train.Supervisor` to run the training services. supervisor: `tf.compat.v1.train.Supervisor` to run the training services.
train_step_fn: Callable to execute one training step. Called train_step_fn: Callable to execute one training step. Called repeatedly as
repeatedly as `train_step_fn(session, *args **kwargs)`. `train_step_fn(session, *args **kwargs)`.
args: Optional positional arguments passed to `train_step_fn`. args: Optional positional arguments passed to `train_step_fn`.
kwargs: Optional keyword arguments passed to `train_step_fn`. kwargs: Optional keyword arguments passed to `train_step_fn`.
master: Master to use to create the training session. Defaults to master: Master to use to create the training session. Defaults to `""`
`""` which causes the session to be created in the local process. which causes the session to be created in the local process.
""" """
if args is None: if args is None:
args = [] args = []

View File

@ -42,7 +42,6 @@ from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache from tensorflow.python.training.summary_io import SummaryWriterCache
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
_HOOKS = "hooks" _HOOKS = "hooks"
_STEPS_PER_RUN_VAR = "steps_per_run" _STEPS_PER_RUN_VAR = "steps_per_run"
@ -85,8 +84,7 @@ class _HookTimer(object):
@tf_export(v1=["train.SecondOrStepTimer"]) @tf_export(v1=["train.SecondOrStepTimer"])
class SecondOrStepTimer(_HookTimer): class SecondOrStepTimer(_HookTimer):
"""Timer that triggers at most once every N seconds or once every N steps. """Timer that triggers at most once every N seconds or once every N steps."""
"""
def __init__(self, every_secs=None, every_steps=None): def __init__(self, every_secs=None, every_steps=None):
self.reset() self.reset()
@ -171,29 +169,33 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
seeing the logs, you might want to add the following line after your imports: seeing the logs, you might want to add the following line after your imports:
```python ```python
tf.logging.set_verbosity(tf.logging.INFO) tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
``` ```
Note that if `at_end` is True, `tensors` should not include any tensor Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs. whose evaluation produces a side effect such as consuming additional inputs.
""" """
def __init__(self, tensors, every_n_iter=None, every_n_secs=None, def __init__(self,
at_end=False, formatter=None): tensors,
every_n_iter=None,
every_n_secs=None,
at_end=False,
formatter=None):
"""Initializes a `LoggingTensorHook`. """Initializes a `LoggingTensorHook`.
Args: Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names, tensors: `dict` that maps string-valued tags to tensors/tensor names, or
or `iterable` of tensors/tensor names. `iterable` of tensors/tensor names.
every_n_iter: `int`, print the values of `tensors` once every N local every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker. steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided. provided.
at_end: `bool` specifying whether to print the values of `tensors` at the at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run. end of the run.
formatter: function, takes dict of `tag`->`Tensor` and returns a string. formatter: function, takes dict of `tag`->`Tensor` and returns a string.
If `None` uses default printing all tensors. If `None` uses default printing all tensors.
Raises: Raises:
ValueError: if `every_n_iter` is non-positive. ValueError: if `every_n_iter` is non-positive.
@ -215,16 +217,18 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
self._tensors = tensors self._tensors = tensors
self._formatter = formatter self._formatter = formatter
self._timer = ( self._timer = (
NeverTriggerTimer() if only_log_at_end else NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter)) every_secs=every_n_secs, every_steps=every_n_iter))
self._log_at_end = at_end self._log_at_end = at_end
def begin(self): def begin(self):
self._timer.reset() self._timer.reset()
self._iter_count = 0 self._iter_count = 0
# Convert names to tensors if given # Convert names to tensors if given
self._current_tensors = {tag: _as_graph_element(tensor) self._current_tensors = {
for (tag, tensor) in self._tensors.items()} tag: _as_graph_element(tensor)
for (tag, tensor) in self._tensors.items()
}
def before_run(self, run_context): # pylint: disable=unused-argument def before_run(self, run_context): # pylint: disable=unused-argument
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
@ -463,9 +467,10 @@ class CheckpointSaverListener(object):
... ...
listener = ExampleCheckpointSaverListener() listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook( saver_hook = tf.estimator.CheckpointSaverHook(
checkpoint_dir, listeners=[listener]) checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): with
tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
... ...
``` ```
@ -516,9 +521,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
saver: `Saver` object, used for saving. saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files. checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object. scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances. listeners: List of `CheckpointSaverListener` subclass instances. Used for
Used for callbacks that run immediately before or after this hook saves callbacks that run immediately before or after this hook saves the
the checkpoint. checkpoint.
Raises: Raises:
ValueError: One of `save_steps` or `save_secs` should be set. ValueError: One of `save_steps` or `save_secs` should be set.
@ -531,8 +536,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._checkpoint_dir = checkpoint_dir self._checkpoint_dir = checkpoint_dir
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs, self._timer = SecondOrStepTimer(
every_steps=save_steps) every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or [] self._listeners = listeners or []
self._steps_per_run = 1 self._steps_per_run = 1
@ -555,13 +560,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
# add variables in begin. Graph is finalized after all begin calls. # add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph( training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True), ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, self._checkpoint_dir, "graph.pbtxt")
"graph.pbtxt")
saver_def = self._get_saver().saver_def if self._get_saver() else None saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph() graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def( meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True), graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
saver_def=saver_def)
self._summary_writer.add_graph(graph) self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def) self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step". # The checkpoint saved here is the state at step "global_step".
@ -573,8 +576,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values): def after_run(self, run_context, run_values):
stale_global_step = run_values.results stale_global_step = run_values.results
if self._timer.should_trigger_for_step( if self._timer.should_trigger_for_step(stale_global_step +
stale_global_step + self._steps_per_run): self._steps_per_run):
# get the real value after train op. # get the real value after train op.
global_step = run_context.session.run(self._global_step_tensor) global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step): if self._timer.should_trigger_for_step(global_step):
@ -627,8 +630,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
elif len(savers) > 1: elif len(savers) > 1:
raise RuntimeError( raise RuntimeError(
"More than one item in collection {}. " "More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor.". "Please indicate which one to use by passing it to the constructor."
format(collection_key)) .format(collection_key))
self._saver = savers[0] self._saver = savers[0]
return savers[0] return savers[0]
@ -647,8 +650,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
if (every_n_steps is None) == (every_n_secs is None): if (every_n_steps is None) == (every_n_secs is None):
raise ValueError( raise ValueError(
"exactly one of every_n_steps and every_n_secs should be provided.") "exactly one of every_n_steps and every_n_secs should be provided.")
self._timer = SecondOrStepTimer(every_steps=every_n_steps, self._timer = SecondOrStepTimer(
every_secs=every_n_secs) every_steps=every_n_steps, every_secs=every_n_secs)
self._summary_writer = summary_writer self._summary_writer = summary_writer
self._output_dir = output_dir self._output_dir = output_dir
@ -673,8 +676,9 @@ class StepCounterHook(session_run_hook.SessionRunHook):
def _log_and_record(self, elapsed_steps, elapsed_time, global_step): def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
steps_per_sec = elapsed_steps / elapsed_time steps_per_sec = elapsed_steps / elapsed_time
if self._summary_writer is not None: if self._summary_writer is not None:
summary = Summary(value=[Summary.Value( summary = Summary(value=[
tag=self._summary_tag, simple_value=steps_per_sec)]) Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
])
self._summary_writer.add_summary(summary, global_step) self._summary_writer.add_summary(summary, global_step)
logging.info("%s: %g", self._summary_tag, steps_per_sec) logging.info("%s: %g", self._summary_tag, steps_per_sec)
@ -682,8 +686,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
_ = run_context _ = run_context
stale_global_step = run_values.results stale_global_step = run_values.results
if self._timer.should_trigger_for_step( if self._timer.should_trigger_for_step(stale_global_step +
stale_global_step + self._steps_per_run): self._steps_per_run):
# get the real value after train op. # get the real value after train op.
global_step = run_context.session.run(self._global_step_tensor) global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step): if self._timer.should_trigger_for_step(global_step):
@ -767,18 +771,18 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
Args: Args:
save_steps: `int`, save summaries every N steps. Exactly one of save_steps: `int`, save summaries every N steps. Exactly one of
`save_secs` and `save_steps` should be set. `save_secs` and `save_steps` should be set.
save_secs: `int`, save summaries every N seconds. save_secs: `int`, save summaries every N seconds.
output_dir: `string`, the directory to save the summaries to. Only used output_dir: `string`, the directory to save the summaries to. Only used if
if no `summary_writer` is supplied. no `summary_writer` is supplied.
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
one will be created accordingly. one will be created accordingly.
scaffold: `Scaffold` to get summary_op if it's not provided. scaffold: `Scaffold` to get summary_op if it's not provided.
summary_op: `Tensor` of type `string` containing the serialized `Summary` summary_op: `Tensor` of type `string` containing the serialized `Summary`
protocol buffer or a list of `Tensor`. They are most likely an output protocol buffer or a list of `Tensor`. They are most likely an output by
by TF summary methods like `tf.summary.scalar` or TF summary methods like `tf.compat.v1.summary.scalar` or
`tf.summary.merge_all`. It can be passed in as one tensor; if more `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
than one, they must be passed in as a list. more than one, they must be passed in as a list.
Raises: Raises:
ValueError: Exactly one of scaffold or summary_op should be set. ValueError: Exactly one of scaffold or summary_op should be set.
@ -791,8 +795,8 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
self._summary_writer = summary_writer self._summary_writer = summary_writer
self._output_dir = output_dir self._output_dir = output_dir
self._scaffold = scaffold self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs, self._timer = SecondOrStepTimer(
every_steps=save_steps) every_secs=save_secs, every_steps=save_steps)
# TODO(mdan): Throw an error if output_dir and summary_writer are None. # TODO(mdan): Throw an error if output_dir and summary_writer are None.
def begin(self): def begin(self):
@ -903,8 +907,9 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
self._worker_is_started = True self._worker_is_started = True
return None return None
if current_step - last_logged_step > 1000: if current_step - last_logged_step > 1000:
logging.info("Waiting for global step %d before starting training. " logging.info(
"Current step is %d.", self._wait_until_step, current_step) "Waiting for global step %d before starting training. "
"Current step is %d.", self._wait_until_step, current_step)
last_logged_step = current_step last_logged_step = current_step
time.sleep(0.5) time.sleep(0.5)
@ -917,8 +922,8 @@ class FinalOpsHook(session_run_hook.SessionRunHook):
"""Initializes `FinalOpHook` with ops to run at the end of the session. """Initializes `FinalOpHook` with ops to run at the end of the session.
Args: Args:
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
names to `Tensors`. to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when running final_ops_feed_dict: A feed dictionary to use when running
`final_ops_dict`. `final_ops_dict`.
""" """
@ -997,14 +1002,14 @@ class ProfilerHook(session_run_hook.SessionRunHook):
Args: Args:
save_steps: `int`, save profile traces every N steps. Exactly one of save_steps: `int`, save profile traces every N steps. Exactly one of
`save_secs` and `save_steps` should be set. `save_secs` and `save_steps` should be set.
save_secs: `int` or `float`, save profile traces every N seconds. save_secs: `int` or `float`, save profile traces every N seconds.
output_dir: `string`, the directory to save the profile traces to. output_dir: `string`, the directory to save the profile traces to.
Defaults to the current directory. Defaults to the current directory.
show_dataflow: `bool`, if True, add flow events to the trace connecting show_dataflow: `bool`, if True, add flow events to the trace connecting
producers and consumers of tensors. producers and consumers of tensors.
show_memory: `bool`, if True, add object snapshot events to the trace show_memory: `bool`, if True, add object snapshot events to the trace
showing the sizes and lifetimes of tensors. showing the sizes and lifetimes of tensors.
""" """
self._output_file = os.path.join(output_dir, "timeline-{}.json") self._output_file = os.path.join(output_dir, "timeline-{}.json")
self._file_writer = SummaryWriterCache.get(output_dir) self._file_writer = SummaryWriterCache.get(output_dir)
@ -1024,8 +1029,9 @@ class ProfilerHook(session_run_hook.SessionRunHook):
self._next_step is not None and self._next_step is not None and
self._timer.should_trigger_for_step(self._next_step)) self._timer.should_trigger_for_step(self._next_step))
requests = {"global_step": self._global_step_tensor} requests = {"global_step": self._global_step_tensor}
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) opts = (
if self._request_summary else None) config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
if self._request_summary else None)
return SessionRunArgs(requests, options=opts) return SessionRunArgs(requests, options=opts)
@ -1039,8 +1045,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
if self._request_summary: if self._request_summary:
global_step = run_context.session.run(self._global_step_tensor) global_step = run_context.session.run(self._global_step_tensor)
self._timer.update_last_triggered_step(global_step) self._timer.update_last_triggered_step(global_step)
self._save(global_step, self._save(global_step, self._output_file.format(global_step),
self._output_file.format(global_step),
run_values.run_metadata.step_stats) run_values.run_metadata.step_stats)
self._file_writer.add_run_metadata(run_values.run_metadata, self._file_writer.add_run_metadata(run_values.run_metadata,
"step_%d" % global_step) "step_%d" % global_step)

View File

@ -106,7 +106,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
"""Replaces `tf.Variable` initializers so they load from a checkpoint file. """Replaces `tf.Variable` initializers so they load from a checkpoint file.
Values are not loaded immediately, but when the initializer is run Values are not loaded immediately, but when the initializer is run
(typically by running a `tf.global_variables_initializer` op). (typically by running a `tf.compat.v1.global_variables_initializer` op).
Note: This overrides default initialization ops of specified variables and Note: This overrides default initialization ops of specified variables and
redefines dtype. redefines dtype.
@ -139,15 +139,15 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
# -- name='old_scope_2/var3', shape=[100, 100] # -- name='old_scope_2/var3', shape=[100, 100]
# Create new model's variables # Create new model's variables
with tf.variable_scope('new_scope_1'): with tf.compat.v1.variable_scope('new_scope_1'):
var1 = tf.get_variable('var1', shape=[20, 2], var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
initializer=tf.zeros_initializer()) initializer=tf.compat.v1.zeros_initializer())
with tf.variable_scope('new_scope_2'): with tf.compat.v1.variable_scope('new_scope_2'):
var2 = tf.get_variable('var2', shape=[50, 4], var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
initializer=tf.zeros_initializer()) initializer=tf.compat.v1.zeros_initializer())
# Partition into 5 variables along the first axis. # Partition into 5 variables along the first axis.
var3 = tf.get_variable(name='var3', shape=[100, 100], var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
initializer=tf.zeros_initializer(), initializer=tf.compat.v1.zeros_initializer(),
partitioner=lambda shape, dtype: [5, 1]) partitioner=lambda shape, dtype: [5, 1])
# Initialize all variables in `new_scope_1` from `old_scope_1`. # Initialize all variables in `new_scope_1` from `old_scope_1`.

View File

@ -131,9 +131,13 @@ class _ReplicaDeviceChooser(object):
@tf_export(v1=["train.replica_device_setter"]) @tf_export(v1=["train.replica_device_setter"])
def replica_device_setter(ps_tasks=0, ps_device="/job:ps", def replica_device_setter(ps_tasks=0,
worker_device="/job:worker", merge_devices=True, ps_device="/job:ps",
cluster=None, ps_ops=None, ps_strategy=None): worker_device="/job:worker",
merge_devices=True,
cluster=None,
ps_ops=None,
ps_strategy=None):
"""Return a `device function` to use when building a Graph for replicas. """Return a `device function` to use when building a Graph for replicas.
Device Functions are used in `with tf.device(device_function):` statement to Device Functions are used in `with tf.device(device_function):` statement to
@ -158,7 +162,8 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
cluster_spec = { cluster_spec = {
"ps": ["ps0:2222", "ps1:2222"], "ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]} "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)): with
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
# Build your graph # Build your graph
v1 = tf.Variable(...) # assigned to /job:ps/task:0 v1 = tf.Variable(...) # assigned to /job:ps/task:0
v2 = tf.Variable(...) # assigned to /job:ps/task:1 v2 = tf.Variable(...) # assigned to /job:ps/task:1
@ -218,6 +223,6 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
ps_strategy = _RoundRobinStrategy(ps_tasks) ps_strategy = _RoundRobinStrategy(ps_tasks)
if not six.callable(ps_strategy): if not six.callable(ps_strategy):
raise TypeError("ps_strategy must be callable") raise TypeError("ps_strategy must be callable")
chooser = _ReplicaDeviceChooser( chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy) merge_devices, ps_ops, ps_strategy)
return chooser.device_function return chooser.device_function

View File

@ -65,8 +65,8 @@ def _get_latest_eval_step_value(update_ops):
"""Gets the eval step `Tensor` value after running `update_ops`. """Gets the eval step `Tensor` value after running `update_ops`.
Args: Args:
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which
which are run before reading the eval step value. are run before reading the eval step value.
Returns: Returns:
A `Tensor` representing the value for the evaluation step. A `Tensor` representing the value for the evaluation step.
@ -102,21 +102,20 @@ class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
def after_create_session(self, session, coord): def after_create_session(self, session, coord):
# Update number of steps to run in the first run call # Update number of steps to run in the first run call
if self._num_evals is None: if self._num_evals is None:
steps = self._steps_per_run_initial_value steps = self._steps_per_run_initial_value
else: else:
steps = min(self._steps_per_run_initial_value, self._num_evals) steps = min(self._steps_per_run_initial_value, self._num_evals)
self._steps_per_run_variable.load(steps, session=session) self._steps_per_run_variable.load(steps, session=session)
def before_run(self, run_context): def before_run(self, run_context):
return session_run_hook.SessionRunArgs({ return session_run_hook.SessionRunArgs(
'evals_completed': self._evals_completed {'evals_completed': self._evals_completed})
})
def after_run(self, run_context, run_values): def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed'] evals_completed = run_values.results['evals_completed']
# Update number of steps to run in the next iteration # Update number of steps to run in the next iteration
if self._num_evals is None: if self._num_evals is None:
steps = self._steps_per_run_initial_value steps = self._steps_per_run_initial_value
else: else:
steps = min(self._num_evals - evals_completed, steps = min(self._num_evals - evals_completed,
@ -147,16 +146,15 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
self._evals_completed = None self._evals_completed = None
self._log_progress = log_progress self._log_progress = log_progress
# Reduce logging frequency if there are 20 or more evaluations. # Reduce logging frequency if there are 20 or more evaluations.
self._log_frequency = (1 if (num_evals is None or num_evals < 20) self._log_frequency = (1 if (num_evals is None or num_evals < 20) else
else math.floor(num_evals / 10.)) math.floor(num_evals / 10.))
def _set_evals_completed_tensor(self, updated_eval_step): def _set_evals_completed_tensor(self, updated_eval_step):
self._evals_completed = updated_eval_step self._evals_completed = updated_eval_step
def before_run(self, run_context): def before_run(self, run_context):
return session_run_hook.SessionRunArgs({ return session_run_hook.SessionRunArgs(
'evals_completed': self._evals_completed {'evals_completed': self._evals_completed})
})
def after_run(self, run_context, run_values): def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed'] evals_completed = run_values.results['evals_completed']
@ -205,20 +203,20 @@ def _evaluate_once(checkpoint_path,
Args: Args:
checkpoint_path: The path to a checkpoint to use for evaluation. checkpoint_path: The path to a checkpoint to use for evaluation.
master: The BNS address of the TensorFlow master. master: The BNS address of the TensorFlow master.
scaffold: An tf.train.Scaffold instance for initializing variables and scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables
restoring variables. Note that `scaffold.init_fn` is used by the function and restoring variables. Note that `scaffold.init_fn` is used by the
to restore the checkpoint. If you supply a custom init_fn, then it must function to restore the checkpoint. If you supply a custom init_fn, then
also take care of restoring the model from its checkpoint. it must also take care of restoring the model from its checkpoint.
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
to `Tensors`, which is run until the session is requested to stop, `Tensors`, which is run until the session is requested to stop, commonly
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`. done by a `tf.contrib.training.StopAfterNEvalsHook`.
feed_dict: The feed dictionary to use when executing the `eval_ops`. feed_dict: The feed dictionary to use when executing the `eval_ops`.
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`. to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`. final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside
evaluation loop. the evaluation loop.
config: An instance of `tf.ConfigProto` that will be used to config: An instance of `tf.compat.v1.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used. configure the `Session`. If left as `None`, the default will be used.
Returns: Returns:
@ -263,8 +261,8 @@ def _evaluate_once(checkpoint_path,
master=master, master=master,
config=config) config=config)
final_ops_hook = basic_session_run_hooks.FinalOpsHook( final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
final_ops, final_ops_feed_dict) final_ops_feed_dict)
hooks.append(final_ops_hook) hooks.append(final_ops_hook)
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(

View File

@ -1090,7 +1090,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
The `tensors_list` argument is a list of tuples of tensors, or a list of The `tensors_list` argument is a list of tuples of tensors, or a list of
dictionaries of tensors. Each element in the list is treated similarly dictionaries of tensors. Each element in the list is treated similarly
to the `tensors` argument of `tf.train.batch()`. to the `tensors` argument of `tf.compat.v1.train.batch()`.
WARNING: This function is nondeterministic, since it starts a separate thread WARNING: This function is nondeterministic, since it starts a separate thread
for each tensor. for each tensor.
@ -1284,7 +1284,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
```python ```python
# Creates batches of 32 images and 32 labels. # Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch( image_batch, label_batch = tf.compat.v1.train.shuffle_batch(
[single_image, single_label], [single_image, single_label],
batch_size=32, batch_size=32,
num_threads=4, num_threads=4,
@ -1425,7 +1425,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
The `tensors_list` argument is a list of tuples of tensors, or a list of The `tensors_list` argument is a list of tuples of tensors, or a list of
dictionaries of tensors. Each element in the list is treated similarly dictionaries of tensors. Each element in the list is treated similarly
to the `tensors` argument of `tf.train.shuffle_batch()`. to the `tensors` argument of `tf.compat.v1.train.shuffle_batch()`.
This version enqueues a different list of tensors in different threads. This version enqueues a different list of tensors in different threads.
It adds the following to the current `Graph`: It adds the following to the current `Graph`:

View File

@ -56,24 +56,25 @@ def exponential_decay(learning_rate,
... ...
global_step = tf.Variable(0, trainable=False) global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1 starter_learning_rate = 0.1
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
global_step,
100000, 0.96, staircase=True) 100000, 0.96, staircase=True)
# Passing global_step to minimize() will increment it at each step. # Passing global_step to minimize() will increment it at each step.
learning_step = ( learning_step = (
tf.train.GradientDescentOptimizer(learning_rate) tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step) .minimize(...my loss..., global_step=global_step)
) )
``` ```
Args: Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
Python number. The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. Must not be negative. step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
Must be positive. See the decay computation above. be positive. See the decay computation above.
decay_rate: A scalar `float32` or `float64` `Tensor` or a decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
Python number. The decay rate. The decay rate.
staircase: Boolean. If `True` decay the learning rate at discrete intervals staircase: Boolean. If `True` decay the learning rate at discrete intervals
name: String. Optional name of the operation. Defaults to name: String. Optional name of the operation. Defaults to
'ExponentialDecay'. 'ExponentialDecay'.
@ -91,11 +92,8 @@ def exponential_decay(learning_rate,
the learning rate value across different invocations of optimizer functions. the learning rate value across different invocations of optimizer functions.
@end_compatibility @end_compatibility
""" """
decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate, decayed_lr = learning_rate_schedule.ExponentialDecay(
decay_steps, learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
decay_rate,
staircase=staircase,
name=name)
if not context.executing_eagerly(): if not context.executing_eagerly():
decayed_lr = decayed_lr(global_step) decayed_lr = decayed_lr(global_step)
else: else:
@ -114,7 +112,8 @@ def piecewise_constant(x, boundaries, values, name=None):
global_step = tf.Variable(0, trainable=False) global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000] boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1] values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values) learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries,
values)
# Later, whenever we perform an optimization step, we increment global_step. # Later, whenever we perform an optimization step, we increment global_step.
``` ```
@ -202,27 +201,28 @@ def polynomial_decay(learning_rate,
starter_learning_rate = 0.1 starter_learning_rate = 0.1
end_learning_rate = 0.01 end_learning_rate = 0.01
decay_steps = 10000 decay_steps = 10000
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step, learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate,
global_step,
decay_steps, end_learning_rate, decay_steps, end_learning_rate,
power=0.5) power=0.5)
# Passing global_step to minimize() will increment it at each step. # Passing global_step to minimize() will increment it at each step.
learning_step = ( learning_step = (
tf.train.GradientDescentOptimizer(learning_rate) tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step) .minimize(...my loss..., global_step=global_step)
) )
``` ```
Args: Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
Python number. The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. Must not be negative. step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
Must be positive. See the decay computation above. be positive. See the decay computation above.
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python
Python number. The minimal end learning rate. number. The minimal end learning rate.
power: A scalar `float32` or `float64` `Tensor` or a power: A scalar `float32` or `float64` `Tensor` or a Python number. The
Python number. The power of the polynomial. Defaults to linear, 1.0. power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps. cycle: A boolean, whether or not it should cycle beyond decay_steps.
name: String. Optional name of the operation. Defaults to name: String. Optional name of the operation. Defaults to
'PolynomialDecay'. 'PolynomialDecay'.
@ -292,21 +292,22 @@ def natural_exp_decay(learning_rate,
learning_rate = 0.1 learning_rate = 0.1
decay_steps = 5 decay_steps = 5
k = 0.5 k = 0.5
learning_rate = tf.train.natural_exp_decay(learning_rate, global_step, learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate,
global_step,
decay_steps, k) decay_steps, k)
# Passing global_step to minimize() will increment it at each step. # Passing global_step to minimize() will increment it at each step.
learning_step = ( learning_step = (
tf.train.GradientDescentOptimizer(learning_rate) tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step) .minimize(...my loss..., global_step=global_step)
) )
``` ```
Args: Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
Python number. The initial learning rate. The initial learning rate.
global_step: A Python number. global_step: A Python number. Global step to use for the decay computation.
Global step to use for the decay computation. Must not be negative. Must not be negative.
decay_steps: How often to apply decay. decay_steps: How often to apply decay.
decay_rate: A Python number. The decay rate. decay_rate: A Python number. The decay rate.
staircase: Whether to apply decay in a discrete staircase, as opposed to staircase: Whether to apply decay in a discrete staircase, as opposed to
@ -329,7 +330,10 @@ def natural_exp_decay(learning_rate,
""" """
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate)) natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
decayed_lr = learning_rate_schedule.ExponentialDecay( decayed_lr = learning_rate_schedule.ExponentialDecay(
learning_rate, decay_steps, natural_exp_rate, staircase=staircase, learning_rate,
decay_steps,
natural_exp_rate,
staircase=staircase,
name=name) name=name)
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -376,21 +380,22 @@ def inverse_time_decay(learning_rate,
learning_rate = 0.1 learning_rate = 0.1
decay_steps = 1.0 decay_steps = 1.0
decay_rate = 0.5 decay_rate = 0.5
learning_rate = tf.train.inverse_time_decay(learning_rate, global_step, learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate,
global_step,
decay_steps, decay_rate) decay_steps, decay_rate)
# Passing global_step to minimize() will increment it at each step. # Passing global_step to minimize() will increment it at each step.
learning_step = ( learning_step = (
tf.train.GradientDescentOptimizer(learning_rate) tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step) .minimize(...my loss..., global_step=global_step)
) )
``` ```
Args: Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
Python number. The initial learning rate. The initial learning rate.
global_step: A Python number. global_step: A Python number. Global step to use for the decay computation.
Global step to use for the decay computation. Must not be negative. Must not be negative.
decay_steps: How often to apply decay. decay_steps: How often to apply decay.
decay_rate: A Python number. The decay rate. decay_rate: A Python number. The decay rate.
staircase: Whether to apply decay in a discrete staircase, as opposed to staircase: Whether to apply decay in a discrete staircase, as opposed to
@ -412,11 +417,7 @@ def inverse_time_decay(learning_rate,
@end_compatibility @end_compatibility
""" """
decayed_lr = learning_rate_schedule.InverseTimeDecay( decayed_lr = learning_rate_schedule.InverseTimeDecay(
learning_rate, learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
decay_steps,
decay_rate,
staircase=staircase,
name=name)
if not context.executing_eagerly(): if not context.executing_eagerly():
decayed_lr = decayed_lr(global_step) decayed_lr = decayed_lr(global_step)
@ -455,13 +456,14 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
Args: Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number. learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
Number of steps to decay over. of steps to decay over.
alpha: A scalar `float32` or `float64` Tensor or a Python number. alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
Minimum learning rate value as a fraction of learning_rate. learning rate value as a fraction of learning_rate.
name: String. Optional name of the operation. Defaults to 'CosineDecay'. name: String. Optional name of the operation. Defaults to 'CosineDecay'.
Returns: Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate. learning rate.
@ -519,17 +521,18 @@ def cosine_decay_restarts(learning_rate,
Args: Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number. learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. step to use for the decay computation.
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over. Number of steps to decay over.
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to
Used to derive the number of iterations in the i-th period derive the number of iterations in the i-th period
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Used to derive the initial learning rate of the i-th period: Used to derive the initial learning rate of the i-th period:
alpha: A scalar `float32` or `float64` Tensor or a Python number. alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
Minimum learning rate value as a fraction of the learning_rate. learning rate value as a fraction of the learning_rate.
name: String. Optional name of the operation. Defaults to 'SGDRDecay'. name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
Returns: Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate. learning rate.
@ -602,16 +605,17 @@ def linear_cosine_decay(learning_rate,
Args: Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number. learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
Number of steps to decay over. of steps to decay over.
num_periods: Number of periods in the cosine part of the decay. num_periods: Number of periods in the cosine part of the decay. See
See computation above. computation above.
alpha: See computation above. alpha: See computation above.
beta: See computation above. beta: See computation above.
name: String. Optional name of the operation. Defaults to name: String. Optional name of the operation. Defaults to
'LinearCosineDecay'. 'LinearCosineDecay'.
Returns: Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate. learning rate.
@ -690,18 +694,19 @@ def noisy_linear_cosine_decay(learning_rate,
Args: Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number. learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
Global step to use for the decay computation. step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
Number of steps to decay over. of steps to decay over.
initial_variance: initial variance for the noise. See computation above. initial_variance: initial variance for the noise. See computation above.
variance_decay: decay for the noise's variance. See computation above. variance_decay: decay for the noise's variance. See computation above.
num_periods: Number of periods in the cosine part of the decay. num_periods: Number of periods in the cosine part of the decay. See
See computation above. computation above.
alpha: See computation above. alpha: See computation above.
beta: See computation above. beta: See computation above.
name: String. Optional name of the operation. Defaults to name: String. Optional name of the operation. Defaults to
'NoisyLinearCosineDecay'. 'NoisyLinearCosineDecay'.
Returns: Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate. learning rate.

View File

@ -77,7 +77,8 @@ class Scaffold(object):
The following pieces are directly accessible as attributes of the `Scaffold` The following pieces are directly accessible as attributes of the `Scaffold`
object: object:
* `saver`: A `tf.train.Saver` object taking care of saving the variables. * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
variables.
Picked from and stored into the `SAVERS` collection in the graph by default. Picked from and stored into the `SAVERS` collection in the graph by default.
* `init_op`: An op to run to initialize the variables. Picked from and * `init_op`: An op to run to initialize the variables. Picked from and
stored into the `INIT_OP` collection in the graph by default. stored into the `INIT_OP` collection in the graph by default.
@ -133,9 +134,9 @@ class Scaffold(object):
local_init_op: Optional op to initialize local variables. local_init_op: Optional op to initialize local variables.
summary_op: Optional op to gather all summaries. Must return a scalar summary_op: Optional op to gather all summaries. Must return a scalar
string tensor containing a serialized `Summary` proto. string tensor containing a serialized `Summary` proto.
saver: Optional `tf.train.Saver` object to use to save and restore saver: Optional `tf.compat.v1.train.Saver` object to use to save and
variables. May also be a `tf.train.Checkpoint` object, in which case restore variables. May also be a `tf.train.Checkpoint` object, in which
object-based checkpoints are saved. This will also load some case object-based checkpoints are saved. This will also load some
object-based checkpoints saved from elsewhere, but that loading may be object-based checkpoints saved from elsewhere, but that loading may be
fragile since it uses fixed keys rather than performing a full fragile since it uses fixed keys rather than performing a full
graph-based match. For example if a variable has two paths from the graph-based match. For example if a variable has two paths from the
@ -199,8 +200,9 @@ class Scaffold(object):
resources.report_uninitialized_resources() resources.report_uninitialized_resources()
], 0) ], 0)
self._ready_op = Scaffold.get_or_default( self._ready_op = Scaffold.get_or_default('ready_op',
'ready_op', ops.GraphKeys.READY_OP, default_ready_op) ops.GraphKeys.READY_OP,
default_ready_op)
if self._ready_for_local_init_op is None: if self._ready_for_local_init_op is None:
def default_ready_for_local_init_op(): def default_ready_for_local_init_op():
@ -219,8 +221,9 @@ class Scaffold(object):
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold.default_local_init_op) Scaffold.default_local_init_op)
if self._summary_op is None: if self._summary_op is None:
self._summary_op = Scaffold.get_or_default( self._summary_op = Scaffold.get_or_default('summary_op',
'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all) ops.GraphKeys.SUMMARY_OP,
summary.merge_all)
# pylint: disable=g-long-lambda # pylint: disable=g-long-lambda
if self._saver is None: if self._saver is None:
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
@ -292,7 +295,8 @@ class Scaffold(object):
This op is used during session initialization when a Scaffold is This op is used during session initialization when a Scaffold is
initialized without specifying the local_init_op arg. It includes initialized without specifying the local_init_op arg. It includes
`tf.local_variables_initializer`, `tf.tables_initializer`, and also `tf.compat.v1.local_variables_initializer`,
`tf.compat.v1.tables_initializer`, and also
initializes local session resources. initializes local session resources.
Returns: Returns:
@ -435,7 +439,8 @@ def MonitoredTrainingSession(
For a chief, this utility sets proper session initializer/restorer. It also For a chief, this utility sets proper session initializer/restorer. It also
creates hooks related to checkpoint and summary saving. For workers, this creates hooks related to checkpoint and summary saving. For workers, this
utility sets proper session creator which waits for the chief to utility sets proper session creator which waits for the chief to
initialize/restore. Please check `tf.train.MonitoredSession` for more initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
more
information. information.
@ -464,8 +469,9 @@ def MonitoredTrainingSession(
to disk using a default summary saver. If both `save_summaries_steps` and to disk using a default summary saver. If both `save_summaries_steps` and
`save_summaries_secs` are set to `None`, then the default summary saver `save_summaries_secs` are set to `None`, then the default summary saver
isn't used. Default not enabled. isn't used. Default not enabled.
config: an instance of `tf.ConfigProto` proto used to configure the session. config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
It's the `config` argument of constructor of `tf.Session`. the session. It's the `config` argument of constructor of
`tf.compat.v1.Session`.
stop_grace_period_secs: Number of seconds given to threads to stop after stop_grace_period_secs: Number of seconds given to threads to stop after
`close()` has been called. `close()` has been called.
log_step_count_steps: The frequency, in number of global steps, that the log_step_count_steps: The frequency, in number of global steps, that the
@ -591,7 +597,7 @@ class SessionCreator(object):
@tf_export(v1=['train.ChiefSessionCreator']) @tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator): class ChiefSessionCreator(SessionCreator):
"""Creates a tf.Session for a chief.""" """Creates a tf.compat.v1.Session for a chief."""
def __init__(self, def __init__(self,
scaffold=None, scaffold=None,
@ -643,7 +649,7 @@ class ChiefSessionCreator(SessionCreator):
@tf_export(v1=['train.WorkerSessionCreator']) @tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator): class WorkerSessionCreator(SessionCreator):
"""Creates a tf.Session for a worker.""" """Creates a tf.compat.v1.Session for a worker."""
def __init__(self, def __init__(self,
scaffold=None, scaffold=None,
@ -757,8 +763,9 @@ class _MonitoredSession(object):
`step_fn` will be returned from `run_step_fn`, unless a stop is `step_fn` will be returned from `run_step_fn`, unless a stop is
requested. In that case, the next `should_stop` call will return True. requested. In that case, the next `should_stop` call will return True.
Example usage: ```python Example usage: ```python
with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v = with tf.Graph().as_default(): c =
tf.add(c, 4.0) w = tf.add(c, 0.5) tf.compat.v1.placeholder(dtypes.float32) v = tf.add(c, 4.0) w =
tf.add(c, 0.5)
def step_fn(step_context): def step_fn(step_context):
a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
if a <= 4.5: step_context.request_stop() if a <= 4.5: step_context.request_stop()
@ -808,7 +815,7 @@ class _MonitoredSession(object):
"""Initializes the `step_context` argument for a `step_fn` invocation. """Initializes the `step_context` argument for a `step_fn` invocation.
Args: Args:
session: An instance of `tf.Session`. session: An instance of `tf.compat.v1.Session`.
run_with_hooks_fn: A function for running fetches and hooks. run_with_hooks_fn: A function for running fetches and hooks.
""" """
self._session = session self._session = session
@ -901,13 +908,13 @@ class _MonitoredSession(object):
return self._coordinated_creator.tf_sess is None return self._coordinated_creator.tf_sess is None
def _tf_sess(self): def _tf_sess(self):
"""Return underlying tf.Session object. """Return underlying tf.compat.v1.Session object.
Warning: accessing the returned object in user code is likely to cause races Warning: accessing the returned object in user code is likely to cause races
or "flaky tests". or "flaky tests".
Returns: Returns:
A tf.Session object. A tf.compat.v1.Session object.
""" """
return self._coordinated_creator.tf_sess return self._coordinated_creator.tf_sess
@ -955,7 +962,7 @@ class MonitoredSession(_MonitoredSession):
* suppresses `OutOfRange` error which indicates that all inputs have been * suppresses `OutOfRange` error which indicates that all inputs have been
processed if the monitored_session is used as a context processed if the monitored_session is used as a context
How to set `tf.Session` arguments: How to set `tf.compat.v1.Session` arguments:
* In most cases you can set session arguments as follows: * In most cases you can set session arguments as follows:
@ -973,7 +980,8 @@ class MonitoredSession(_MonitoredSession):
See `MonitoredTrainingSession` for an example usage based on chief or worker. See `MonitoredTrainingSession` for an example usage based on chief or worker.
Note: This is not a `tf.Session`. For example, it cannot do following: Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
following:
* it cannot be set as default session. * it cannot be set as default session.
* it cannot be sent to saver.save. * it cannot be sent to saver.save.
@ -1004,14 +1012,15 @@ class SingularMonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, restoring, and hooks. """Session-like object that handles initialization, restoring, and hooks.
Please note that this utility is not recommended for distributed settings. Please note that this utility is not recommended for distributed settings.
For distributed settings, please use `tf.train.MonitoredSession`. The For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
The
differences between `MonitoredSession` and `SingularMonitoredSession` are: differences between `MonitoredSession` and `SingularMonitoredSession` are:
* `MonitoredSession` handles `AbortedError` and `UnavailableError` for * `MonitoredSession` handles `AbortedError` and `UnavailableError` for
distributed settings, but `SingularMonitoredSession` does not. distributed settings, but `SingularMonitoredSession` does not.
* `MonitoredSession` can be created in `chief` or `worker` modes. * `MonitoredSession` can be created in `chief` or `worker` modes.
`SingularMonitoredSession` is always created as `chief`. `SingularMonitoredSession` is always created as `chief`.
* You can access the raw `tf.Session` object used by * You can access the raw `tf.compat.v1.Session` object used by
`SingularMonitoredSession`, whereas in MonitoredSession the raw session is `SingularMonitoredSession`, whereas in MonitoredSession the raw session is
private. This can be used: private. This can be used:
- To `run` without hooks. - To `run` without hooks.
@ -1093,7 +1102,7 @@ class SingularMonitoredSession(_MonitoredSession):
class _WrappedSession(object): class _WrappedSession(object):
"""Wrapper around a `tf.Session`. """Wrapper around a `tf.compat.v1.Session`.
This wrapper is used as a base class for various session wrappers This wrapper is used as a base class for various session wrappers
that provide additional functionality such as monitoring, coordination, that provide additional functionality such as monitoring, coordination,
@ -1108,7 +1117,8 @@ class _WrappedSession(object):
"""Creates a `_WrappedSession`. """Creates a `_WrappedSession`.
Args: Args:
sess: A `tf.Session` or `_WrappedSession` object. The wrapped session. sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped
session.
""" """
self._sess = sess self._sess = sess
self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
@ -1293,7 +1303,7 @@ class _CoordinatedSession(_WrappedSession):
"""Create a new `_CoordinatedSession`. """Create a new `_CoordinatedSession`.
Args: Args:
sess: A `tf.Session` object. The wrapped session. sess: A `tf.compat.v1.Session` object. The wrapped session.
coord: A `tf.train.Coordinator` object. coord: A `tf.train.Coordinator` object.
stop_grace_period_secs: Number of seconds given to threads to stop after stop_grace_period_secs: Number of seconds given to threads to stop after
`close()` has been called. `close()` has been called.
@ -1364,7 +1374,7 @@ class _HookedSession(_WrappedSession):
"""Initializes a _HookedSession object. """Initializes a _HookedSession object.
Args: Args:
sess: A `tf.Session` or a `_WrappedSession` object. sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
hooks: An iterable of `SessionRunHook' objects. hooks: An iterable of `SessionRunHook' objects.
""" """

View File

@ -56,9 +56,9 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
E.g.: E.g.:
``` ```
with tf.variable_scope('scope1'): with tf.compat.v1.variable_scope('scope1'):
with tf.variable_scope('scope2'): with tf.compat.v1.variable_scope('scope2'):
var = tf.get_variable('foo') var = tf.compat.v1.get_variable('foo')
update_1 = tf.assign_moving_average(var, 0.0, 1.0) update_1 = tf.assign_moving_average(var, 0.0, 1.0)
update_2 = tf.assign_moving_average(var, 0.0, 0.9) update_2 = tf.assign_moving_average(var, 0.0, 0.9)
@ -73,12 +73,13 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
decay: A float Tensor or float value. The moving average decay. decay: A float Tensor or float value. The moving average decay.
zero_debias: A python bool. If true, assume the variable is 0-initialized zero_debias: A python bool. If true, assume the variable is 0-initialized
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
`_zero_debias` for more details. `_zero_debias` for more details.
name: Optional name of the returned operation. name: Optional name of the returned operation.
Returns: Returns:
A tensor which if evaluated will compute and return the new moving average. A tensor which if evaluated will compute and return the new moving average.
""" """
def update_fn(v, value, decay=decay): def update_fn(v, value, decay=decay):
decay = ops.convert_to_tensor(1.0 - decay, name="decay") decay = ops.convert_to_tensor(1.0 - decay, name="decay")
if decay.dtype != v.dtype.base_dtype: if decay.dtype != v.dtype.base_dtype:
@ -96,8 +97,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
# In a replica context, we update variable using the mean of value across # In a replica context, we update variable using the mean of value across
# replicas. # replicas.
def merge_fn(strategy, v, value): def merge_fn(strategy, v, value):
value = strategy.extended.reduce_to( value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
ds_reduce_util.ReduceOp.MEAN, value, v) v)
return strategy.extended.update(v, update_fn, args=(value,)) return strategy.extended.update(v, update_fn, args=(value,))
return replica_context.merge_call(merge_fn, args=(variable, value)) return replica_context.merge_call(merge_fn, args=(variable, value))
@ -124,15 +125,15 @@ def weighted_moving_average(value,
Args: Args:
value: A numeric `Tensor`. value: A numeric `Tensor`.
decay: A float `Tensor` or float value. The moving average decay. decay: A float `Tensor` or float value. The moving average decay.
weight: `Tensor` that keeps the current value of a weight. weight: `Tensor` that keeps the current value of a weight. Shape should be
Shape should be able to multiply `value`. able to multiply `value`.
truediv: Boolean, if `True`, dividing by `moving_average(weight)` is truediv: Boolean, if `True`, dividing by `moving_average(weight)` is
floating point division. If `False`, use division implied by dtypes. floating point division. If `False`, use division implied by dtypes.
collections: List of graph collections keys to add the internal variables collections: List of graph collections keys to add the internal variables
`value * weight` and `weight` to. `value * weight` and `weight` to. Defaults to
Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. `[GraphKeys.GLOBAL_VARIABLES]`.
name: Optional name of the returned operation. name: Optional name of the returned operation. Defaults to
Defaults to "WeightedMovingAvg". "WeightedMovingAvg".
Returns: Returns:
An Operation that updates and returns the weighted moving average. An Operation that updates and returns the weighted moving average.
@ -203,27 +204,34 @@ def _zero_debias(unbiased_var, value, decay):
tensor will also update the shadow variables appropriately. tensor will also update the shadow variables appropriately.
""" """
with variable_scope.variable_scope( with variable_scope.variable_scope(
unbiased_var.name[:-len(":0")], values=[unbiased_var, unbiased_var.name[:-len(":0")], values=[unbiased_var, value,
value, decay]) as scope: decay]) as scope:
with ops.colocate_with(unbiased_var): with ops.colocate_with(unbiased_var):
with ops.init_scope(): with ops.init_scope():
biased_initializer = init_ops.zeros_initializer( biased_initializer = init_ops.zeros_initializer(
dtype=unbiased_var.dtype)(unbiased_var.get_shape()) dtype=unbiased_var.dtype)(
unbiased_var.get_shape())
local_step_initializer = init_ops.zeros_initializer() local_step_initializer = init_ops.zeros_initializer()
def _maybe_get_unique(name): def _maybe_get_unique(name):
"""Get name for a unique variable, if not `reuse=True`.""" """Get name for a unique variable, if not `reuse=True`."""
if variable_scope.get_variable_scope().reuse: if variable_scope.get_variable_scope().reuse:
return name return name
vs_vars = [x.op.name for x in vs_vars = [
variable_scope.get_variable_scope().global_variables()] x.op.name
for x in variable_scope.get_variable_scope().global_variables()
]
full_name = variable_scope.get_variable_scope().name + "/" + name full_name = variable_scope.get_variable_scope().name + "/" + name
if full_name not in vs_vars: return name if full_name not in vs_vars:
return name
idx = 1 idx = 1
while full_name + ("_%d" % idx) in vs_vars: while full_name + ("_%d" % idx) in vs_vars:
idx += 1 idx += 1
return name + ("_%d" % idx) return name + ("_%d" % idx)
biased_var = variable_scope.get_variable( biased_var = variable_scope.get_variable(
_maybe_get_unique("biased"), initializer=biased_initializer, _maybe_get_unique("biased"),
initializer=biased_initializer,
trainable=False) trainable=False)
local_step = variable_scope.get_variable( local_step = variable_scope.get_variable(
_maybe_get_unique("local_step"), _maybe_get_unique("local_step"),
@ -233,18 +241,17 @@ def _zero_debias(unbiased_var, value, decay):
trainable=False) trainable=False)
# Get an update ops for both shadow variables. # Get an update ops for both shadow variables.
update_biased = state_ops.assign_sub(biased_var, update_biased = state_ops.assign_sub(
(biased_var - value) * decay, biased_var, (biased_var - value) * decay, name=scope.name)
name=scope.name)
update_local_step = local_step.assign_add(1) update_local_step = local_step.assign_add(1)
# Compute the value of the delta to update the unbiased EMA. Make sure to # Compute the value of the delta to update the unbiased EMA. Make sure to
# use the new values of the biased variable and the local step. # use the new values of the biased variable and the local step.
with ops.control_dependencies([update_biased, update_local_step]): with ops.control_dependencies([update_biased, update_local_step]):
# This function gets `1 - decay`, so use `1.0 - decay` in the exponent. # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
unbiased_ema_delta = (unbiased_var - biased_var.read_value() / unbiased_ema_delta = (
(1 - math_ops.pow( unbiased_var - biased_var.read_value() /
1.0 - decay, local_step.read_value()))) (1 - math_ops.pow(1.0 - decay, local_step.read_value())))
return unbiased_ema_delta return unbiased_ema_delta
@ -315,7 +322,7 @@ class ExponentialMovingAverage(object):
for a given variable. for a given variable.
* Build a model normally but load the checkpoint files to evaluate by using * Build a model normally but load the checkpoint files to evaluate by using
the shadow variable names. For this use the `average_name()` method. See the shadow variable names. For this use the `average_name()` method. See
the `tf.train.Saver` for more the `tf.compat.v1.train.Saver` for more
information on restoring saved variables. information on restoring saved variables.
Example of restoring the shadow variable values: Example of restoring the shadow variable values:
@ -324,13 +331,17 @@ class ExponentialMovingAverage(object):
# Create a Saver that loads variables from their saved shadow values. # Create a Saver that loads variables from their saved shadow values.
shadow_var0_name = ema.average_name(var0) shadow_var0_name = ema.average_name(var0)
shadow_var1_name = ema.average_name(var1) shadow_var1_name = ema.average_name(var1)
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1}) saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name:
var1})
saver.restore(...checkpoint filename...) saver.restore(...checkpoint filename...)
# var0 and var1 now hold the moving average values # var0 and var1 now hold the moving average values
``` ```
""" """
def __init__(self, decay, num_updates=None, zero_debias=False, def __init__(self,
decay,
num_updates=None,
zero_debias=False,
name="ExponentialMovingAverage"): name="ExponentialMovingAverage"):
"""Creates a new ExponentialMovingAverage object. """Creates a new ExponentialMovingAverage object.
@ -376,7 +387,7 @@ class ExponentialMovingAverage(object):
shadow variables are created with `trainable=False` and added to the shadow variables are created with `trainable=False` and added to the
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`. `tf.compat.v1.global_variables()`.
Returns an op that updates all shadow variables from the current value of Returns an op that updates all shadow variables from the current value of
their associated variables. their associated variables.
@ -386,8 +397,8 @@ class ExponentialMovingAverage(object):
be called in a loop. be called in a loop.
Args: Args:
var_list: A list of Variable or Tensor objects. The variables var_list: A list of Variable or Tensor objects. The variables and Tensors
and Tensors must be of types bfloat16, float16, float32, or float64. must be of types bfloat16, float16, float32, or float64.
Returns: Returns:
An Operation that updates the moving averages. An Operation that updates the moving averages.
@ -417,10 +428,11 @@ class ExponentialMovingAverage(object):
# tensors, we rely on the existing device allocation mechanism. # tensors, we rely on the existing device allocation mechanism.
with ops.init_scope(): with ops.init_scope():
if isinstance(var, variables.Variable): if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(var, avg = slot_creator.create_slot(
var.initialized_value(), var,
self.name, var.initialized_value(),
colocate_with_primary=True) self.name,
colocate_with_primary=True)
# NOTE(mrry): We only add `tf.Variable` objects to the # NOTE(mrry): We only add `tf.Variable` objects to the
# `MOVING_AVERAGE_VARIABLES` collection. # `MOVING_AVERAGE_VARIABLES` collection.
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
@ -428,9 +440,9 @@ class ExponentialMovingAverage(object):
avg = slot_creator.create_zeros_slot( avg = slot_creator.create_zeros_slot(
var, var,
self.name, self.name,
colocate_with_primary=(var.op.type in ["Variable", colocate_with_primary=(var.op.type in [
"VariableV2", "Variable", "VariableV2", "VarHandleOp"
"VarHandleOp"])) ]))
if self._zero_debias: if self._zero_debias:
zero_debias_true.add(avg) zero_debias_true.add(avg)
self._averages[var] = avg self._averages[var] = avg
@ -438,16 +450,16 @@ class ExponentialMovingAverage(object):
with ops.name_scope(self.name) as scope: with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay") decay = ops.convert_to_tensor(self._decay, name="decay")
if self._num_updates is not None: if self._num_updates is not None:
num_updates = math_ops.cast(self._num_updates, num_updates = math_ops.cast(
dtypes.float32, self._num_updates, dtypes.float32, name="num_updates")
name="num_updates")
decay = math_ops.minimum(decay, decay = math_ops.minimum(decay,
(1.0 + num_updates) / (10.0 + num_updates)) (1.0 + num_updates) / (10.0 + num_updates))
updates = [] updates = []
for var in var_list: for var in var_list:
zero_debias = self._averages[var] in zero_debias_true zero_debias = self._averages[var] in zero_debias_true
updates.append(assign_moving_average( updates.append(
self._averages[var], var, decay, zero_debias=zero_debias)) assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
return control_flow_ops.group(*updates, name=scope) return control_flow_ops.group(*updates, name=scope)
def average(self, var): def average(self, var):
@ -472,7 +484,7 @@ class ExponentialMovingAverage(object):
To restore variables, you have to know the name of the shadow variables. To restore variables, you have to know the name of the shadow variables.
That name and the original variable can then be passed to a `Saver()` object That name and the original variable can then be passed to a `Saver()` object
to restore the variable from the moving average value with: to restore the variable from the moving average value with:
`saver = tf.train.Saver({ema.average_name(var): var})` `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
`average_name()` can be called whether or not `apply()` has been called. `average_name()` can be called whether or not `apply()` has been called.
@ -499,7 +511,7 @@ class ExponentialMovingAverage(object):
```python ```python
variables_to_restore = ema.variables_to_restore() variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore) saver = tf.compat.v1.train.Saver(variables_to_restore)
``` ```
Below is an example of such mapping: Below is an example of such mapping:

View File

@ -434,14 +434,14 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
Raises: Raises:
ValueError: if `sess` is None and there isn't any default session. ValueError: if `sess` is None and there isn't any default session.
TypeError: if `sess` is not a `tf.Session` object. TypeError: if `sess` is not a `tf.compat.v1.Session` object.
Returns: Returns:
A list of threads. A list of threads.
Raises: Raises:
RuntimeError: If called with eager execution enabled. RuntimeError: If called with eager execution enabled.
ValueError: If called without a default `tf.Session` registered. ValueError: If called without a default `tf.compat.v1.Session` registered.
@compatibility(eager) @compatibility(eager)
Not compatible with eager execution. To ingest data under eager execution, Not compatible with eager execution. To ingest data under eager execution,

View File

@ -56,7 +56,6 @@ from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# TODO(allenl): Remove these aliases once all users are migrated off. # TODO(allenl): Remove these aliases once all users are migrated off.
get_checkpoint_state = checkpoint_management.get_checkpoint_state get_checkpoint_state = checkpoint_management.get_checkpoint_state
update_checkpoint_state = checkpoint_management.update_checkpoint_state update_checkpoint_state = checkpoint_management.update_checkpoint_state
@ -174,13 +173,11 @@ class BaseSaverBuilder(object):
tensors = [] tensors = []
for spec in saveable.specs: for spec in saveable.specs:
tensors.append( tensors.append(
io_ops.restore_v2( io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
filename_tensor, [spec.dtype])[0])
[spec.name],
[spec.slice_spec],
[spec.dtype])[0])
return tensors return tensors
# pylint: enable=unused-argument # pylint: enable=unused-argument
def sharded_filename(self, filename_tensor, shard, num_shards): def sharded_filename(self, filename_tensor, shard, num_shards):
@ -217,8 +214,8 @@ class BaseSaverBuilder(object):
from each device. from each device.
Args: Args:
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
FILENAME*, but as a prefix of a V2 checkpoint; but as a prefix of a V2 checkpoint;
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
returned by _GroupByDevices(). returned by _GroupByDevices().
@ -319,8 +316,8 @@ class BaseSaverBuilder(object):
saveables: A list of SaveableObject objects. saveables: A list of SaveableObject objects.
restore_sequentially: True if we want to restore variables sequentially restore_sequentially: True if we want to restore variables sequentially
within a shard. within a shard.
reshape: True if we want to reshape loaded tensors to the shape of reshape: True if we want to reshape loaded tensors to the shape of the
the corresponding variable. corresponding variable.
preferred_shard: Shard to open first when loading a sharded file. preferred_shard: Shard to open first when loading a sharded file.
name: Name for the returned op. name: Name for the returned op.
@ -361,12 +358,12 @@ class BaseSaverBuilder(object):
Args: Args:
filename_tensor: Tensor for the path of the file to load. filename_tensor: Tensor for the path of the file to load.
per_device: A list of (device, SaveableObject) pairs, as per_device: A list of (device, SaveableObject) pairs, as returned by
returned by _GroupByDevices(). _GroupByDevices().
restore_sequentially: True if we want to restore variables sequentially restore_sequentially: True if we want to restore variables sequentially
within a shard. within a shard.
reshape: True if we want to reshape loaded tensors to the shape of reshape: True if we want to reshape loaded tensors to the shape of the
the corresponding variable. corresponding variable.
Returns: Returns:
An Operation that restores the variables. An Operation that restores the variables.
@ -424,14 +421,13 @@ class BaseSaverBuilder(object):
Args: Args:
names_to_saveables: A dictionary mapping name to a Variable or names_to_saveables: A dictionary mapping name to a Variable or
SaveableObject. Each name will be associated with the SaveableObject. Each name will be associated with the corresponding
corresponding variable in the checkpoint. variable in the checkpoint.
reshape: If True, allow restoring parameters from a checkpoint reshape: If True, allow restoring parameters from a checkpoint that where
that where the parameters have a different shape. This is the parameters have a different shape. This is only needed when you try
only needed when you try to restore from a Dist-Belief checkpoint, to restore from a Dist-Belief checkpoint, and only some times.
and only some times. sharded: If True, shard the checkpoints, one per device that has Variable
sharded: If True, shard the checkpoints, one per device that has nodes.
Variable nodes.
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
are created, old ones are deleted. If None or 0, no checkpoints are are created, old ones are deleted. If None or 0, no checkpoints are
deleted from the filesystem but only the last one is kept in the deleted from the filesystem but only the last one is kept in the
@ -597,8 +593,8 @@ def _get_saver_or_default():
if len(savers) > 1: if len(savers) > 1:
raise RuntimeError( raise RuntimeError(
"More than one item in collection {}. " "More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor.". "Please indicate which one to use by passing it to the constructor."
format(collection_key)) .format(collection_key))
return savers[0] return savers[0]
saver = Saver(sharded=True, allow_empty=True) saver = Saver(sharded=True, allow_empty=True)
if saver is not None: if saver is not None:
@ -662,9 +658,9 @@ class Saver(object):
```python ```python
... ...
# Create a saver. # Create a saver.
saver = tf.train.Saver(...variables...) saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps. # Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session() sess = tf.compat.v1.Session()
for step in xrange(1000000): for step in xrange(1000000):
sess.run(..training_op..) sess.run(..training_op..)
if step % 1000 == 0: if step % 1000 == 0:
@ -717,13 +713,13 @@ class Saver(object):
v2 = tf.Variable(..., name='v2') v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict: # Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2}) saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list. # Or pass them as a list.
saver = tf.train.Saver([v1, v2]) saver = tf.compat.v1.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names # Passing a list is equivalent to passing a dict with the variable op names
# as keys: # as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
``` ```
The optional `reshape` argument, if `True`, allows restoring a variable from The optional `reshape` argument, if `True`, allows restoring a variable from
@ -738,35 +734,33 @@ class Saver(object):
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
names to `SaveableObject`s. If `None`, defaults to the list of all names to `SaveableObject`s. If `None`, defaults to the list of all
saveable objects. saveable objects.
reshape: If `True`, allows restoring parameters from a checkpoint reshape: If `True`, allows restoring parameters from a checkpoint where
where the variables have a different shape. the variables have a different shape.
sharded: If `True`, shard the checkpoints, one per device. sharded: If `True`, shard the checkpoints, one per device.
max_to_keep: Maximum number of recent checkpoints to keep. max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
Defaults to 5. keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
keep_checkpoint_every_n_hours: How often to keep checkpoints. 10,000 hours.
Defaults to 10,000 hours.
name: String. Optional name to use as a prefix when adding operations. name: String. Optional name to use as a prefix when adding operations.
restore_sequentially: A `Bool`, which if true, causes restore of different restore_sequentially: A `Bool`, which if true, causes restore of different
variables to happen sequentially within each device. This can lower variables to happen sequentially within each device. This can lower
memory usage when restoring very large models. memory usage when restoring very large models.
saver_def: Optional `SaverDef` proto to use instead of running the saver_def: Optional `SaverDef` proto to use instead of running the
builder. This is only useful for specialty code that wants to recreate builder. This is only useful for specialty code that wants to recreate a
a `Saver` object for a previously built `Graph` that had a `Saver`. `Saver` object for a previously built `Graph` that had a `Saver`. The
The `saver_def` proto should be the one returned by the `saver_def` proto should be the one returned by the `as_saver_def()`
`as_saver_def()` call of the `Saver` that was created for that `Graph`. call of the `Saver` that was created for that `Graph`.
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
Defaults to `BulkSaverBuilder()`. Defaults to `BulkSaverBuilder()`.
defer_build: If `True`, defer adding the save and restore ops to the defer_build: If `True`, defer adding the save and restore ops to the
`build()` call. In that case `build()` should be called before `build()` call. In that case `build()` should be called before
finalizing the graph or using the saver. finalizing the graph or using the saver.
allow_empty: If `False` (default) raise an error if there are no allow_empty: If `False` (default) raise an error if there are no variables
variables in the graph. Otherwise, construct the saver anyway and make in the graph. Otherwise, construct the saver anyway and make it a no-op.
it a no-op.
write_version: controls what format to use when saving checkpoints. It write_version: controls what format to use when saving checkpoints. It
also affects certain filepath matching logic. The V2 format is the also affects certain filepath matching logic. The V2 format is the
recommended choice: it is much more optimized than V1 in terms of recommended choice: it is much more optimized than V1 in terms of memory
memory required and latency incurred during restore. Regardless of required and latency incurred during restore. Regardless of this
this flag, the Saver is able to restore from both V2 and V1 checkpoints. flag, the Saver is able to restore from both V2 and V1 checkpoints.
pad_step_number: if True, pads the global step number in the checkpoint pad_step_number: if True, pads the global step number in the checkpoint
filepaths to some fixed width (8 by default). This is turned off by filepaths to some fixed width (8 by default). This is turned off by
default. default.
@ -877,7 +871,8 @@ class Saver(object):
name=self._name, name=self._name,
restore_sequentially=self._restore_sequentially, restore_sequentially=self._restore_sequentially,
filename=checkpoint_path, filename=checkpoint_path,
build_save=build_save, build_restore=build_restore) build_save=build_save,
build_restore=build_restore)
elif self.saver_def and self._name: elif self.saver_def and self._name:
# Since self._name is used as a name_scope by builder(), we are # Since self._name is used as a name_scope by builder(), we are
# overloading the use of this field to represent the "import_scope" as # overloading the use of this field to represent the "import_scope" as
@ -997,8 +992,8 @@ class Saver(object):
saver_def.filename_tensor_name, export_scope) saver_def.filename_tensor_name, export_scope)
saver_def.save_tensor_name = ops.strip_name_scope( saver_def.save_tensor_name = ops.strip_name_scope(
saver_def.save_tensor_name, export_scope) saver_def.save_tensor_name, export_scope)
saver_def.restore_op_name = ops.strip_name_scope( saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
saver_def.restore_op_name, export_scope) export_scope)
return saver_def return saver_def
@staticmethod @staticmethod
@ -1092,14 +1087,13 @@ class Saver(object):
Args: Args:
sess: A Session to use to save the variables. sess: A Session to use to save the variables.
save_path: String. Prefix of filenames created for the checkpoint. save_path: String. Prefix of filenames created for the checkpoint.
global_step: If provided the global step number is appended to global_step: If provided the global step number is appended to `save_path`
`save_path` to create the checkpoint filenames. The optional argument to create the checkpoint filenames. The optional argument can be a
can be a `Tensor`, a `Tensor` name or an integer. `Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoints. That file, contains the list of most recent checkpoints. That file, kept in the
kept in the same directory as the checkpoint files, is automatically same directory as the checkpoint files, is automatically managed by the
managed by the saver to keep track of recent checkpoints. Defaults to saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
'checkpoint'.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'. meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
write_meta_graph: `Boolean` indicating whether or not to write the meta write_meta_graph: `Boolean` indicating whether or not to write the meta
graph file. graph file.
@ -1107,7 +1101,8 @@ class Saver(object):
`CheckpointStateProto`. `CheckpointStateProto`.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). [Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file, save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of save_path and with `_debug` added before which in the same directory of save_path and with `_debug` added before
the file extension. This is only enabled when `write_meta_graph` is the file extension. This is only enabled when `write_meta_graph` is
@ -1151,8 +1146,7 @@ class Saver(object):
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step)) checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
else: else:
checkpoint_file = save_path checkpoint_file = save_path
if os.path.basename( if os.path.basename(save_path) == latest_filename and not self._sharded:
save_path) == latest_filename and not self._sharded:
# Guard against collision between data file and checkpoint state file. # Guard against collision between data file and checkpoint state file.
raise ValueError( raise ValueError(
"'latest_filename' collides with 'save_path': '%s' and '%s'" % "'latest_filename' collides with 'save_path': '%s' and '%s'" %
@ -1197,7 +1191,8 @@ class Saver(object):
if not context.executing_eagerly(): if not context.executing_eagerly():
with sess.graph.as_default(): with sess.graph.as_default():
self.export_meta_graph( self.export_meta_graph(
meta_graph_filename, strip_default_attrs=strip_default_attrs, meta_graph_filename,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info) save_debug_info=save_debug_info)
if self._is_empty: if self._is_empty:
@ -1225,11 +1220,12 @@ class Saver(object):
clear_devices: Whether or not to clear the device field for an `Operation` clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export. or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated graph (both Save/Restore ops and SaverDefs) that are not associated with
with this Saver. this Saver.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). [Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file, save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before which in the same directory of filename and with `_debug` added before
the file extension. the file extension.
@ -1274,8 +1270,8 @@ class Saver(object):
raise ValueError("Can't load save_path when it is None.") raise ValueError("Can't load save_path when it is None.")
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)): if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: " raise ValueError("The passed save_path is not a valid checkpoint: " +
+ compat.as_text(save_path)) compat.as_text(save_path))
logging.info("Restoring parameters from %s", compat.as_text(save_path)) logging.info("Restoring parameters from %s", compat.as_text(save_path))
try: try:
@ -1330,13 +1326,15 @@ class Saver(object):
key: One of the GraphKeys or user-defined string. key: One of the GraphKeys or user-defined string.
export_scope: Optional `string`. Name scope to remove. export_scope: Optional `string`. Name scope to remove.
""" """
meta_graph.add_collection_def(meta_graph_def, key, meta_graph.add_collection_def(
export_scope=export_scope) meta_graph_def, key, export_scope=export_scope)
@tf_export(v1=["train.import_meta_graph"]) @tf_export(v1=["train.import_meta_graph"])
def import_meta_graph(meta_graph_or_file, clear_devices=False, def import_meta_graph(meta_graph_or_file,
import_scope=None, **kwargs): clear_devices=False,
import_scope=None,
**kwargs):
"""Recreates a Graph saved in a `MetaGraphDef` proto. """Recreates a Graph saved in a `MetaGraphDef` proto.
This function takes a `MetaGraphDef` protocol buffer as input. If This function takes a `MetaGraphDef` protocol buffer as input. If
@ -1358,10 +1356,10 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python ```Python
... ...
# Create a saver. # Create a saver.
saver = tf.train.Saver(...variables...) saver = tf.compat.v1.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection. # Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op) tf.compat.v1.add_to_collection('train_op', train_op)
sess = tf.Session() sess = tf.compat.v1.Session()
for step in xrange(1000000): for step in xrange(1000000):
sess.run(train_op) sess.run(train_op)
if step % 1000 == 0: if step % 1000 == 0:
@ -1374,12 +1372,13 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
the model from scratch. the model from scratch.
```Python ```Python
with tf.Session() as sess: with tf.compat.v1.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') new_saver =
tf.compat.v1.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000') new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() returns a list. In this example we only want the # tf.compat.v1.get_collection() returns a list. In this example we only want
# first one. # the first one.
train_op = tf.get_collection('train_op')[0] train_op = tf.compat.v1.get_collection('train_op')[0]
for step in xrange(1000000): for step in xrange(1000000):
sess.run(train_op) sess.run(train_op)
``` ```
@ -1393,14 +1392,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python ```Python
# Saving contents and operations. # Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1") v1 = tf.compat.v1.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2") v2 = tf.compat.v1.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2) v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx") vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4") v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx]) saver = tf.compat.v1.train.Saver([vx])
sess = tf.Session() sess = tf.compat.v1.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.compat.v1.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx))) sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result) print(result)
@ -1411,8 +1410,8 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python ```Python
# Restoring variables and running operations. # Restoring variables and running operations.
saver = tf.train.import_meta_graph("./model_ex1.meta") saver = tf.compat.v1.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session() sess = tf.compat.v1.Session()
saver.restore(sess, "./model_ex1") saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result) print(result)
@ -1441,13 +1440,16 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
execution is enabled. execution is enabled.
@end_compatibility @end_compatibility
""" # pylint: disable=g-doc-exception """ # pylint: disable=g-doc-exception
return _import_meta_graph_with_return_elements( return _import_meta_graph_with_return_elements(meta_graph_or_file,
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0] clear_devices, import_scope,
**kwargs)[0]
def _import_meta_graph_with_return_elements( def _import_meta_graph_with_return_elements(meta_graph_or_file,
meta_graph_or_file, clear_devices=False, import_scope=None, clear_devices=False,
return_elements=None, **kwargs): import_scope=None,
return_elements=None,
**kwargs):
"""Import MetaGraph, and return both a saver and returned elements.""" """Import MetaGraph, and return both a saver and returned elements."""
if context.executing_eagerly(): if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when " raise RuntimeError("Exporting/importing meta graphs is not supported when "
@ -1466,13 +1468,13 @@ def _import_meta_graph_with_return_elements(
return_elements=return_elements, return_elements=return_elements,
**kwargs)) **kwargs))
saver = _create_saver_from_imported_meta_graph( saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
meta_graph_def, import_scope, imported_vars) imported_vars)
return saver, imported_return_elements return saver, imported_return_elements
def _create_saver_from_imported_meta_graph( def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
meta_graph_def, import_scope, imported_vars): imported_vars):
"""Return a saver for restoring variable values to an imported MetaGraph.""" """Return a saver for restoring variable values to an imported MetaGraph."""
if meta_graph_def.HasField("saver_def"): if meta_graph_def.HasField("saver_def"):
# Infer the scope that is prepended by `import_scoped_meta_graph`. # Infer the scope that is prepended by `import_scoped_meta_graph`.
@ -1510,7 +1512,9 @@ def export_meta_graph(filename=None,
save_debug_info=False, save_debug_info=False,
**kwargs): **kwargs):
# pylint: disable=line-too-long # pylint: disable=line-too-long
"""Returns `MetaGraphDef` proto. Optionally writes it to filename. """Returns `MetaGraphDef` proto.
Optionally writes it to filename.
This function exports the graph, saver, and collection objects into This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported `MetaGraphDef` protocol buffer with the intention of it being imported
@ -1518,29 +1522,29 @@ def export_meta_graph(filename=None,
a subgraph. a subgraph.
Args: Args:
filename: Optional filename including the path for writing the filename: Optional filename including the path for writing the generated
generated `MetaGraphDef` protocol buffer. `MetaGraphDef` protocol buffer.
meta_info_def: `MetaInfoDef` protocol buffer. meta_info_def: `MetaInfoDef` protocol buffer.
graph_def: `GraphDef` protocol buffer. graph_def: `GraphDef` protocol buffer.
saver_def: `SaverDef` protocol buffer. saver_def: `SaverDef` protocol buffer.
collection_list: List of string keys to collect. collection_list: List of string keys to collect.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto. as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
graph: The `Graph` to export. If `None`, use the default graph. graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract export_scope: Optional `string`. Name scope under which to extract the
the subgraph. The scope name will be striped from the node definitions subgraph. The scope name will be striped from the node definitions for
for easy import later into new name scopes. If `None`, the whole graph easy import later into new name scopes. If `None`, the whole graph is
is exported. graph_def and export_scope cannot both be specified. exported. graph_def and export_scope cannot both be specified.
clear_devices: Whether or not to clear the device field for an `Operation` clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export. or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the clear_extraneous_savers: Remove any Saver-related information from the graph
graph (both Save/Restore ops and SaverDefs) that are not associated (both Save/Restore ops and SaverDefs) that are not associated with the
with the provided SaverDef. provided SaverDef.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file, save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before which in the same directory of filename and with `_debug` added before the
the file extend. file extend.
**kwargs: Optional keyed arguments. **kwargs: Optional keyed arguments.
Returns: Returns:
@ -1603,10 +1607,8 @@ def object_graph_key_mapping(checkpoint_path):
Dictionary mapping tensor names to checkpoint keys. Dictionary mapping tensor names to checkpoint keys.
""" """
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
object_graph_string = reader.get_tensor( object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
trackable.OBJECT_GRAPH_PROTO_KEY) object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto = (
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string) object_graph_proto.ParseFromString(object_graph_string)
names_to_keys = {} names_to_keys = {}
for node in object_graph_proto.nodes: for node in object_graph_proto.nodes:
@ -1615,9 +1617,11 @@ def object_graph_key_mapping(checkpoint_path):
return names_to_keys return names_to_keys
def saver_from_object_based_checkpoint( def saver_from_object_based_checkpoint(checkpoint_path,
checkpoint_path, var_list=None, builder=None, names_to_keys=None, var_list=None,
cached_saver=None): builder=None,
names_to_keys=None,
cached_saver=None):
"""Return a `Saver` which reads from an object-based checkpoint. """Return a `Saver` which reads from an object-based checkpoint.
This function validates that all variables in the variables list are remapped This function validates that all variables in the variables list are remapped
@ -1659,8 +1663,8 @@ def saver_from_object_based_checkpoint(
try: try:
names_to_keys = object_graph_key_mapping(checkpoint_path) names_to_keys = object_graph_key_mapping(checkpoint_path)
except errors.NotFoundError: except errors.NotFoundError:
raise ValueError("Checkpoint in %s not an object-based checkpoint." raise ValueError("Checkpoint in %s not an object-based checkpoint." %
% checkpoint_path) checkpoint_path)
if var_list is None: if var_list is None:
var_list = variables._all_saveable_objects() # pylint: disable=protected-access var_list = variables._all_saveable_objects() # pylint: disable=protected-access
if builder is None: if builder is None:
@ -1677,7 +1681,8 @@ def saver_from_object_based_checkpoint(
extra_names = previous_names - current_names extra_names = previous_names - current_names
intersecting_names = previous_names.intersection(current_names) intersecting_names = previous_names.intersection(current_names)
raise errors.NotFoundError( raise errors.NotFoundError(
None, None, None,
None,
message=( message=(
"\n\nExisting variables not in the checkpoint: %s\n\n" "\n\nExisting variables not in the checkpoint: %s\n\n"
"Variables names when this checkpoint was written which don't " "Variables names when this checkpoint was written which don't "
@ -1695,9 +1700,9 @@ def saver_from_object_based_checkpoint(
"existed, and if variable names have changed you may need to " "existed, and if variable names have changed you may need to "
"make this a dictionary with the old names as keys. If you're " "make this a dictionary with the old names as keys. If you're "
"using an Estimator, you'll need to return a tf.train.Saver " "using an Estimator, you'll need to return a tf.train.Saver "
"inside a tf.train.Scaffold from your model_fn.") "inside a tf.train.Scaffold from your model_fn.") %
% (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)), (", ".join(sorted(missing_names)), ", ".join(
len(intersecting_names))) sorted(extra_names)), len(intersecting_names)))
for saveable in saveables: for saveable in saveables:
for spec in saveable.specs: for spec in saveable.specs:
spec.name = names_to_keys[spec.name] spec.name = names_to_keys[spec.name]

View File

@ -32,21 +32,19 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
"""Creates a `tf.train.ServerDef` protocol buffer. """Creates a `tf.train.ServerDef` protocol buffer.
Args: Args:
server_or_cluster_def: A `tf.train.ServerDef` or server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
`tf.train.ClusterDef` protocol buffer, or a protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
`tf.train.ClusterSpec` object, describing the server to be to be defined and/or the cluster of which it is a member.
defined and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is a
job_name: (Optional.) Specifies the name of the job of which the server member. Defaults to the value in `server_or_cluster_def`, if specified.
is a member. Defaults to the value in `server_or_cluster_def`, if
specified.
task_index: (Optional.) Specifies the task index of the server in its job. task_index: (Optional.) Specifies the task index of the server in its job.
Defaults to the value in `server_or_cluster_def`, if specified. Otherwise Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
defaults to 0 if the server's job has only one task. defaults to 0 if the server's job has only one task.
protocol: (Optional.) Specifies the protocol to be used by the server. protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
config: (Options.) A `tf.ConfigProto` that specifies default configuration config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
options for all sessions that run on this server. configuration options for all sessions that run on this server.
Returns: Returns:
A `tf.train.ServerDef`. A `tf.train.ServerDef`.
@ -88,7 +86,9 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
server_def = tensorflow_server_pb2.ServerDef( server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_spec.as_cluster_def(), cluster=cluster_spec.as_cluster_def(),
job_name=job_name, task_index=task_index, protocol=protocol) job_name=job_name,
task_index=task_index,
protocol=protocol)
if config is not None: if config is not None:
server_def.default_session_config.MergeFrom(config) server_def.default_session_config.MergeFrom(config)
return server_def return server_def
@ -99,8 +99,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
class Server(object): class Server(object):
"""An in-process TensorFlow server, for use in distributed training. """An in-process TensorFlow server, for use in distributed training.
A `tf.train.Server` instance encapsulates a set of devices and a A `tf.distribute.Server` instance encapsulates a set of devices and a
`tf.Session` target that `tf.compat.v1.Session` target that
can participate in distributed training. A server belongs to a can participate in distributed training. A server belongs to a
cluster (specified by a `tf.train.ClusterSpec`), and cluster (specified by a `tf.train.ClusterSpec`), and
corresponds to a particular task in a named job. The server can corresponds to a particular task in a named job. The server can
@ -120,31 +120,30 @@ class Server(object):
override any information provided in `server_or_cluster_def`. override any information provided in `server_or_cluster_def`.
Args: Args:
server_or_cluster_def: A `tf.train.ServerDef` or server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
`tf.train.ClusterDef` protocol buffer, or a protocol buffer, or a `tf.train.ClusterSpec` object, describing the
`tf.train.ClusterSpec` object, describing the server to be server to be created and/or the cluster of which it is a member.
created and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is
job_name: (Optional.) Specifies the name of the job of which the server a member. Defaults to the value in `server_or_cluster_def`, if
is a member. Defaults to the value in `server_or_cluster_def`, if
specified. specified.
task_index: (Optional.) Specifies the task index of the server in its task_index: (Optional.) Specifies the task index of the server in its job.
job. Defaults to the value in `server_or_cluster_def`, if specified. Defaults to the value in `server_or_cluster_def`, if specified.
Otherwise defaults to 0 if the server's job has only one task. Otherwise defaults to 0 if the server's job has only one task.
protocol: (Optional.) Specifies the protocol to be used by the server. protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
value in `server_or_cluster_def`, if specified. Otherwise defaults to in `server_or_cluster_def`, if specified. Otherwise defaults to
`"grpc"`. `"grpc"`.
config: (Options.) A `tf.ConfigProto` that specifies default config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
configuration options for all sessions that run on this server. configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server start: (Optional.) Boolean, indicating whether to start the server after
after creating it. Defaults to `True`. creating it. Defaults to `True`.
Raises: Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server. creating the TensorFlow server.
""" """
self._server_def = _make_server_def(server_or_cluster_def, self._server_def = _make_server_def(server_or_cluster_def, job_name,
job_name, task_index, protocol, config) task_index, protocol, config)
self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
if start: if start:
self.start() self.start()
@ -195,15 +194,15 @@ class Server(object):
@property @property
def target(self): def target(self):
"""Returns the target for a `tf.Session` to connect to this server. """Returns the target for a `tf.compat.v1.Session` to connect to this server.
To create a To create a
`tf.Session` that `tf.compat.v1.Session` that
connects to this server, use the following snippet: connects to this server, use the following snippet:
```python ```python
server = tf.train.Server(...) server = tf.distribute.Server(...)
with tf.Session(server.target): with tf.compat.v1.Session(server.target):
# ... # ...
``` ```
@ -217,22 +216,24 @@ class Server(object):
"""Creates a new single-process cluster running on the local host. """Creates a new single-process cluster running on the local host.
This method is a convenience wrapper for creating a This method is a convenience wrapper for creating a
`tf.train.Server` with a `tf.train.ServerDef` that specifies a `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
single-process cluster containing a single task in a job called single-process cluster containing a single task in a job called
`"local"`. `"local"`.
Args: Args:
config: (Options.) A `tf.ConfigProto` that specifies default config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
configuration options for all sessions that run on this server. configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server after start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`. creating it. Defaults to `True`.
Returns: Returns:
A local `tf.train.Server`. A local `tf.distribute.Server`.
""" """
# Specifying port 0 means that the OS will choose a free port for the # Specifying port 0 means that the OS will choose a free port for the
# server. # server.
return Server({"local": ["localhost:0"]}, protocol="grpc", config=config, return Server({"local": ["localhost:0"]},
protocol="grpc",
config=config,
start=start) start=start)
@ -242,7 +243,7 @@ class ClusterSpec(object):
A `tf.train.ClusterSpec` represents the set of processes that A `tf.train.ClusterSpec` represents the set of processes that
participate in a distributed TensorFlow computation. Every participate in a distributed TensorFlow computation. Every
`tf.train.Server` is constructed in a particular cluster. `tf.distribute.Server` is constructed in a particular cluster.
To create a cluster with two jobs and five tasks, you specify the To create a cluster with two jobs and five tasks, you specify the
mapping from job names to lists of network addresses (typically mapping from job names to lists of network addresses (typically
@ -272,10 +273,9 @@ class ClusterSpec(object):
"""Creates a `ClusterSpec`. """Creates a `ClusterSpec`.
Args: Args:
cluster: A dictionary mapping one or more job names to (i) a cluster: A dictionary mapping one or more job names to (i) a list of
list of network addresses, or (ii) a dictionary mapping integer network addresses, or (ii) a dictionary mapping integer task indices to
task indices to network addresses; or a `tf.train.ClusterDef` network addresses; or a `tf.train.ClusterDef` protocol buffer.
protocol buffer.
Raises: Raises:
TypeError: If `cluster` is not a dictionary mapping strings to lists TypeError: If `cluster` is not a dictionary mapping strings to lists
@ -298,14 +298,16 @@ class ClusterSpec(object):
self._cluster_spec = {} self._cluster_spec = {}
for job_def in self._cluster_def.job: for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = { self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()} i: t for i, t in job_def.tasks.items()
}
elif isinstance(cluster, ClusterSpec): elif isinstance(cluster, ClusterSpec):
self._cluster_def = cluster_pb2.ClusterDef() self._cluster_def = cluster_pb2.ClusterDef()
self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_def.MergeFrom(cluster.as_cluster_def())
self._cluster_spec = {} self._cluster_spec = {}
for job_def in self._cluster_def.job: for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = { self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()} i: t for i, t in job_def.tasks.items()
}
else: else:
raise TypeError("`cluster` must be a dictionary mapping one or more " raise TypeError("`cluster` must be a dictionary mapping one or more "
"job names to lists of network addresses, or a " "job names to lists of network addresses, or a "
@ -326,7 +328,8 @@ class ClusterSpec(object):
def __str__(self): def __str__(self):
key_values = self.as_dict() key_values = self.as_dict()
string_items = [ string_items = [
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)] repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
]
return "ClusterSpec({" + ", ".join(string_items) + "})" return "ClusterSpec({" + ", ".join(string_items) + "})"
def as_dict(self): def as_dict(self):
@ -427,8 +430,8 @@ class ClusterSpec(object):
try: try:
return job[task_index] return job[task_index]
except KeyError: except KeyError:
raise ValueError("No task with index %r in job %r" raise ValueError("No task with index %r in job %r" %
% (task_index, job_name)) (task_index, job_name))
def job_tasks(self, job_name): def job_tasks(self, job_name):
"""Returns a mapping from task ID to address in the given job. """Returns a mapping from task ID to address in the given job.
@ -482,6 +485,6 @@ class ClusterSpec(object):
try: try:
task_address = compat.as_bytes(task_address) task_address = compat.as_bytes(task_address)
except TypeError: except TypeError:
raise TypeError( raise TypeError("Task address %r must be bytes or unicode" %
"Task address %r must be bytes or unicode" % task_address) task_address)
job_def.tasks[i] = task_address job_def.tasks[i] = task_address

View File

@ -107,7 +107,7 @@ class GrpcServerTest(test.TestCase):
self.assertAllEqual(2.0, sess.run(v1)) self.assertAllEqual(2.0, sess.run(v1))
def _useRPCConfig(self): def _useRPCConfig(self):
"""Return a `tf.ConfigProto` that ensures we use the RPC stack for tests. """Return a `tf.compat.v1.ConfigProto` that ensures we use the RPC stack for tests.
This configuration ensures that we continue to exercise the gRPC This configuration ensures that we continue to exercise the gRPC
stack when testing, rather than using the in-process optimization, stack when testing, rather than using the in-process optimization,
@ -115,7 +115,7 @@ class GrpcServerTest(test.TestCase):
master in the same process. master in the same process.
Returns: Returns:
A `tf.ConfigProto`. A `tf.compat.v1.ConfigProto`.
""" """
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions( return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
use_rpc_for_inprocess_master=True)) use_rpc_for_inprocess_master=True))

View File

@ -174,8 +174,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized) self.assertFalse(initialized)
sess.run(v.initializer) sess.run(v.initializer)
self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(v))
saver.save(sess, saver.save(sess, os.path.join(checkpoint_dir,
os.path.join(checkpoint_dir, "recover_session_checkpoint")) "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir) self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable( self._test_recovered_variable(
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint( checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
@ -202,9 +202,9 @@ class SessionManagerTest(test.TestCase):
def testInitWithNoneLocalInitOpError(self): def testInitWithNoneLocalInitOpError(self):
# Creating a SessionManager with a None local_init_op but # Creating a SessionManager with a None local_init_op but
# non-None ready_for_local_init_op raises ValueError # non-None ready_for_local_init_op raises ValueError
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(
"If you pass a ready_for_local_init_op " ValueError, "If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "): "you must also pass a local_init_op "):
session_manager.SessionManager( session_manager.SessionManager(
ready_for_local_init_op=variables.report_uninitialized_variables( ready_for_local_init_op=variables.report_uninitialized_variables(
variables.global_variables()), variables.global_variables()),
@ -231,8 +231,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized) self.assertFalse(initialized)
sess.run(v.initializer) sess.run(v.initializer)
self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(v))
saver.save(sess, saver.save(sess, os.path.join(checkpoint_dir,
os.path.join(checkpoint_dir, "recover_session_checkpoint")) "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover. # Create a new Graph and SessionManager and recover.
with ops.Graph().as_default(): with ops.Graph().as_default():
v = variables.VariableV1(2, name="v") v = variables.VariableV1(2, name="v")
@ -266,7 +266,7 @@ class SessionManagerTest(test.TestCase):
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self): def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
# We use ready_for_local_init_op=tf.report_uninitialized_variables(), # We use ready_for_local_init_op=report_uninitialized_variables(),
# which causes recover_session to not run local_init_op, and to return # which causes recover_session to not run local_init_op, and to return
# initialized=False # initialized=False
@ -290,8 +290,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized) self.assertFalse(initialized)
sess.run(v.initializer) sess.run(v.initializer)
self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(v))
saver.save(sess, saver.save(sess, os.path.join(checkpoint_dir,
os.path.join(checkpoint_dir, "recover_session_checkpoint")) "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover. # Create a new Graph and SessionManager and recover.
with ops.Graph().as_default(): with ops.Graph().as_default():
v = variables.VariableV1(2, name="v") v = variables.VariableV1(2, name="v")
@ -780,8 +780,8 @@ class ObsoleteSessionManagerTest(test.TestCase):
self.assertFalse(initialized) self.assertFalse(initialized)
sess.run(v.initializer) sess.run(v.initializer)
self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(v))
saver.save(sess, saver.save(sess, os.path.join(checkpoint_dir,
os.path.join(checkpoint_dir, "recover_session_checkpoint")) "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover. # Create a new Graph and SessionManager and recover.
with ops.Graph().as_default(): with ops.Graph().as_default():
v = variables.VariableV1(2, name="v") v = variables.VariableV1(2, name="v")

View File

@ -68,7 +68,7 @@ look at following code:
Above user code leads to following execution: Above user code leads to following execution:
call hooks.begin() call hooks.begin()
sess = tf.Session() sess = tf.compat.v1.Session()
call hooks.after_create_session() call hooks.after_create_session()
while not stop is requested: while not stop is requested:
call hooks.before_run() call hooks.before_run()

View File

@ -56,9 +56,9 @@ class SummaryWriter(_FileWriter):
```python ```python
...create a graph... ...create a graph...
# Launch the graph in a session. # Launch the graph in a session.
sess = tf.Session() sess = tf.compat.v1.Session()
# Create a summary writer, add the 'graph' to the event file. # Create a summary writer, add the 'graph' to the event file.
writer = tf.summary.FileWriter(<some-directory>, sess.graph) writer = tf.compat.v1.summary.FileWriter(<some-directory>, sess.graph)
``` ```
The other arguments to the constructor control the asynchronous writes to The other arguments to the constructor control the asynchronous writes to

View File

@ -45,7 +45,7 @@ class Supervisor(object):
"""A training helper that checkpoints models and computes summaries. """A training helper that checkpoints models and computes summaries.
This class is deprecated. Please use This class is deprecated. Please use
`tf.train.MonitoredTrainingSession` instead. `tf.compat.v1.train.MonitoredTrainingSession` instead.
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`, The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
and a `SessionManager` that takes care of common needs of TensorFlow and a `SessionManager` that takes care of common needs of TensorFlow
@ -97,7 +97,7 @@ class Supervisor(object):
# or job_def.name, or job_def.tasks. It's entirely up to the end user. # or job_def.name, or job_def.tasks. It's entirely up to the end user.
# But there can be only one *chief*. # But there can be only one *chief*.
is_chief = (server_def.task_index == 0) is_chief = (server_def.task_index == 0)
server = tf.train.Server(server_def) server = tf.distribute.Server(server_def)
with tf.Graph().as_default(): with tf.Graph().as_default():
...add operations to the graph... ...add operations to the graph...
@ -140,7 +140,7 @@ class Supervisor(object):
* Specifying `'grpc://hostname:port'` requests a session that uses * Specifying `'grpc://hostname:port'` requests a session that uses
the RPC interface to a specific host, and also allows the in-process the RPC interface to a specific host, and also allows the in-process
master to access remote tensorflow workers. Often, it is master to access remote tensorflow workers. Often, it is
appropriate to pass `server.target` (for some `tf.train.Server` appropriate to pass `server.target` (for some `tf.distribute.Server`
named `server). named `server).
#### Advanced use #### Advanced use
@ -237,17 +237,16 @@ class Supervisor(object):
ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
`prepare_or_wait_for_session()` to check if the model is ready to use. `prepare_or_wait_for_session()` to check if the model is ready to use.
The model is considered ready if it returns an empty array. Defaults to The model is considered ready if it returns an empty array. Defaults to
the tensor returned from `tf.report_uninitialized_variables()` If the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
`None`, the model is not checked for readiness. If `None`, the model is not checked for readiness.
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
supervisors in `prepare_or_wait_for_session()` to check if the model is supervisors in `prepare_or_wait_for_session()` to check if the model is
ready to run the local_init_op. ready to run the local_init_op. The model is considered ready if it
The model is considered ready if it returns an empty array. Defaults to returns an empty array. Defaults to `None`. If `None`, the model is not
`None`. If `None`, the model is not checked for readiness before running checked for readiness before running local_init_op.
local_init_op. is_chief: If True, create a chief supervisor in charge of initializing and
is_chief: If True, create a chief supervisor in charge of initializing restoring the model. If False, create a supervisor that relies on a
and restoring the model. If False, create a supervisor that relies chief supervisor for inits and restore.
on a chief supervisor for inits and restore.
init_op: `Operation`. Used by chief supervisors to initialize the model init_op: `Operation`. Used by chief supervisors to initialize the model
when it can not be recovered. Defaults to an `Operation` that when it can not be recovered. Defaults to an `Operation` that
initializes all global variables. If `None`, no initialization is done initializes all global variables. If `None`, no initialization is done
@ -255,20 +254,19 @@ class Supervisor(object):
init_feed_dict: A dictionary that maps `Tensor` objects to feed values. init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
This feed dictionary will be used when `init_op` is evaluated. This feed dictionary will be used when `init_op` is evaluated.
local_init_op: `Operation`. Used by all supervisors to run initializations local_init_op: `Operation`. Used by all supervisors to run initializations
that should run for every new supervisor instance. By default these that should run for every new supervisor instance. By default these are
are table initializers and initializers for local variables. table initializers and initializers for local variables. If `None`, no
If `None`, no further per supervisor-instance initialization is further per supervisor-instance initialization is done automatically.
done automatically.
logdir: A string. Optional path to a directory where to checkpoint the logdir: A string. Optional path to a directory where to checkpoint the
model and log events for the visualizer. Used by chief supervisors. model and log events for the visualizer. Used by chief supervisors. The
The directory will be created if it does not exist. directory will be created if it does not exist.
summary_op: An `Operation` that returns a Summary for the event logs. summary_op: An `Operation` that returns a Summary for the event logs. Used
Used by chief supervisors if a `logdir` was specified. Defaults to the by chief supervisors if a `logdir` was specified. Defaults to the
operation returned from summary.merge_all(). If `None`, summaries are operation returned from summary.merge_all(). If `None`, summaries are
not computed automatically. not computed automatically.
saver: A Saver object. Used by chief supervisors if a `logdir` was saver: A Saver object. Used by chief supervisors if a `logdir` was
specified. Defaults to the saved returned by Saver(). specified. Defaults to the saved returned by Saver(). If `None`, the
If `None`, the model is not saved automatically. model is not saved automatically.
global_step: An integer Tensor of size 1 that counts steps. The value global_step: An integer Tensor of size 1 that counts steps. The value
from 'global_step' is used in summaries and checkpoint filenames. from 'global_step' is used in summaries and checkpoint filenames.
Default to the op named 'global_step' in the graph if it exists, is of Default to the op named 'global_step' in the graph if it exists, is of
@ -280,20 +278,20 @@ class Supervisor(object):
disable summaries. disable summaries.
save_model_secs: Number of seconds between the creation of model save_model_secs: Number of seconds between the creation of model
checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints. checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
recovery_wait_secs: Number of seconds between checks that the model recovery_wait_secs: Number of seconds between checks that the model is
is ready. Used by supervisors when waiting for a chief supervisor ready. Used by supervisors when waiting for a chief supervisor to
to initialize or restore the model. Defaults to 30 seconds. initialize or restore the model. Defaults to 30 seconds.
stop_grace_secs: Grace period, in seconds, given to running threads to stop_grace_secs: Grace period, in seconds, given to running threads to
stop when `stop()` is called. Defaults to 120 seconds. stop when `stop()` is called. Defaults to 120 seconds.
checkpoint_basename: The basename for checkpoint saving. checkpoint_basename: The basename for checkpoint saving.
session_manager: `SessionManager`, which manages Session creation and session_manager: `SessionManager`, which manages Session creation and
recovery. If it is `None`, a default `SessionManager` will be created recovery. If it is `None`, a default `SessionManager` will be created
with the set of arguments passed in for backwards compatibility. with the set of arguments passed in for backwards compatibility.
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to
to indicate that no summaries should be written. indicate that no summaries should be written.
init_fn: Optional callable used to initialize the model. Called init_fn: Optional callable used to initialize the model. Called after the
after the optional `init_op` is called. The callable must accept one optional `init_op` is called. The callable must accept one argument,
argument, the session being initialized. the session being initialized.
local_init_run_options: RunOptions to be passed as the SessionManager local_init_run_options: RunOptions to be passed as the SessionManager
local_init_run_options parameter. local_init_run_options parameter.
@ -397,12 +395,11 @@ class Supervisor(object):
"""Initializes ready_op. """Initializes ready_op.
Args: Args:
ready_op: `Tensor` to check if the model is initialized. ready_op: `Tensor` to check if the model is initialized. If it's set to
If it's set to USE_DEFAULT, creates an op that checks all USE_DEFAULT, creates an op that checks all the variables are
the variables are initialized. initialized.
ready_for_local_init_op: `Tensor` to check if the model is ready to run ready_for_local_init_op: `Tensor` to check if the model is ready to run
local_init_op. local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
If it's set to USE_DEFAULT, creates an op that checks all
the global variables are initialized. the global variables are initialized.
""" """
if ready_op is Supervisor.USE_DEFAULT: if ready_op is Supervisor.USE_DEFAULT:
@ -440,9 +437,9 @@ class Supervisor(object):
Args: Args:
local_init_op: `Operation` run for every new supervisor instance. If set local_init_op: `Operation` run for every new supervisor instance. If set
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
collection. If the collection is empty, create an op that initializes collection. If the collection is empty, create an op that initializes
all local variables and all tables. all local variables and all tables.
""" """
if local_init_op is Supervisor.USE_DEFAULT: if local_init_op is Supervisor.USE_DEFAULT:
local_init_op = self._get_first_op_from_collection( local_init_op = self._get_first_op_from_collection(
@ -461,8 +458,8 @@ class Supervisor(object):
"""Initializes saver. """Initializes saver.
Args: Args:
saver: A `Saver` object. If set to USE_DEFAULT, create one that saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
saves all the variables. the variables.
""" """
if saver is Supervisor.USE_DEFAULT: if saver is Supervisor.USE_DEFAULT:
saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS) saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
@ -475,8 +472,8 @@ class Supervisor(object):
"""Initializes summary_op. """Initializes summary_op.
Args: Args:
summary_op: An Operation that returns a Summary for the event logs. summary_op: An Operation that returns a Summary for the event logs. If set
If set to USE_DEFAULT, create an op that merges all the summaries. to USE_DEFAULT, create an op that merges all the summaries.
""" """
if summary_op is Supervisor.USE_DEFAULT: if summary_op is Supervisor.USE_DEFAULT:
summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP) summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
@ -490,8 +487,8 @@ class Supervisor(object):
"""Initializes global_step. """Initializes global_step.
Args: Args:
global_step: An integer Tensor of size 1 that counts steps. If global_step: An integer Tensor of size 1 that counts steps. If set to
set to USE_DEFAULT, creates global_step tensor. USE_DEFAULT, creates global_step tensor.
""" """
if global_step is Supervisor.USE_DEFAULT: if global_step is Supervisor.USE_DEFAULT:
global_step = self._get_first_op_from_collection( global_step = self._get_first_op_from_collection(
@ -630,8 +627,9 @@ class Supervisor(object):
"""Writes graph_def to `logdir` and adds it to summary if applicable.""" """Writes graph_def to `logdir` and adds it to summary if applicable."""
assert self._is_chief assert self._is_chief
if self._logdir: if self._logdir:
training_util.write_graph(self._graph.as_graph_def(add_shapes=True), training_util.write_graph(
self._logdir, "graph.pbtxt") self._graph.as_graph_def(add_shapes=True), self._logdir,
"graph.pbtxt")
if self._summary_writer and not self._graph_added_to_summary: if self._summary_writer and not self._graph_added_to_summary:
self._summary_writer.add_graph(self._graph) self._summary_writer.add_graph(self._graph)
self._summary_writer.add_meta_graph(self._meta_graph_def) self._summary_writer.add_meta_graph(self._meta_graph_def)
@ -675,8 +673,7 @@ class Supervisor(object):
# if there is no step value. # if there is no step value.
current_step = training_util.global_step(sess, self._global_step) current_step = training_util.global_step(sess, self._global_step)
self._summary_writer.add_session_log( self._summary_writer.add_session_log(
SessionLog(status=SessionLog.START), SessionLog(status=SessionLog.START), current_step)
current_step)
threads = [] threads = []
if self._save_summaries_secs and self._summary_writer: if self._save_summaries_secs and self._summary_writer:
@ -690,7 +687,9 @@ class Supervisor(object):
t.start() t.start()
return threads return threads
def prepare_or_wait_for_session(self, master="", config=None, def prepare_or_wait_for_session(self,
master="",
config=None,
wait_for_checkpoint=False, wait_for_checkpoint=False,
max_wait_secs=7200, max_wait_secs=7200,
start_standard_services=True): start_standard_services=True):
@ -702,10 +701,10 @@ class Supervisor(object):
manager to start the standard services. manager to start the standard services.
Args: Args:
master: name of the TensorFlow master to use. See the `tf.Session` master: name of the TensorFlow master to use. See the
constructor for how this is interpreted. `tf.compat.v1.Session` constructor for how this is interpreted.
config: Optional ConfigProto proto used to configure the session, config: Optional ConfigProto proto used to configure the session, which is
which is passed as-is to create the session. passed as-is to create the session.
wait_for_checkpoint: Whether we should wait for the availability of a wait_for_checkpoint: Whether we should wait for the availability of a
checkpoint before creating Session. Defaults to False. checkpoint before creating Session. Defaults to False.
max_wait_secs: Maximum time to wait for the session to become available. max_wait_secs: Maximum time to wait for the session to become available.
@ -724,18 +723,22 @@ class Supervisor(object):
if self._is_chief: if self._is_chief:
sess = self._session_manager.prepare_session( sess = self._session_manager.prepare_session(
master, init_op=self.init_op, saver=self.saver, master,
checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint, init_op=self.init_op,
max_wait_secs=max_wait_secs, config=config, saver=self.saver,
init_feed_dict=self._init_feed_dict, init_fn=self._init_fn) checkpoint_dir=self._logdir,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config,
init_feed_dict=self._init_feed_dict,
init_fn=self._init_fn)
self._write_graph() self._write_graph()
if start_standard_services: if start_standard_services:
logging.info("Starting standard services.") logging.info("Starting standard services.")
self.start_standard_services(sess) self.start_standard_services(sess)
else: else:
sess = self._session_manager.wait_for_session(master, sess = self._session_manager.wait_for_session(
config=config, master, config=config, max_wait_secs=max_wait_secs)
max_wait_secs=max_wait_secs)
if start_standard_services: if start_standard_services:
logging.info("Starting queue runners.") logging.info("Starting queue runners.")
self.start_queue_runners(sess) self.start_queue_runners(sess)
@ -772,8 +775,8 @@ class Supervisor(object):
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
threads = [] threads = []
for qr in queue_runners: for qr in queue_runners:
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, threads.extend(
start=True)) qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
return threads return threads
def loop(self, timer_interval_secs, target, args=None, kwargs=None): def loop(self, timer_interval_secs, target, args=None, kwargs=None):
@ -795,8 +798,12 @@ class Supervisor(object):
Returns: Returns:
The started thread. The started thread.
""" """
looper = coordinator.LooperThread(self._coord, timer_interval_secs, looper = coordinator.LooperThread(
target=target, args=args, kwargs=kwargs) self._coord,
timer_interval_secs,
target=target,
args=args,
kwargs=kwargs)
looper.start() looper.start()
return looper return looper
@ -812,13 +819,13 @@ class Supervisor(object):
threads: Optional list of threads to join with the coordinator. If threads: Optional list of threads to join with the coordinator. If
`None`, defaults to the threads running the standard services, the `None`, defaults to the threads running the standard services, the
threads started for `QueueRunners`, and the threads started by the threads started for `QueueRunners`, and the threads started by the
`loop()` method. To wait on additional threads, pass the `loop()` method. To wait on additional threads, pass the list in this
list in this parameter. parameter.
close_summary_writer: Whether to close the `summary_writer`. Defaults to close_summary_writer: Whether to close the `summary_writer`. Defaults to
`True` if the summary writer was created by the supervisor, `False` `True` if the summary writer was created by the supervisor, `False`
otherwise. otherwise.
ignore_live_threads: If `True` ignores threads that remain running after ignore_live_threads: If `True` ignores threads that remain running after a
a grace period when joining threads via the coordinator, instead of grace period when joining threads via the coordinator, instead of
raising a RuntimeError. raising a RuntimeError.
""" """
self._coord.request_stop() self._coord.request_stop()
@ -926,7 +933,9 @@ class Supervisor(object):
# pylint: disable=g-doc-return-or-yield,broad-except # pylint: disable=g-doc-return-or-yield,broad-except
@contextlib.contextmanager @contextlib.contextmanager
def managed_session(self, master="", config=None, def managed_session(self,
master="",
config=None,
start_standard_services=True, start_standard_services=True,
close_summary_writer=True): close_summary_writer=True):
"""Returns a context manager for a managed session. """Returns a context manager for a managed session.
@ -940,7 +949,7 @@ class Supervisor(object):
```python ```python
def train(): def train():
sv = tf.train.Supervisor(...) sv = tf.compat.v1.train.Supervisor(...)
with sv.managed_session(<master>) as sess: with sv.managed_session(<master>) as sess:
for step in xrange(..): for step in xrange(..):
if sv.should_stop(): if sv.should_stop():
@ -973,14 +982,14 @@ class Supervisor(object):
the training loop and are considered normal termination. the training loop and are considered normal termination.
Args: Args:
master: name of the TensorFlow master to use. See the `tf.Session` master: name of the TensorFlow master to use. See the
constructor for how this is interpreted. `tf.compat.v1.Session` constructor for how this is interpreted.
config: Optional `ConfigProto` proto used to configure the session. config: Optional `ConfigProto` proto used to configure the session. Passed
Passed as-is to create the session. as-is to create the session.
start_standard_services: Whether to start the standard services, start_standard_services: Whether to start the standard services, such as
such as checkpoint, summary and step counter. checkpoint, summary and step counter.
close_summary_writer: Whether to close the summary writer when close_summary_writer: Whether to close the summary writer when closing the
closing the session. Defaults to True. session. Defaults to True.
Returns: Returns:
A context manager that yields a `Session` restored from the latest A context manager that yields a `Session` restored from the latest
@ -989,7 +998,8 @@ class Supervisor(object):
""" """
try: try:
sess = self.prepare_or_wait_for_session( sess = self.prepare_or_wait_for_session(
master=master, config=config, master=master,
config=config,
start_standard_services=start_standard_services) start_standard_services=start_standard_services)
yield sess yield sess
except Exception as e: except Exception as e:
@ -1011,6 +1021,7 @@ class Supervisor(object):
except Exception: except Exception:
# Silently ignore exceptions raised by close(). # Silently ignore exceptions raised by close().
pass pass
# pylint: enable=g-doc-return-or-yield,broad-except # pylint: enable=g-doc-return-or-yield,broad-except
@ -1030,8 +1041,8 @@ class SVSummaryThread(coordinator.LooperThread):
def run_loop(self): def run_loop(self):
if self._sv.global_step is not None: if self._sv.global_step is not None:
summary_strs, global_step = self._sess.run([self._sv.summary_op, summary_strs, global_step = self._sess.run(
self._sv.global_step]) [self._sv.summary_op, self._sv.global_step])
else: else:
summary_strs = self._sess.run(self._sv.summary_op) summary_strs = self._sess.run(self._sv.summary_op)
global_step = None global_step = None
@ -1063,8 +1074,7 @@ class SVStepCounterThread(coordinator.LooperThread):
def start_loop(self): def start_loop(self):
self._last_time = time.time() self._last_time = time.time()
self._last_step = training_util.global_step( self._last_step = training_util.global_step(self._sess, self._step_counter)
self._sess, self._step_counter)
def run_loop(self): def run_loop(self):
# Count the steps. # Count the steps.
@ -1080,12 +1090,13 @@ class SVStepCounterThread(coordinator.LooperThread):
steps_per_sec = added_steps / elapsed_time steps_per_sec = added_steps / elapsed_time
else: else:
steps_per_sec = float("inf") steps_per_sec = float("inf")
summary = Summary(value=[Summary.Value(tag=self._summary_tag, summary = Summary(value=[
simple_value=steps_per_sec)]) Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
])
if self._sv.summary_writer: if self._sv.summary_writer:
self._sv.summary_writer.add_summary(summary, current_step) self._sv.summary_writer.add_summary(summary, current_step)
logging.log_first_n(logging.INFO, "%s: %g", 10, logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
self._summary_tag, steps_per_sec) steps_per_sec)
class SVTimerCheckpointThread(coordinator.LooperThread): class SVTimerCheckpointThread(coordinator.LooperThread):
@ -1104,13 +1115,13 @@ class SVTimerCheckpointThread(coordinator.LooperThread):
def run_loop(self): def run_loop(self):
logging.info("Saving checkpoint to path %s", self._sv.save_path) logging.info("Saving checkpoint to path %s", self._sv.save_path)
self._sv.saver.save(self._sess, self._sv.save_path, self._sv.saver.save(
global_step=self._sv.global_step) self._sess, self._sv.save_path, global_step=self._sv.global_step)
if self._sv.summary_writer and self._sv.global_step is not None: if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step) current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log( self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT, SessionLog(
checkpoint_path=self._sv.save_path), status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
current_step) current_step)

View File

@ -110,7 +110,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
# Note that if you want to have 2 backup replicas, you can change # Note that if you want to have 2 backup replicas, you can change
# total_num_replicas=52 and make sure this number matches how many physical # total_num_replicas=52 and make sure this number matches how many physical
# replicas you started in your job. # replicas you started in your job.
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50, opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
total_num_replicas=50) total_num_replicas=50)
# Some models have startup_delays to help stabilize the model but when using # Some models have startup_delays to help stabilize the model but when using

View File

@ -38,11 +38,9 @@ from tensorflow.python.util import nest
from tensorflow.python.util import serialization from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle # Key where the object graph proto is saved in a TensorBundle
OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
# A key indicating a variable's value in an object's checkpointed Tensors # A key indicating a variable's value in an object's checkpointed Tensors
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and # (Trackable._gather_saveables_for_checkpoint). If this is the only key and
# the object has no dependencies, then its value may be restored on object # the object has no dependencies, then its value may be restored on object
@ -74,8 +72,7 @@ class CheckpointInitialValue(ops.Tensor):
""" """
def __init__(self, checkpoint_position, shape=None): def __init__(self, checkpoint_position, shape=None):
self.wrapped_value = checkpoint_position.value_tensors()[ self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY]
VARIABLE_VALUE_KEY]
if shape: if shape:
# We need to set the static shape information on the initializer if # We need to set the static shape information on the initializer if
# possible so we don't get a variable with an unknown shape. # possible so we don't get a variable with an unknown shape.
@ -97,8 +94,8 @@ class NoRestoreSaveable(saveable_object.SaveableObject):
"""Embeds a tensor in a checkpoint with no restore ops.""" """Embeds a tensor in a checkpoint with no restore ops."""
def __init__(self, tensor, name, dtype=None, device=None): def __init__(self, tensor, name, dtype=None, device=None):
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype, spec = saveable_object.SaveSpec(
device=device) tensor, "", name, dtype=dtype, device=device)
super(NoRestoreSaveable, self).__init__(tensor, [spec], name) super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes): def restore(self, restored_tensors, restored_shapes):
@ -123,7 +120,8 @@ class PythonStateSaveable(saveable_object.SaveableObject):
"""Create a new `SaveableObject` which freezes current state as a constant. """Create a new `SaveableObject` which freezes current state as a constant.
Used when executing eagerly to embed the current state as a constant, or Used when executing eagerly to embed the current state as a constant, or
when creating a static tf.train.Saver with the frozen current Python state. when creating a static tf.compat.v1.train.Saver with the frozen current
Python state.
Returns: Returns:
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
@ -140,24 +138,26 @@ class PythonStringStateSaveable(PythonStateSaveable):
Args: Args:
name: The checkpoint key to write to. name: The checkpoint key to write to.
state_callback: A function taking no arguments which returns a state_callback: A function taking no arguments which returns a string.
string. This function is run every time a checkpoint is written. This function is run every time a checkpoint is written.
restore_callback: A function taking a Python string, used to restore restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing, in which case it is ignored state. Optional; defaults to doing nothing, in which case it is ignored
by status assertions such as assert_consumed(). by status assertions such as assert_consumed().
""" """
self._has_trivial_state_callback = (restore_callback is None) self._has_trivial_state_callback = (restore_callback is None)
def _state_callback_wrapper(): def _state_callback_wrapper():
with ops.init_scope(): with ops.init_scope():
return state_callback() return state_callback()
self._state_callback = _state_callback_wrapper self._state_callback = _state_callback_wrapper
self._restore_callback = restore_callback self._restore_callback = restore_callback
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string) self._save_string = constant_op.constant("", dtype=dtypes.string)
spec = saveable_object.SaveSpec( spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string) self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__( super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
self._save_string, [spec], name) name)
@property @property
def optional_restore(self): def optional_restore(self):
@ -170,8 +170,10 @@ class PythonStringStateSaveable(PythonStateSaveable):
def freeze(self): def freeze(self):
"""Create a frozen `SaveableObject` which saves the current state.""" """Create a frozen `SaveableObject` which saves the current state."""
def _constant_state(): def _constant_state():
return constant_op.constant(self._state_callback(), dtype=dtypes.string) return constant_op.constant(self._state_callback(), dtype=dtypes.string)
return NoRestoreSaveable( return NoRestoreSaveable(
tensor=_constant_state, tensor=_constant_state,
dtype=dtypes.string, dtype=dtypes.string,
@ -217,6 +219,7 @@ class CheckpointPosition(object):
Args: Args:
trackable: The object to record a correspondence for. trackable: The object to record a correspondence for.
Returns: Returns:
True if this is a new assignment, False if this object has already been True if this is a new assignment, False if this object has already been
mapped to a checkpointed `Object` proto. mapped to a checkpointed `Object` proto.
@ -263,21 +266,21 @@ class CheckpointPosition(object):
# consistent (if the dependency DAG is not a tree then there are # consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object). # multiple paths to the same object).
if current_assignment is not trackable: if current_assignment is not trackable:
logging.warning( logging.warning((
("Inconsistent references when loading the checkpoint into this " "Inconsistent references when loading the checkpoint into this "
"object graph. Either the Trackable object references in the " "object graph. Either the Trackable object references in the "
"Python program have changed in an incompatible way, or the " "Python program have changed in an incompatible way, or the "
"checkpoint was generated in an incompatible program.\n\nTwo " "checkpoint was generated in an incompatible program.\n\nTwo "
"checkpoint references resolved to different objects (%s and %s).") "checkpoint references resolved to different objects (%s and %s)."),
% (current_assignment, trackable)) current_assignment, trackable)
return False # Not a new assignment return False # Not a new assignment
def is_simple_variable(self): def is_simple_variable(self):
"""Determine whether this value is restorable with a Tensor initializer.""" """Determine whether this value is restorable with a Tensor initializer."""
attributes = self.object_proto.attributes attributes = self.object_proto.attributes
return (len(attributes) == 1 return (len(attributes) == 1 and
and attributes[0].name == VARIABLE_VALUE_KEY attributes[0].name == VARIABLE_VALUE_KEY and
and not self.object_proto.children) not self.object_proto.children)
def value_tensors(self): def value_tensors(self):
"""Create value `Tensor`s for this object's attributes. """Create value `Tensor`s for this object's attributes.
@ -335,8 +338,9 @@ class CheckpointPosition(object):
# If we've already created and cached a SaveableObject for this # If we've already created and cached a SaveableObject for this
# attribute, we can re-use it to avoid re-creating some ops when graph # attribute, we can re-use it to avoid re-creating some ops when graph
# building. # building.
saveable_list = saveables_cache.get( saveable_list = saveables_cache.get(self.trackable,
self.trackable, {}).get(serialized_tensor.name, (None,)) {}).get(serialized_tensor.name,
(None,))
if len(saveable_list) == 1: if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject. # Almost every attribute will have exactly one SaveableObject.
saveable, = saveable_list saveable, = saveable_list
@ -370,8 +374,8 @@ class CheckpointPosition(object):
else: else:
saveable = saveable_factory saveable = saveable_factory
if saveables_cache is not None: if saveables_cache is not None:
saveables_cache.setdefault( saveables_cache.setdefault(self.trackable,
self.trackable, {})[serialized_tensor.name] = [saveable] {})[serialized_tensor.name] = [saveable]
if isinstance(saveable, PythonStateSaveable): if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable) python_saveables.append(saveable)
else: else:
@ -388,11 +392,10 @@ class CheckpointPosition(object):
A list of operations when graph building, or an empty list when executing A list of operations when graph building, or an empty list when executing
eagerly. eagerly.
""" """
(restore_ops, (restore_ops, tensor_saveables,
tensor_saveables,
python_saveables) = self._gather_ops_or_named_saveables() python_saveables) = self._gather_ops_or_named_saveables()
restore_ops.extend(self._checkpoint.restore_saveables( restore_ops.extend(
tensor_saveables, python_saveables)) self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
return restore_ops return restore_ops
@property @property
@ -416,13 +419,11 @@ class CheckpointPosition(object):
_DeferredSlotVariableRestoration = collections.namedtuple( _DeferredSlotVariableRestoration = collections.namedtuple(
"_DeferredSlotVariableRestoration", "_DeferredSlotVariableRestoration", [
[
"original_variable", "original_variable",
"slot_variable_id", "slot_variable_id",
"slot_name", "slot_name",
] ])
)
_SlotVariableRestoration = collections.namedtuple( _SlotVariableRestoration = collections.namedtuple(
"_SlotVariableRestoration", "_SlotVariableRestoration",
@ -446,6 +447,7 @@ def no_automatic_dependency_tracking(method):
Args: Args:
method: The method to decorate. method: The method to decorate.
Returns: Returns:
A decorated method which sets and un-sets automatic dependency tracking for A decorated method which sets and un-sets automatic dependency tracking for
the object the method is called on (not thread safe). the object the method is called on (not thread safe).
@ -595,16 +597,21 @@ class Trackable(object):
Args: Args:
name: The local name of the dependency. name: The local name of the dependency.
Returns: Returns:
A `Trackable` object, or `None` if no dependency by this name was A `Trackable` object, or `None` if no dependency by this name was
found. found.
""" """
return self._self_unconditional_dependency_names.get(name, None) return self._self_unconditional_dependency_names.get(name, None)
def _add_variable_with_custom_getter( def _add_variable_with_custom_getter(self,
self, name, shape=None, dtype=dtypes.float32, name,
initializer=None, getter=None, overwrite=False, shape=None,
**kwargs_for_getter): dtype=dtypes.float32,
initializer=None,
getter=None,
overwrite=False,
**kwargs_for_getter):
"""Restore-on-create for a variable be saved with this `Trackable`. """Restore-on-create for a variable be saved with this `Trackable`.
If the user has requested that this object or another `Trackable` which If the user has requested that this object or another `Trackable` which
@ -640,11 +647,9 @@ class Trackable(object):
name=name, shape=shape) name=name, shape=shape)
else: else:
checkpoint_initializer = None checkpoint_initializer = None
if (checkpoint_initializer is not None if (checkpoint_initializer is not None and
and not ( not (isinstance(initializer, CheckpointInitialValue) and
isinstance(initializer, CheckpointInitialValue) (initializer.restore_uid > checkpoint_initializer.restore_uid))):
and (initializer.restore_uid
> checkpoint_initializer.restore_uid))):
# If multiple Trackable objects are "creating" the same variable # If multiple Trackable objects are "creating" the same variable
# via the magic of custom getters, the one with the highest restore UID # via the magic of custom getters, the one with the highest restore UID
# (the one called last) has to make the final initializer. If another # (the one called last) has to make the final initializer. If another
@ -654,7 +659,10 @@ class Trackable(object):
initializer = checkpoint_initializer initializer = checkpoint_initializer
shape = None shape = None
new_variable = getter( new_variable = getter(
name=name, shape=shape, dtype=dtype, initializer=initializer, name=name,
shape=shape,
dtype=dtype,
initializer=initializer,
**kwargs_for_getter) **kwargs_for_getter)
# If we set an initializer and the variable processed it, tracking will not # If we set an initializer and the variable processed it, tracking will not
@ -662,8 +670,7 @@ class Trackable(object):
# is a non-trivial restoration queued, it will handle that. This also # is a non-trivial restoration queued, it will handle that. This also
# handles slot variables. # handles slot variables.
if not overwrite or isinstance(new_variable, Trackable): if not overwrite or isinstance(new_variable, Trackable):
return self._track_trackable(new_variable, name=name, return self._track_trackable(new_variable, name=name, overwrite=overwrite)
overwrite=overwrite)
else: else:
# TODO(allenl): Some variable types are not yet supported. Remove this # TODO(allenl): Some variable types are not yet supported. Remove this
# fallback once all get_variable() return types are Trackable. # fallback once all get_variable() return types are Trackable.
@ -681,6 +688,7 @@ class Trackable(object):
name: The object-local name of the dependency holding the variable's name: The object-local name of the dependency holding the variable's
value. value.
shape: The shape of the variable being loaded into. shape: The shape of the variable being loaded into.
Returns: Returns:
An callable for use as a variable's initializer/initial_value, or None if An callable for use as a variable's initializer/initial_value, or None if
one should not be set (either because there was no variable with this name one should not be set (either because there was no variable with this name
@ -718,8 +726,8 @@ class Trackable(object):
Args: Args:
trackable: A `Trackable` which this object depends on. trackable: A `Trackable` which this object depends on.
name: A local name for `trackable`, used for loading checkpoints into name: A local name for `trackable`, used for loading checkpoints into the
the correct objects. correct objects.
overwrite: Boolean, whether silently replacing dependencies is OK. Used overwrite: Boolean, whether silently replacing dependencies is OK. Used
for __setattr__, where throwing an error on attribute reassignment would for __setattr__, where throwing an error on attribute reassignment would
be inappropriate. be inappropriate.
@ -734,13 +742,11 @@ class Trackable(object):
""" """
self._maybe_initialize_trackable() self._maybe_initialize_trackable()
if not isinstance(trackable, Trackable): if not isinstance(trackable, Trackable):
raise TypeError( raise TypeError(("Trackable._track_trackable() passed type %s, not a "
("Trackable._track_trackable() passed type %s, not a " "Trackable.") % (type(trackable),))
"Trackable.") % (type(trackable),))
new_reference = TrackableReference(name=name, ref=trackable) new_reference = TrackableReference(name=name, ref=trackable)
current_object = self._lookup_dependency(name) current_object = self._lookup_dependency(name)
if (current_object is not None if (current_object is not None and current_object is not trackable):
and current_object is not trackable):
if not overwrite: if not overwrite:
raise ValueError( raise ValueError(
("Called Trackable._track_trackable() with name='%s', " ("Called Trackable._track_trackable() with name='%s', "
@ -755,8 +761,7 @@ class Trackable(object):
index] = new_reference index] = new_reference
elif current_object is None: elif current_object is None:
self._self_unconditional_checkpoint_dependencies.append(new_reference) self._self_unconditional_checkpoint_dependencies.append(new_reference)
self._handle_deferred_dependencies( self._handle_deferred_dependencies(name=name, trackable=trackable)
name=name, trackable=trackable)
self._self_unconditional_dependency_names[name] = trackable self._self_unconditional_dependency_names[name] = trackable
return trackable return trackable
@ -780,8 +785,7 @@ class Trackable(object):
Args: Args:
name: The name of the dependency within this object (`self`), used to name: The name of the dependency within this object (`self`), used to
match `trackable` with values saved in a checkpoint. match `trackable` with values saved in a checkpoint.
trackable: The Trackable object to restore (inheriting from trackable: The Trackable object to restore (inheriting from `Trackable`).
`Trackable`).
""" """
self._maybe_initialize_trackable() self._maybe_initialize_trackable()
trackable._maybe_initialize_trackable() # pylint: disable=protected-access trackable._maybe_initialize_trackable() # pylint: disable=protected-access
@ -809,15 +813,15 @@ class Trackable(object):
restore_ops = [] restore_ops = []
while visit_queue: while visit_queue:
current_position = visit_queue.popleft() current_position = visit_queue.popleft()
restore_ops.extend(nest.flatten( restore_ops.extend(
current_position.trackable # pylint: disable=protected-access nest.flatten(current_position.trackable # pylint: disable=protected-access
._single_restoration_from_checkpoint_position( ._single_restoration_from_checkpoint_position(
checkpoint_position=current_position, checkpoint_position=current_position,
visit_queue=visit_queue))) visit_queue=visit_queue)))
return restore_ops return restore_ops
def _single_restoration_from_checkpoint_position( def _single_restoration_from_checkpoint_position(self, checkpoint_position,
self, checkpoint_position, visit_queue): visit_queue):
"""Restore this object, and either queue its dependencies or defer them.""" """Restore this object, and either queue its dependencies or defer them."""
self._maybe_initialize_trackable() self._maybe_initialize_trackable()
checkpoint = checkpoint_position.checkpoint checkpoint = checkpoint_position.checkpoint
@ -831,14 +835,13 @@ class Trackable(object):
restore_ops = () restore_ops = ()
for child in checkpoint_position.object_proto.children: for child in checkpoint_position.object_proto.children:
child_position = CheckpointPosition( child_position = CheckpointPosition(
checkpoint=checkpoint, checkpoint=checkpoint, proto_id=child.node_id)
proto_id=child.node_id)
local_object = self._lookup_dependency(child.local_name) local_object = self._lookup_dependency(child.local_name)
if local_object is None: if local_object is None:
# We don't yet have a dependency registered with this name. Save it # We don't yet have a dependency registered with this name. Save it
# in case we do. # in case we do.
self._deferred_dependencies.setdefault(child.local_name, []).append( self._deferred_dependencies.setdefault(child.local_name,
child_position) []).append(child_position)
else: else:
if child_position.bind_object(trackable=local_object): if child_position.bind_object(trackable=local_object):
# This object's correspondence is new, so dependencies need to be # This object's correspondence is new, so dependencies need to be
@ -853,7 +856,8 @@ class Trackable(object):
Keys in the returned dictionary are local to this object and in a separate Keys in the returned dictionary are local to this object and in a separate
namespace from dependencies. Values may either be `SaveableObject` factories namespace from dependencies. Values may either be `SaveableObject` factories
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s or variables easily converted to `SaveableObject`s (as in
`tf.compat.v1.train.Saver`'s
`var_list` constructor argument). `var_list` constructor argument).
`SaveableObjects` have a name set, which Trackable needs to generate `SaveableObjects` have a name set, which Trackable needs to generate
@ -861,7 +865,8 @@ class Trackable(object):
should return a dictionary of callables which take `name` arguments and should return a dictionary of callables which take `name` arguments and
return `SaveableObjects` with that name. return `SaveableObjects` with that name.
If this object may also be passed to the global-name-based `tf.train.Saver`, If this object may also be passed to the global-name-based
`tf.compat.v1.train.Saver`,
the returned callables should have a default value for their name argument the returned callables should have a default value for their name argument
(i.e. be callable with no arguments). (i.e. be callable with no arguments).
@ -884,6 +889,7 @@ class Trackable(object):
except NotImplementedError: except NotImplementedError:
return {} return {}
weak_self = weakref.ref(self) weak_self = weakref.ref(self)
def _state_callback(): def _state_callback():
"""Serializes `self.get_config()` for saving.""" """Serializes `self.get_config()` for saving."""
dereferenced_self = weak_self() dereferenced_self = weak_self()
@ -898,9 +904,12 @@ class Trackable(object):
return "" return ""
else: else:
return "" return ""
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable, return {
state_callback=_state_callback)} OBJECT_CONFIG_JSON_KEY:
functools.partial(
PythonStringStateSaveable, state_callback=_state_callback)
}
def _list_functions_for_serialization(self): def _list_functions_for_serialization(self):
"""Lists the functions of this trackable to serialize. """Lists the functions of this trackable to serialize.

View File

@ -61,8 +61,8 @@ class _CheckpointRestoreCoordinator(object):
"""Specify the checkpoint being loaded. """Specify the checkpoint being loaded.
Args: Args:
object_graph_proto: The TrackableObjectGraph protocol buffer object_graph_proto: The TrackableObjectGraph protocol buffer associated
associated with this checkpoint. with this checkpoint.
save_path: A string, the path to the checkpoint, as returned by save_path: A string, the path to the checkpoint, as returned by
`tf.train.latest_checkpoint`. `tf.train.latest_checkpoint`.
save_path_tensor: A string `Tensor` which contains or will be fed the save save_path_tensor: A string `Tensor` which contains or will be fed the save
@ -142,12 +142,10 @@ class _CheckpointRestoreCoordinator(object):
""" """
restore_ops = [] restore_ops = []
# Eagerly run restorations for Python state. # Eagerly run restorations for Python state.
reader = pywrap_tensorflow.NewCheckpointReader( reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
self.save_path_string)
for saveable in python_saveables: for saveable in python_saveables:
spec_names = [spec.name for spec in saveable.specs] spec_names = [spec.name for spec in saveable.specs]
saveable.python_restore( saveable.python_restore([reader.get_tensor(name) for name in spec_names])
[reader.get_tensor(name) for name in spec_names])
# If we have new SaveableObjects, extract and cache restore ops. # If we have new SaveableObjects, extract and cache restore ops.
if tensor_saveables: if tensor_saveables:
@ -205,14 +203,13 @@ class _NameBasedRestoreCoordinator(object):
# whether it's optional to restore it. If it's optional we don't need # whether it's optional to restore it. If it's optional we don't need
# to make assertions fail. # to make assertions fail.
if not saveable_factory("").optional_restore: if not saveable_factory("").optional_restore:
self.unused_attributes.setdefault(trackable, []).append( self.unused_attributes.setdefault(trackable,
attribute_name) []).append(attribute_name)
continue continue
else: else:
saveable = saveable_factory saveable = saveable_factory
names_to_saveables = saveable_object_util.op_list_to_dict( names_to_saveables = saveable_object_util.op_list_to_dict(
[saveable], [saveable], convert_variable_to_tensor=False)
convert_variable_to_tensor=False)
for name, op in names_to_saveables.items(): for name, op in names_to_saveables.items():
for saveable_object in saveable_object_util.saveable_objects_for_op( for saveable_object in saveable_object_util.saveable_objects_for_op(
op=op, name=name): op=op, name=name):
@ -224,8 +221,7 @@ class _NameBasedRestoreCoordinator(object):
# run_restore_ops/initialize_or_restore on the status object for name-based # run_restore_ops/initialize_or_restore on the status object for name-based
# checkpoints. # checkpoints.
assert context.executing_eagerly() assert context.executing_eagerly()
for saveable in self.globally_named_object_attributes( for saveable in self.globally_named_object_attributes(trackable):
trackable):
restored_tensors = [] restored_tensors = []
tensor_missing = False tensor_missing = False
for spec in saveable.specs: for spec in saveable.specs:
@ -248,14 +244,18 @@ class _NameBasedRestoreCoordinator(object):
# Ignores values missing from the checkpoint, as with object-based # Ignores values missing from the checkpoint, as with object-based
# restore. Status assertions can be used to check exact matches, # restore. Status assertions can be used to check exact matches,
# although it's unlikely to ever happen for name-based checkpoints. # although it's unlikely to ever happen for name-based checkpoints.
saveable.restore(restored_tensors=restored_tensors, saveable.restore(
restored_shapes=None) restored_tensors=restored_tensors, restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange # TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
# or consolidating the implementation with get_variable. # or consolidating the implementation with get_variable.
def _default_getter(name, shape, dtype, initializer=None, def _default_getter(name,
partition_info=None, **kwargs): shape,
dtype,
initializer=None,
partition_info=None,
**kwargs):
"""A pared-down version of get_variable which does not reuse variables.""" """A pared-down version of get_variable which does not reuse variables."""
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
shape_object = tensor_shape.as_shape(shape) shape_object = tensor_shape.as_shape(shape)
@ -263,7 +263,9 @@ def _default_getter(name, shape, dtype, initializer=None,
if initializer is None: if initializer is None:
initializer, initializing_from_value = ( initializer, initializing_from_value = (
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
name=name, shape=shape_object, dtype=dtype)) name=name,
shape=shape_object,
dtype=dtype))
else: else:
initializing_from_value = not callable(initializer) initializing_from_value = not callable(initializer)
# Same logic as get_variable # Same logic as get_variable
@ -276,24 +278,33 @@ def _default_getter(name, shape, dtype, initializer=None,
# Instantiate initializer if provided initializer is a type object. # Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)): if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype) initializer = initializer(dtype=dtype)
def initial_value(): def initial_value():
return initializer( return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info) shape_object.as_list(), dtype=dtype, partition_info=partition_info)
return variables.VariableV1( return variables.VariableV1(
initial_value=initial_value, initial_value=initial_value,
name=name, name=name,
dtype=variable_dtype, dtype=variable_dtype,
use_resource=True, use_resource=True,
**kwargs **kwargs)
)
def add_variable(trackable, name, shape=None, dtype=dtypes.float32, def add_variable(trackable,
initializer=None, trainable=True): name,
shape=None,
dtype=dtypes.float32,
initializer=None,
trainable=True):
"""Add a variable to a Trackable with no scope influence.""" """Add a variable to a Trackable with no scope influence."""
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
name=name, shape=shape, dtype=dtype, name=name,
initializer=initializer, getter=_default_getter, trainable=trainable) shape=shape,
dtype=dtype,
initializer=initializer,
getter=_default_getter,
trainable=trainable)
def object_metadata(save_path): def object_metadata(save_path):
@ -313,6 +324,7 @@ def object_metadata(save_path):
Args: Args:
save_path: The path to the checkpoint, as returned by `save` or save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. `tf.train.latest_checkpoint`.
Returns: Returns:
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer. A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
Raises: Raises:
@ -320,16 +332,14 @@ def object_metadata(save_path):
""" """
reader = pywrap_tensorflow.NewCheckpointReader(save_path) reader = pywrap_tensorflow.NewCheckpointReader(save_path)
try: try:
object_graph_string = reader.get_tensor( object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError: except errors_impl.NotFoundError:
raise ValueError( raise ValueError(
('The specified checkpoint "%s" does not appear to be object-based (it ' ('The specified checkpoint "%s" does not appear to be object-based (it '
'is missing the key "%s"). Likely it was created with a name-based ' 'is missing the key "%s"). Likely it was created with a name-based '
'saver and does not contain an object dependency graph.') % ( "saver and does not contain an object dependency graph.") %
save_path, base.OBJECT_GRAPH_PROTO_KEY)) (save_path, base.OBJECT_GRAPH_PROTO_KEY))
object_graph_proto = ( object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string) object_graph_proto.ParseFromString(object_graph_string)
return object_graph_proto return object_graph_proto
@ -343,8 +353,8 @@ def list_objects(root_trackable):
(i.e. if they would be saved with a checkpoint). (i.e. if they would be saved with a checkpoint).
Args: Args:
root_trackable: A `Trackable` object whose dependencies should be root_trackable: A `Trackable` object whose dependencies should be flattened.
flattened.
Returns: Returns:
A flat list of objects. A flat list of objects.
""" """
@ -362,12 +372,16 @@ def gather_initializers(root_trackable):
Args: Args:
root_trackable: A `Trackable` object to gather initializers for. root_trackable: A `Trackable` object to gather initializers for.
Returns: Returns:
A list of initialization ops. A list of initialization ops.
""" """
trackable_objects = list_objects(root_trackable) trackable_objects = list_objects(root_trackable)
return [c.initializer for c in trackable_objects return [
if hasattr(c, "initializer") and c.initializer is not None] c.initializer
for c in trackable_objects
if hasattr(c, "initializer") and c.initializer is not None
]
@tf_contextlib.contextmanager @tf_contextlib.contextmanager
@ -380,7 +394,7 @@ def capture_dependencies(template):
object to add dependencies on variables created in a block of code which is object to add dependencies on variables created in a block of code which is
not aware of object-based saving (and instead uses variable names not aware of object-based saving (and instead uses variable names
heavily). This is how `Template` objects add dependencies on variables and heavily). This is how `Template` objects add dependencies on variables and
sub-`Template`s. Where possible, use `tf.make_template` directly. sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
Args: Args:
template: The `Template` object to register dependencies with. template: The `Template` object to register dependencies with.
@ -390,8 +404,11 @@ def capture_dependencies(template):
""" """
name_prefix = template.variable_scope.name name_prefix = template.variable_scope.name
def _trackable_custom_creator(next_creator, name, initial_value, def _trackable_custom_creator(next_creator,
trackable_parent=None, **kwargs): name,
initial_value,
trackable_parent=None,
**kwargs):
"""A variable creation hook which adds Trackable dependencies. """A variable creation hook which adds Trackable dependencies.
Set for example during a `Template`'s first wrapped function Set for example during a `Template`'s first wrapped function
@ -415,21 +432,20 @@ def capture_dependencies(template):
initial_value: See `variable_scope.variable_creator_scope`. Taken initial_value: See `variable_scope.variable_creator_scope`. Taken
explicitly so the argument can be re-named and used with explicitly so the argument can be re-named and used with
`Trackable._add_variable_with_custom_getter`. `Trackable._add_variable_with_custom_getter`.
trackable_parent: If not None, a more deeply nested trackable trackable_parent: If not None, a more deeply nested trackable object and
object and its name prefix which were passed to `capture_dependencies` its name prefix which were passed to `capture_dependencies` to add a
to add a dependency on (rather than depending on the variable directly). dependency on (rather than depending on the variable directly).
**kwargs: Passed through to the next creator. **kwargs: Passed through to the next creator.
Returns: Returns:
The output of `next_creator`: the fetched/created variable object. The output of `next_creator`: the fetched/created variable object.
""" """
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs): def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate. # we don't want to propagate.
return next_creator( return next_creator(initial_value=initializer, name=name, **inner_kwargs)
initial_value=initializer,
name=name,
**inner_kwargs)
if name is not None and name.startswith(name_prefix): if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:] scope_stripped_name = name[len(name_prefix) + 1:]
if not trackable_parent: if not trackable_parent:
@ -450,8 +466,10 @@ def capture_dependencies(template):
name=parent_name_prefix[len(name_prefix) + 1:], name=parent_name_prefix[len(name_prefix) + 1:],
overwrite=True) overwrite=True)
return next_creator( return next_creator(
name=name, initial_value=initial_value, name=name,
trackable_parent=(template, name_prefix), **kwargs) initial_value=initial_value,
trackable_parent=(template, name_prefix),
**kwargs)
with variable_scope.variable_creator_scope(_trackable_custom_creator): with variable_scope.variable_creator_scope(_trackable_custom_creator):
yield yield
@ -490,9 +508,8 @@ def streaming_restore(status, session=None):
"""When graph building, runs restore ops as soon as they come in. """When graph building, runs restore ops as soon as they come in.
Args: Args:
status: A _LoadStatus objects from an object-based saver's status: A _LoadStatus objects from an object-based saver's restore().
restore(). Streaming restore from name-based checkpoints is not currently Streaming restore from name-based checkpoints is not currently supported.
supported.
session: A session to run new restore ops in. session: A session to run new restore ops in.
""" """
if context.executing_eagerly(): if context.executing_eagerly():
@ -553,13 +570,13 @@ class CheckpointLoadStatus(_LoadStatus):
if self._checkpoint.slot_restorations: if self._checkpoint.slot_restorations:
# Sanity check; this collection should be clear if everything has been # Sanity check; this collection should be clear if everything has been
# restored. # restored.
raise AssertionError("Unresolved slot restorations: %s" % ( raise AssertionError("Unresolved slot restorations: %s" %
self._checkpoint.slot_restorations,)) (self._checkpoint.slot_restorations,))
if self._checkpoint.unused_attributes: if self._checkpoint.unused_attributes:
raise AssertionError( raise AssertionError(
("Unused attributes in these objects (the attributes exist in the " ("Unused attributes in these objects (the attributes exist in the "
"checkpoint but not in the objects): %s") % ( "checkpoint but not in the objects): %s") %
list(self._checkpoint.unused_attributes.items()),)) (list(self._checkpoint.unused_attributes.items()),))
return self return self
def assert_existing_objects_matched(self): def assert_existing_objects_matched(self):
@ -581,10 +598,10 @@ class CheckpointLoadStatus(_LoadStatus):
""" """
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
trackable = self._checkpoint.object_by_proto_id.get(node_id, None) trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
if (trackable is not None if (trackable is not None and
and trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
raise AssertionError( raise AssertionError("Object not assigned a value from checkpoint: %s" %
"Object not assigned a value from checkpoint: %s" % (node,)) (node,))
for trackable_object in self._graph_view.list_objects(): for trackable_object in self._graph_view.list_objects():
# Remove data structures that do not contain any variables from # Remove data structures that do not contain any variables from
# restoration checks. # restoration checks.
@ -594,14 +611,14 @@ class CheckpointLoadStatus(_LoadStatus):
continue continue
self._checkpoint.all_python_objects.add(trackable_object) self._checkpoint.all_python_objects.add(trackable_object)
unused_python_objects = ( unused_python_objects = (
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) -
- object_identity.ObjectIdentitySet( object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values())) self._checkpoint.object_by_proto_id.values()))
if unused_python_objects: if unused_python_objects:
raise AssertionError( raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely " ("Some Python objects were not bound to checkpointed values, likely "
"due to changes in the Python program: %s") "due to changes in the Python program: %s") %
% (list(unused_python_objects),)) (list(unused_python_objects),))
return self return self
def assert_nontrivial_match(self): def assert_nontrivial_match(self):
@ -610,8 +627,7 @@ class CheckpointLoadStatus(_LoadStatus):
self._checkpoint.all_python_objects.add(trackable_object) self._checkpoint.all_python_objects.add(trackable_object)
if len(self._checkpoint.object_by_proto_id) <= 1: if len(self._checkpoint.object_by_proto_id) <= 1:
unused_python_objects = ( unused_python_objects = (
object_identity.ObjectIdentitySet( object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
self._checkpoint.all_python_objects)
- object_identity.ObjectIdentitySet( - object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values())) self._checkpoint.object_by_proto_id.values()))
if unused_python_objects: if unused_python_objects:
@ -622,8 +638,8 @@ class CheckpointLoadStatus(_LoadStatus):
"checkpointed value: %s") % (list(unused_python_objects),)) "checkpointed value: %s") % (list(unused_python_objects),))
else: else:
raise AssertionError( raise AssertionError(
"Nothing to load. No dependencies have been added to %s yet." % ( "Nothing to load. No dependencies have been added to %s yet." %
self._graph_view.root,)) (self._graph_view.root,))
return self return self
def run_restore_ops(self, session=None): def run_restore_ops(self, session=None):
@ -760,8 +776,8 @@ class NameBasedSaverStatus(_LoadStatus):
unused_attributes = dict(self._checkpoint.unused_attributes) unused_attributes = dict(self._checkpoint.unused_attributes)
if unused_attributes: if unused_attributes:
raise AssertionError( raise AssertionError(
"Some objects had attributes which were not restored: %s" "Some objects had attributes which were not restored: %s" %
% (unused_attributes,)) (unused_attributes,))
for trackable in self._graph_view.list_objects(): for trackable in self._graph_view.list_objects():
# pylint: disable=protected-access # pylint: disable=protected-access
trackable._maybe_initialize_trackable() trackable._maybe_initialize_trackable()
@ -799,12 +815,11 @@ class NameBasedSaverStatus(_LoadStatus):
continue continue
# pylint: enable=protected-access # pylint: enable=protected-access
saveable_objects.extend( saveable_objects.extend(
self._checkpoint.globally_named_object_attributes( self._checkpoint.globally_named_object_attributes(trackable))
trackable))
return saveable_objects return saveable_objects
def run_restore_ops(self, session=None): def run_restore_ops(self, session=None):
"""Load the name-based training checkpoint using a new `tf.train.Saver`.""" """Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
if context.executing_eagerly(): if context.executing_eagerly():
return # Nothing to do, variables are restored on creation. return # Nothing to do, variables are restored on creation.
if session is None: if session is None:
@ -840,7 +855,8 @@ class TrackableSaver(object):
"""Saves and restores a `Trackable` object and its dependencies. """Saves and restores a `Trackable` object and its dependencies.
See `Trackable` for details of dependency management. `Saver` wraps See `Trackable` for details of dependency management. `Saver` wraps
`tf.train.Saver` for saving, including extra information about the graph of `tf.compat.v1.train.Saver` for saving, including extra information about the
graph of
dependencies between Python objects. When restoring, it uses this information dependencies between Python objects. When restoring, it uses this information
about the save-time dependency graph to more robustly match objects with their about the save-time dependency graph to more robustly match objects with their
checkpointed values. When executing eagerly, it supports restoring variables checkpointed values. When executing eagerly, it supports restoring variables
@ -851,7 +867,8 @@ class TrackableSaver(object):
checkpoint was written. To avoid breaking existing checkpoints when modifying checkpoint was written. To avoid breaking existing checkpoints when modifying
a class, dependency names (the names of attributes to which `Trackable` a class, dependency names (the names of attributes to which `Trackable`
objects are assigned) may not change. These names are local to objects, in objects are assigned) may not change. These names are local to objects, in
contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and contrast to the `Variable.name`-based save/restore from
`tf.compat.v1.train.Saver`, and
so allow additional program transformations. so allow additional program transformations.
""" """
@ -877,8 +894,7 @@ class TrackableSaver(object):
self._restore_op_cache = {} self._restore_op_cache = {}
self._graph_view = graph_view self._graph_view = graph_view
def _gather_saveables( def _gather_saveables(self, object_graph_tensor=None):
self, object_graph_tensor=None):
"""Wraps _serialize_object_graph to include the object graph proto.""" """Wraps _serialize_object_graph to include the object graph proto."""
(named_saveable_objects, graph_proto, (named_saveable_objects, graph_proto,
feed_additions) = self._graph_view.serialize_object_graph() feed_additions) = self._graph_view.serialize_object_graph()
@ -892,14 +908,12 @@ class TrackableSaver(object):
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
named_saveable_objects.append( named_saveable_objects.append(
base.NoRestoreSaveable( base.NoRestoreSaveable(
tensor=object_graph_tensor, tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects, graph_proto, feed_additions return named_saveable_objects, graph_proto, feed_additions
def _save_cached_when_graph_building( def _save_cached_when_graph_building(self,
self, file_prefix,
file_prefix, object_graph_tensor=None):
object_graph_tensor=None):
"""Create or retrieve save ops. """Create or retrieve save ops.
Args: Args:
@ -921,8 +935,7 @@ class TrackableSaver(object):
# save() is called so they pick up new Tensors passed to their # save() is called so they pick up new Tensors passed to their
# constructors. That means the Saver needs to be copied with a new # constructors. That means the Saver needs to be copied with a new
# var_list. # var_list.
or context.executing_eagerly() or context.executing_eagerly() or ops.inside_function()):
or ops.inside_function()):
saver = functional_saver.MultiDeviceSaver(named_saveable_objects) saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
save_op = saver.save(file_prefix) save_op = saver.save(file_prefix)
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
@ -954,8 +967,8 @@ class TrackableSaver(object):
The full path to the checkpoint. The full path to the checkpoint.
""" """
feed_dict = {} feed_dict = {}
use_session = (not context.executing_eagerly() use_session = (not context.executing_eagerly() and
and not ops.inside_function()) not ops.inside_function())
if checkpoint_number: if checkpoint_number:
file_prefix = "%s-%d" % (file_prefix, checkpoint_number) file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
if use_session: if use_session:
@ -976,8 +989,7 @@ class TrackableSaver(object):
file_io.recursive_create_dir(os.path.dirname(file_prefix)) file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building( save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor, file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
object_graph_tensor=object_graph_tensor)
if new_feed_additions: if new_feed_additions:
feed_dict.update(new_feed_additions) feed_dict.update(new_feed_additions)
if not use_session: if not use_session:
@ -1024,7 +1036,7 @@ class TrackableSaver(object):
If the checkpoint has not been consumed completely, then the list of restore If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. There is no deferred loading, and names are used to match method. There is no deferred loading, and names are used to match
variables. No restore ops are created/run until `run_restore_ops()` or variables. No restore ops are created/run until `run_restore_ops()` or
`initialize_or_restore()` are called on the returned status object, even `initialize_or_restore()` are called on the returned status object, even
@ -1035,9 +1047,9 @@ class TrackableSaver(object):
save_path: The path to the checkpoint, as returned by `save` or save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest `tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency object which may run initializers for objects in the dependency graph.
graph. If the checkpoint was written by the name-based `tf.train.Saver`, If the checkpoint was written by the name-based
names are used to match variables. `tf.compat.v1.train.Saver`, names are used to match variables.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
@ -1057,8 +1069,7 @@ class TrackableSaver(object):
else: else:
dtype_map = reader.get_variable_to_dtype_map() dtype_map = reader.get_variable_to_dtype_map()
try: try:
object_graph_string = reader.get_tensor( object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError: except errors_impl.NotFoundError:
# The object graph proto does not exist in this checkpoint. Try the # The object graph proto does not exist in this checkpoint. Try the
# name-based compatibility mode. # name-based compatibility mode.
@ -1069,8 +1080,7 @@ class TrackableSaver(object):
# pylint: disable=protected-access # pylint: disable=protected-access
existing_trackable._maybe_initialize_trackable() existing_trackable._maybe_initialize_trackable()
existing_trackable._name_based_restores.add(restore_coordinator) existing_trackable._name_based_restores.add(restore_coordinator)
existing_trackable._name_based_attribute_restore( existing_trackable._name_based_attribute_restore(restore_coordinator)
restore_coordinator)
# pylint: enable=protected-access # pylint: enable=protected-access
return NameBasedSaverStatus( return NameBasedSaverStatus(
restore_coordinator, graph_view=self._graph_view) restore_coordinator, graph_view=self._graph_view)
@ -1085,8 +1095,7 @@ class TrackableSaver(object):
with ops.device("/cpu:0"): with ops.device("/cpu:0"):
file_prefix_tensor = constant_op.constant(save_path) file_prefix_tensor = constant_op.constant(save_path)
file_prefix_feed_dict = None file_prefix_feed_dict = None
object_graph_proto = ( object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string) object_graph_proto.ParseFromString(object_graph_string)
checkpoint = _CheckpointRestoreCoordinator( checkpoint = _CheckpointRestoreCoordinator(
object_graph_proto=object_graph_proto, object_graph_proto=object_graph_proto,
@ -1094,8 +1103,8 @@ class TrackableSaver(object):
save_path_tensor=file_prefix_tensor, save_path_tensor=file_prefix_tensor,
restore_op_cache=self._restore_op_cache, restore_op_cache=self._restore_op_cache,
graph_view=self._graph_view) graph_view=self._graph_view)
base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore( base.CheckpointPosition(
self._graph_view.root) checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
load_status = CheckpointLoadStatus( load_status = CheckpointLoadStatus(
checkpoint, checkpoint,
graph_view=self._graph_view, graph_view=self._graph_view,
@ -1104,7 +1113,7 @@ class TrackableSaver(object):
def frozen_saver(root_trackable): def frozen_saver(root_trackable):
"""Creates a static `tf.train.Saver` from a trackable object. """Creates a static `tf.compat.v1.train.Saver` from a trackable object.
The returned `Saver` saves object-based checkpoints, but these checkpoints The returned `Saver` saves object-based checkpoints, but these checkpoints
will no longer reflect structural changes to the object graph, only changes to will no longer reflect structural changes to the object graph, only changes to
@ -1135,9 +1144,9 @@ def saver_with_op_caching(obj):
saveables_cache = None saveables_cache = None
else: else:
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
return TrackableSaver(graph_view_lib.ObjectGraphView( return TrackableSaver(
weakref.ref(obj), graph_view_lib.ObjectGraphView(
saveables_cache=saveables_cache)) weakref.ref(obj), saveables_cache=saveables_cache))
# Mentions graph building / Sessions. The v2 version is below. # Mentions graph building / Sessions. The v2 version is below.
@ -1146,7 +1155,7 @@ class CheckpointV1(tracking.AutoTrackable):
"""Groups trackable objects, saving and restoring them. """Groups trackable objects, saving and restoring them.
`Checkpoint`'s constructor accepts keyword arguments whose values are types `Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.train.Optimizer` that contain trackable state, such as `tf.compat.v1.train.Optimizer`
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
`tf.keras.Model` implementations. It saves these values with a checkpoint, and `tf.keras.Model` implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints. maintains a `save_counter` for numbering checkpoints.
@ -1164,7 +1173,7 @@ class CheckpointV1(tracking.AutoTrackable):
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... ) train_op = optimizer.minimize( ... )
status.assert_consumed() # Optional sanity checks. status.assert_consumed() # Optional sanity checks.
with tf.Session() as session: with tf.compat.v1.Session() as session:
# Use the Session to restore variables, or initialize them if # Use the Session to restore variables, or initialize them if
# tf.train.latest_checkpoint returned None. # tf.train.latest_checkpoint returned None.
status.initialize_or_restore(session) status.initialize_or_restore(session)
@ -1179,7 +1188,7 @@ class CheckpointV1(tracking.AutoTrackable):
import tensorflow as tf import tensorflow as tf
import os import os
tf.enable_eager_execution() tf.compat.v1.enable_eager_execution()
checkpoint_directory = "/tmp/training_checkpoints" checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@ -1193,13 +1202,14 @@ class CheckpointV1(tracking.AutoTrackable):
``` ```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based `Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to `tf.train.Saver` which writes and reads checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
`variable.name` based checkpoints. Object-based checkpointing saves a graph of `variable.name` based checkpoints. Object-based checkpointing saves a graph of
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s, dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
etc.) with named edges, and this graph is used to match variables when etc.) with named edges, and this graph is used to match variables when
restoring a checkpoint. It can be more robust to changes in the Python restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables when executing program, and helps to support restore-on-create for variables when executing
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code. eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
code.
`Checkpoint` objects have dependencies on the objects passed as keyword `Checkpoint` objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is arguments to their constructors, and each dependency is given a name that is
@ -1244,6 +1254,7 @@ class CheckpointV1(tracking.AutoTrackable):
Args: Args:
**kwargs: Keyword arguments are set as attributes of this object, and are **kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Values must be trackable objects. saved with the checkpoint. Values must be trackable objects.
Raises: Raises:
ValueError: If objects in `kwargs` are not trackable. ValueError: If objects in `kwargs` are not trackable.
""" """
@ -1269,8 +1280,12 @@ class CheckpointV1(tracking.AutoTrackable):
# add_variable creates a dependency named "save_counter"; NoDependency # add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter". # prevents creating a second dependency named "_save_counter".
self._save_counter = data_structures.NoDependency( self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0, add_variable(
dtype=dtypes.int64, trainable=False)) self,
name="save_counter",
initializer=0,
dtype=dtypes.int64,
trainable=False))
def write(self, file_prefix, session=None): def write(self, file_prefix, session=None):
"""Writes a training checkpoint. """Writes a training checkpoint.
@ -1294,9 +1309,7 @@ class CheckpointV1(tracking.AutoTrackable):
Returns: Returns:
The full path to the checkpoint (i.e. `file_prefix`). The full path to the checkpoint (i.e. `file_prefix`).
""" """
output = self._saver.save( output = self._saver.save(file_prefix=file_prefix, session=session)
file_prefix=file_prefix,
session=session)
if tensor_util.is_tensor(output): if tensor_util.is_tensor(output):
if context.executing_eagerly(): if context.executing_eagerly():
return compat.as_str(output.numpy()) return compat.as_str(output.numpy())
@ -1370,8 +1383,8 @@ class CheckpointV1(tracking.AutoTrackable):
checkpoint_number = session.run(self._save_assign_op) checkpoint_number = session.run(self._save_assign_op)
else: else:
checkpoint_number = assign_op.numpy() checkpoint_number = assign_op.numpy()
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number), file_path = self.write(
session=session) "%s-%d" % (file_prefix, checkpoint_number), session=session)
checkpoint_management.update_checkpoint_state_internal( checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(file_prefix), save_dir=os.path.dirname(file_prefix),
model_checkpoint_path=file_path, model_checkpoint_path=file_path,
@ -1417,7 +1430,7 @@ class CheckpointV1(tracking.AutoTrackable):
If the checkpoint has not been consumed completely, then the list of restore If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph. ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation returned status object when graph building, but there is restore-on-creation
@ -1428,9 +1441,9 @@ class CheckpointV1(tracking.AutoTrackable):
save_path: The path to the checkpoint, as returned by `save` or save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest `tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency object which may run initializers for objects in the dependency graph.
graph. If the checkpoint was written by the name-based `tf.train.Saver`, If the checkpoint was written by the name-based
names are used to match variables. `tf.compat.v1.train.Saver`, names are used to match variables.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
@ -1453,7 +1466,8 @@ class CheckpointV1(tracking.AutoTrackable):
built, and so has not created any variables, will pass this assertion built, and so has not created any variables, will pass this assertion
but fail `assert_consumed`. Useful when loading part of a larger but fail `assert_consumed`. Useful when loading part of a larger
checkpoint into a new Python program, e.g. a training checkpoint with checkpoint into a new Python program, e.g. a training checkpoint with
a `tf.train.Optimizer` was saved but only the state required for a `tf.compat.v1.train.Optimizer` was saved but only the state required
for
inference is being loaded. This method returns the status object, and inference is being loaded. This method returns the status object, and
so may be chained with `initialize_or_restore` or `run_restore_ops`. so may be chained with `initialize_or_restore` or `run_restore_ops`.
@ -1488,7 +1502,7 @@ class Checkpoint(tracking.AutoTrackable):
"""Groups trackable objects, saving and restoring them. """Groups trackable objects, saving and restoring them.
`Checkpoint`'s constructor accepts keyword arguments whose values are types `Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.train.Optimizer` that contain trackable state, such as `tf.keras.optimizers.Optimizer`
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
`tf.keras.Model` implementations. It saves these values with a checkpoint, and `tf.keras.Model` implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints. maintains a `save_counter` for numbering checkpoints.
@ -1511,7 +1525,8 @@ class Checkpoint(tracking.AutoTrackable):
``` ```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based `Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to TensorFlow 1.x's `tf.train.Saver` which writes and checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
writes and
reads `variable.name` based checkpoints. Object-based checkpointing saves a reads `variable.name` based checkpoints. Object-based checkpointing saves a
graph of dependencies between Python objects (`Layer`s, `Optimizer`s, graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
`Variable`s, etc.) with named edges, and this graph is used to match variables `Variable`s, etc.) with named edges, and this graph is used to match variables
@ -1561,6 +1576,7 @@ class Checkpoint(tracking.AutoTrackable):
Args: Args:
**kwargs: Keyword arguments are set as attributes of this object, and are **kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Values must be trackable objects. saved with the checkpoint. Values must be trackable objects.
Raises: Raises:
ValueError: If objects in `kwargs` are not trackable. ValueError: If objects in `kwargs` are not trackable.
""" """
@ -1586,8 +1602,12 @@ class Checkpoint(tracking.AutoTrackable):
# add_variable creates a dependency named "save_counter"; NoDependency # add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter". # prevents creating a second dependency named "_save_counter".
self._save_counter = data_structures.NoDependency( self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0, add_variable(
dtype=dtypes.int64, trainable=False)) self,
name="save_counter",
initializer=0,
dtype=dtypes.int64,
trainable=False))
def write(self, file_prefix): def write(self, file_prefix):
"""Writes a training checkpoint. """Writes a training checkpoint.
@ -1608,8 +1628,7 @@ class Checkpoint(tracking.AutoTrackable):
Returns: Returns:
The full path to the checkpoint (i.e. `file_prefix`). The full path to the checkpoint (i.e. `file_prefix`).
""" """
output = self._saver.save( output = self._saver.save(file_prefix=file_prefix)
file_prefix=file_prefix)
if tensor_util.is_tensor(output): if tensor_util.is_tensor(output):
if context.executing_eagerly(): if context.executing_eagerly():
return compat.as_str(output.numpy()) return compat.as_str(output.numpy())
@ -1711,7 +1730,8 @@ class Checkpoint(tracking.AutoTrackable):
were not found in the checkpoint, or if any checkpointed values do not have were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object. a matching Python object.
Name-based `tf.train.Saver` checkpoints from TensorFlow 1.x can be loaded Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
loaded
using this method. Names are used to match variables. Re-encode name-based using this method. Names are used to match variables. Re-encode name-based
checkpoints using `tf.train.Checkpoint.save` as soon as possible. checkpoints using `tf.train.Checkpoint.save` as soon as possible.
@ -1719,9 +1739,9 @@ class Checkpoint(tracking.AutoTrackable):
save_path: The path to the checkpoint, as returned by `save` or save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest `tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency object which may run initializers for objects in the dependency graph.
graph. If the checkpoint was written by the name-based `tf.train.Saver`, If the checkpoint was written by the name-based
names are used to match variables. `tf.compat.v1.train.Saver`, names are used to match variables.
Returns: Returns:
A load status object, which can be used to make assertions about the A load status object, which can be used to make assertions about the
@ -1744,7 +1764,8 @@ class Checkpoint(tracking.AutoTrackable):
built, and so has not created any variables, will pass this assertion built, and so has not created any variables, will pass this assertion
but fail `assert_consumed`. Useful when loading part of a larger but fail `assert_consumed`. Useful when loading part of a larger
checkpoint into a new Python program, e.g. a training checkpoint with checkpoint into a new Python program, e.g. a training checkpoint with
a `tf.train.Optimizer` was saved but only the state required for a `tf.compat.v1.train.Optimizer` was saved but only the state required
for
inference is being loaded. This method returns the status object, and inference is being loaded. This method returns the status object, and
so may be chained with other assertions. so may be chained with other assertions.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utility functions for training.""" """Utility functions for training."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -34,7 +33,6 @@ from tensorflow.python.util.tf_export import tf_export
# collection keys. # collection keys.
GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
# TODO(drpng): remove this after legacy uses are resolved. # TODO(drpng): remove this after legacy uses are resolved.
write_graph = graph_io.write_graph write_graph = graph_io.write_graph
@ -47,11 +45,12 @@ def global_step(sess, global_step_tensor):
# Create a variable to hold the global_step. # Create a variable to hold the global_step.
global_step_tensor = tf.Variable(10, trainable=False, name='global_step') global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
# Create a session. # Create a session.
sess = tf.Session() sess = tf.compat.v1.Session()
# Initialize the variable # Initialize the variable
sess.run(global_step_tensor.initializer) sess.run(global_step_tensor.initializer)
# Get the variable value. # Get the variable value.
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor)) print('global_step: %s' % tf.compat.v1.train.global_step(sess,
global_step_tensor))
global_step: 10 global_step: 10
``` ```
@ -109,8 +108,8 @@ def create_global_step(graph=None):
"""Create global step tensor in graph. """Create global step tensor in graph.
Args: Args:
graph: The graph in which to create the global step tensor. If missing, graph: The graph in which to create the global step tensor. If missing, use
use default graph. default graph.
Returns: Returns:
Global step tensor. Global step tensor.
@ -130,8 +129,9 @@ def create_global_step(graph=None):
initializer=init_ops.zeros_initializer(), initializer=init_ops.zeros_initializer(),
trainable=False, trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
collections=[ops.GraphKeys.GLOBAL_VARIABLES, collections=[
ops.GraphKeys.GLOBAL_STEP]) ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
])
# Create in proper graph and base name_scope. # Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None): with graph.as_default() as g, g.name_scope(None):
return variable_scope.get_variable( return variable_scope.get_variable(
@ -141,8 +141,7 @@ def create_global_step(graph=None):
initializer=init_ops.zeros_initializer(), initializer=init_ops.zeros_initializer(),
trainable=False, trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
collections=[ops.GraphKeys.GLOBAL_VARIABLES, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
ops.GraphKeys.GLOBAL_STEP])
@tf_export(v1=['train.get_or_create_global_step']) @tf_export(v1=['train.get_or_create_global_step'])
@ -173,9 +172,8 @@ def assert_global_step(global_step_tensor):
if not (isinstance(global_step_tensor, variables.Variable) or if not (isinstance(global_step_tensor, variables.Variable) or
isinstance(global_step_tensor, ops.Tensor) or isinstance(global_step_tensor, ops.Tensor) or
resource_variable_ops.is_resource_variable(global_step_tensor)): resource_variable_ops.is_resource_variable(global_step_tensor)):
raise TypeError( raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
'Existing "global_step" must be a Variable or Tensor: %s.' % global_step_tensor)
global_step_tensor)
if not global_step_tensor.dtype.base_dtype.is_integer: if not global_step_tensor.dtype.base_dtype.is_integer:
raise TypeError('Existing "global_step" does not have integer type: %s' % raise TypeError('Existing "global_step" does not have integer type: %s' %

View File

@ -49,8 +49,8 @@ class VocabInfo(
VocabInfo to warm-start. VocabInfo to warm-start.
Attributes: Attributes:
new_vocab: [Required] A path to the new vocabulary file (used with the new_vocab: [Required] A path to the new vocabulary file (used with the model
model to be trained). to be trained).
new_vocab_size: [Required] An integer indicating how many entries of the new new_vocab_size: [Required] An integer indicating how many entries of the new
vocabulary will used in training. vocabulary will used in training.
num_oov_buckets: [Required] An integer indicating how many OOV buckets are num_oov_buckets: [Required] An integer indicating how many OOV buckets are
@ -76,7 +76,7 @@ class VocabInfo(
num_oov_buckets=1, num_oov_buckets=1,
old_vocab='pretrained_embeddings_vocab', old_vocab='pretrained_embeddings_vocab',
old_vocab_size=10000, old_vocab_size=10000,
backup_initializer=tf.truncated_normal_initializer( backup_initializer=tf.compat.v1.truncated_normal_initializer(
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))), mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
axis=0) axis=0)
@ -86,7 +86,7 @@ class VocabInfo(
num_oov_buckets=0, # No OOV for classes. num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab', old_vocab='old_class_vocab',
old_vocab_size=8, old_vocab_size=8,
backup_initializer=tf.glorot_uniform_initializer(), backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
axis=1) axis=1)
softmax_output_layer_bias_vocab_info = tf.VocabInfo( softmax_output_layer_bias_vocab_info = tf.VocabInfo(
@ -95,7 +95,7 @@ class VocabInfo(
num_oov_buckets=0, # No OOV for classes. num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab', old_vocab='old_class_vocab',
old_vocab_size=8, old_vocab_size=8,
backup_initializer=tf.zeros_initializer(), backup_initializer=tf.compat.v1.zeros_initializer(),
axis=0) axis=0)
Currently, only axis=0 and axis=1 are supported. Currently, only axis=0 and axis=1 are supported.
@ -255,8 +255,7 @@ def _warm_start_var_with_vocab(var,
partition_info = None partition_info = None
if slice_info: if slice_info:
partition_info = variable_scope._PartitionInfo( partition_info = variable_scope._PartitionInfo(
full_shape=slice_info.full_shape, full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
var_offset=slice_info.var_offset)
if axis == 0: if axis == 0:
new_row_vocab_size = current_vocab_size new_row_vocab_size = current_vocab_size
@ -301,6 +300,8 @@ def _warm_start_var_with_vocab(var,
new_init_val = ops.convert_to_tensor( new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info)) init(shape=v_shape, partition_info=partition_info))
v._initializer_op = state_ops.assign(v, new_init_val) v._initializer_op = state_ops.assign(v, new_init_val)
# pylint: enable=protected-access # pylint: enable=protected-access
@ -314,7 +315,8 @@ def _get_grouped_variables(vars_to_warm_start):
vars_to_warm_start: One of the following: vars_to_warm_start: One of the following:
- A regular expression (string) that captures which variables to - A regular expression (string) that captures which variables to
warm-start (see tf.get_collection). This expression will only consider warm-start (see tf.compat.v1.get_collection). This expression will
only consider
variables in the TRAINABLE_VARIABLES collection. variables in the TRAINABLE_VARIABLES collection.
- A list of Variables to warm-start. - A list of Variables to warm-start.
- A list of strings, each representing a full variable name to warm-start. - A list of strings, each representing a full variable name to warm-start.
@ -330,14 +332,13 @@ def _get_grouped_variables(vars_to_warm_start):
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
# everything (in TRAINABLE_VARIABLES) here. # everything (in TRAINABLE_VARIABLES) here.
list_of_vars = ops.get_collection( list_of_vars = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
scope=vars_to_warm_start)
elif isinstance(vars_to_warm_start, list): elif isinstance(vars_to_warm_start, list):
if all(isinstance(v, str) for v in vars_to_warm_start): if all(isinstance(v, str) for v in vars_to_warm_start):
list_of_vars = [] list_of_vars = []
for v in vars_to_warm_start: for v in vars_to_warm_start:
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, list_of_vars += ops.get_collection(
scope=v) ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start list_of_vars = vars_to_warm_start
else: else:
@ -377,16 +378,16 @@ def warm_start(ckpt_to_initialize_from,
vars_to_warm_start: [Optional] One of the following: vars_to_warm_start: [Optional] One of the following:
- A regular expression (string) that captures which variables to - A regular expression (string) that captures which variables to
warm-start (see tf.get_collection). This expression will only consider warm-start (see tf.compat.v1.get_collection). This expression will only
variables in the TRAINABLE_VARIABLES collection -- if you need to consider variables in the TRAINABLE_VARIABLES collection -- if you need
warm-start non_TRAINABLE vars (such as optimizer accumulators or batch to warm-start non_TRAINABLE vars (such as optimizer accumulators or
norm statistics), please use the below option. batch norm statistics), please use the below option.
- A list of Variables to warm-start. If you do not have access to the - A list of Variables to warm-start. If you do not have access to the
`Variable` objects at the call site, please use the below option. `Variable` objects at the call site, please use the below option.
- A list of strings, each a regex scope provided to tf.get_collection with - A list of strings, each a regex scope provided to
GLOBAL_VARIABLES (please see tf.get_collection). For backwards tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
compatibility reasons, this is separate from the single-string argument tf.compat.v1.get_collection). For backwards compatibility reasons,
type. this is separate from the single-string argument type.
- `None`, in which case only variables specified in - `None`, in which case only variables specified in
`var_name_to_vocab_info` will be warm-started. `var_name_to_vocab_info` will be warm-started.
@ -404,6 +405,7 @@ def warm_start(ckpt_to_initialize_from,
effect on the set of variables that is warm-started, and only controls effect on the set of variables that is warm-started, and only controls
name mapping (use `vars_to_warm_start` for controlling what variables to name mapping (use `vars_to_warm_start` for controlling what variables to
warm-start). warm-start).
Raises: Raises:
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
configuration for variable names that are not used. This is to ensure configuration for variable names that are not used. This is to ensure