Apply tf1->tf2 name replaces to doc-strings and comments in tensorflow.

No code changes, only doc-strings and comments.

PiperOrigin-RevId: 244275767
This commit is contained in:
Mark Daoust 2019-04-18 15:56:07 -07:00 committed by TensorFlower Gardener
parent 7fdf27b688
commit 18b680216e
22 changed files with 851 additions and 763 deletions

View File

@ -22,8 +22,11 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["train.basic_train_loop"])
def basic_train_loop(supervisor, train_step_fn, args=None,
kwargs=None, master=""):
def basic_train_loop(supervisor,
train_step_fn,
args=None,
kwargs=None,
master=""):
"""Basic loop to train a model.
Calls `train_step_fn` in a loop to train a model. The function is called as:
@ -32,17 +35,18 @@ def basic_train_loop(supervisor, train_step_fn, args=None,
train_step_fn(session, *args, **kwargs)
```
It is passed a `tf.Session` in addition to `args` and `kwargs`. The function
It is passed a `tf.compat.v1.Session` in addition to `args` and `kwargs`. The
function
typically runs one training step in the session.
Args:
supervisor: `tf.train.Supervisor` to run the training services.
train_step_fn: Callable to execute one training step. Called
repeatedly as `train_step_fn(session, *args **kwargs)`.
supervisor: `tf.compat.v1.train.Supervisor` to run the training services.
train_step_fn: Callable to execute one training step. Called repeatedly as
`train_step_fn(session, *args **kwargs)`.
args: Optional positional arguments passed to `train_step_fn`.
kwargs: Optional keyword arguments passed to `train_step_fn`.
master: Master to use to create the training session. Defaults to
`""` which causes the session to be created in the local process.
master: Master to use to create the training session. Defaults to `""`
which causes the session to be created in the local process.
"""
if args is None:
args = []

View File

@ -42,7 +42,6 @@ from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
from tensorflow.python.util.tf_export import tf_export
_HOOKS = "hooks"
_STEPS_PER_RUN_VAR = "steps_per_run"
@ -85,8 +84,7 @@ class _HookTimer(object):
@tf_export(v1=["train.SecondOrStepTimer"])
class SecondOrStepTimer(_HookTimer):
"""Timer that triggers at most once every N seconds or once every N steps.
"""
"""Timer that triggers at most once every N seconds or once every N steps."""
def __init__(self, every_secs=None, every_steps=None):
self.reset()
@ -171,29 +169,33 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
seeing the logs, you might want to add the following line after your imports:
```python
tf.logging.set_verbosity(tf.logging.INFO)
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
```
Note that if `at_end` is True, `tensors` should not include any tensor
whose evaluation produces a side effect such as consuming additional inputs.
"""
def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
at_end=False, formatter=None):
def __init__(self,
tensors,
every_n_iter=None,
every_n_secs=None,
at_end=False,
formatter=None):
"""Initializes a `LoggingTensorHook`.
Args:
tensors: `dict` that maps string-valued tags to tensors/tensor names,
or `iterable` of tensors/tensor names.
tensors: `dict` that maps string-valued tags to tensors/tensor names, or
`iterable` of tensors/tensor names.
every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker.
steps taken on the current worker.
every_n_secs: `int` or `float`, print the values of `tensors` once every N
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
provided.
at_end: `bool` specifying whether to print the values of `tensors` at the
end of the run.
end of the run.
formatter: function, takes dict of `tag`->`Tensor` and returns a string.
If `None` uses default printing all tensors.
If `None` uses default printing all tensors.
Raises:
ValueError: if `every_n_iter` is non-positive.
@ -215,16 +217,18 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
self._tensors = tensors
self._formatter = formatter
self._timer = (
NeverTriggerTimer() if only_log_at_end else
SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter))
NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
every_secs=every_n_secs, every_steps=every_n_iter))
self._log_at_end = at_end
def begin(self):
self._timer.reset()
self._iter_count = 0
# Convert names to tensors if given
self._current_tensors = {tag: _as_graph_element(tensor)
for (tag, tensor) in self._tensors.items()}
self._current_tensors = {
tag: _as_graph_element(tensor)
for (tag, tensor) in self._tensors.items()
}
def before_run(self, run_context): # pylint: disable=unused-argument
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
@ -463,9 +467,10 @@ class CheckpointSaverListener(object):
...
listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(
saver_hook = tf.estimator.CheckpointSaverHook(
checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
with
tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
...
```
@ -516,9 +521,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances.
Used for callbacks that run immediately before or after this hook saves
the checkpoint.
listeners: List of `CheckpointSaverListener` subclass instances. Used for
callbacks that run immediately before or after this hook saves the
checkpoint.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
@ -531,8 +536,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._checkpoint_dir = checkpoint_dir
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs,
every_steps=save_steps)
self._timer = SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or []
self._steps_per_run = 1
@ -555,13 +560,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
# add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
"graph.pbtxt")
self._checkpoint_dir, "graph.pbtxt")
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=saver_def)
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
@ -573,8 +576,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
stale_global_step = run_values.results
if self._timer.should_trigger_for_step(
stale_global_step + self._steps_per_run):
if self._timer.should_trigger_for_step(stale_global_step +
self._steps_per_run):
# get the real value after train op.
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
@ -627,8 +630,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
elif len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor.".
format(collection_key))
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
self._saver = savers[0]
return savers[0]
@ -647,8 +650,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError(
"exactly one of every_n_steps and every_n_secs should be provided.")
self._timer = SecondOrStepTimer(every_steps=every_n_steps,
every_secs=every_n_secs)
self._timer = SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._summary_writer = summary_writer
self._output_dir = output_dir
@ -673,8 +676,9 @@ class StepCounterHook(session_run_hook.SessionRunHook):
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
steps_per_sec = elapsed_steps / elapsed_time
if self._summary_writer is not None:
summary = Summary(value=[Summary.Value(
tag=self._summary_tag, simple_value=steps_per_sec)])
summary = Summary(value=[
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
])
self._summary_writer.add_summary(summary, global_step)
logging.info("%s: %g", self._summary_tag, steps_per_sec)
@ -682,8 +686,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
_ = run_context
stale_global_step = run_values.results
if self._timer.should_trigger_for_step(
stale_global_step + self._steps_per_run):
if self._timer.should_trigger_for_step(stale_global_step +
self._steps_per_run):
# get the real value after train op.
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
@ -767,18 +771,18 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
Args:
save_steps: `int`, save summaries every N steps. Exactly one of
`save_secs` and `save_steps` should be set.
`save_secs` and `save_steps` should be set.
save_secs: `int`, save summaries every N seconds.
output_dir: `string`, the directory to save the summaries to. Only used
if no `summary_writer` is supplied.
output_dir: `string`, the directory to save the summaries to. Only used if
no `summary_writer` is supplied.
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
one will be created accordingly.
one will be created accordingly.
scaffold: `Scaffold` to get summary_op if it's not provided.
summary_op: `Tensor` of type `string` containing the serialized `Summary`
protocol buffer or a list of `Tensor`. They are most likely an output
by TF summary methods like `tf.summary.scalar` or
`tf.summary.merge_all`. It can be passed in as one tensor; if more
than one, they must be passed in as a list.
protocol buffer or a list of `Tensor`. They are most likely an output by
TF summary methods like `tf.compat.v1.summary.scalar` or
`tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
more than one, they must be passed in as a list.
Raises:
ValueError: Exactly one of scaffold or summary_op should be set.
@ -791,8 +795,8 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
self._summary_writer = summary_writer
self._output_dir = output_dir
self._scaffold = scaffold
self._timer = SecondOrStepTimer(every_secs=save_secs,
every_steps=save_steps)
self._timer = SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
def begin(self):
@ -903,8 +907,9 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
self._worker_is_started = True
return None
if current_step - last_logged_step > 1000:
logging.info("Waiting for global step %d before starting training. "
"Current step is %d.", self._wait_until_step, current_step)
logging.info(
"Waiting for global step %d before starting training. "
"Current step is %d.", self._wait_until_step, current_step)
last_logged_step = current_step
time.sleep(0.5)
@ -917,8 +922,8 @@ class FinalOpsHook(session_run_hook.SessionRunHook):
"""Initializes `FinalOpHook` with ops to run at the end of the session.
Args:
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of
names to `Tensors`.
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when running
`final_ops_dict`.
"""
@ -997,14 +1002,14 @@ class ProfilerHook(session_run_hook.SessionRunHook):
Args:
save_steps: `int`, save profile traces every N steps. Exactly one of
`save_secs` and `save_steps` should be set.
`save_secs` and `save_steps` should be set.
save_secs: `int` or `float`, save profile traces every N seconds.
output_dir: `string`, the directory to save the profile traces to.
Defaults to the current directory.
Defaults to the current directory.
show_dataflow: `bool`, if True, add flow events to the trace connecting
producers and consumers of tensors.
producers and consumers of tensors.
show_memory: `bool`, if True, add object snapshot events to the trace
showing the sizes and lifetimes of tensors.
showing the sizes and lifetimes of tensors.
"""
self._output_file = os.path.join(output_dir, "timeline-{}.json")
self._file_writer = SummaryWriterCache.get(output_dir)
@ -1024,8 +1029,9 @@ class ProfilerHook(session_run_hook.SessionRunHook):
self._next_step is not None and
self._timer.should_trigger_for_step(self._next_step))
requests = {"global_step": self._global_step_tensor}
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
if self._request_summary else None)
opts = (
config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
if self._request_summary else None)
return SessionRunArgs(requests, options=opts)
@ -1039,8 +1045,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
if self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)
self._timer.update_last_triggered_step(global_step)
self._save(global_step,
self._output_file.format(global_step),
self._save(global_step, self._output_file.format(global_step),
run_values.run_metadata.step_stats)
self._file_writer.add_run_metadata(run_values.run_metadata,
"step_%d" % global_step)

View File

@ -106,7 +106,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
"""Replaces `tf.Variable` initializers so they load from a checkpoint file.
Values are not loaded immediately, but when the initializer is run
(typically by running a `tf.global_variables_initializer` op).
(typically by running a `tf.compat.v1.global_variables_initializer` op).
Note: This overrides default initialization ops of specified variables and
redefines dtype.
@ -139,15 +139,15 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
# -- name='old_scope_2/var3', shape=[100, 100]
# Create new model's variables
with tf.variable_scope('new_scope_1'):
var1 = tf.get_variable('var1', shape=[20, 2],
initializer=tf.zeros_initializer())
with tf.variable_scope('new_scope_2'):
var2 = tf.get_variable('var2', shape=[50, 4],
initializer=tf.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_1'):
var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
initializer=tf.compat.v1.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_2'):
var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
initializer=tf.compat.v1.zeros_initializer())
# Partition into 5 variables along the first axis.
var3 = tf.get_variable(name='var3', shape=[100, 100],
initializer=tf.zeros_initializer(),
var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
initializer=tf.compat.v1.zeros_initializer(),
partitioner=lambda shape, dtype: [5, 1])
# Initialize all variables in `new_scope_1` from `old_scope_1`.

View File

@ -131,9 +131,13 @@ class _ReplicaDeviceChooser(object):
@tf_export(v1=["train.replica_device_setter"])
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
worker_device="/job:worker", merge_devices=True,
cluster=None, ps_ops=None, ps_strategy=None):
def replica_device_setter(ps_tasks=0,
ps_device="/job:ps",
worker_device="/job:worker",
merge_devices=True,
cluster=None,
ps_ops=None,
ps_strategy=None):
"""Return a `device function` to use when building a Graph for replicas.
Device Functions are used in `with tf.device(device_function):` statement to
@ -158,7 +162,8 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
cluster_spec = {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
with
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
# Build your graph
v1 = tf.Variable(...) # assigned to /job:ps/task:0
v2 = tf.Variable(...) # assigned to /job:ps/task:1
@ -218,6 +223,6 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
ps_strategy = _RoundRobinStrategy(ps_tasks)
if not six.callable(ps_strategy):
raise TypeError("ps_strategy must be callable")
chooser = _ReplicaDeviceChooser(
ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy)
chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
merge_devices, ps_ops, ps_strategy)
return chooser.device_function

View File

@ -65,8 +65,8 @@ def _get_latest_eval_step_value(update_ops):
"""Gets the eval step `Tensor` value after running `update_ops`.
Args:
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
which are run before reading the eval step value.
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which
are run before reading the eval step value.
Returns:
A `Tensor` representing the value for the evaluation step.
@ -102,21 +102,20 @@ class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
def after_create_session(self, session, coord):
# Update number of steps to run in the first run call
if self._num_evals is None:
if self._num_evals is None:
steps = self._steps_per_run_initial_value
else:
steps = min(self._steps_per_run_initial_value, self._num_evals)
self._steps_per_run_variable.load(steps, session=session)
def before_run(self, run_context):
return session_run_hook.SessionRunArgs({
'evals_completed': self._evals_completed
})
return session_run_hook.SessionRunArgs(
{'evals_completed': self._evals_completed})
def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed']
# Update number of steps to run in the next iteration
if self._num_evals is None:
if self._num_evals is None:
steps = self._steps_per_run_initial_value
else:
steps = min(self._num_evals - evals_completed,
@ -147,16 +146,15 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
self._evals_completed = None
self._log_progress = log_progress
# Reduce logging frequency if there are 20 or more evaluations.
self._log_frequency = (1 if (num_evals is None or num_evals < 20)
else math.floor(num_evals / 10.))
self._log_frequency = (1 if (num_evals is None or num_evals < 20) else
math.floor(num_evals / 10.))
def _set_evals_completed_tensor(self, updated_eval_step):
self._evals_completed = updated_eval_step
def before_run(self, run_context):
return session_run_hook.SessionRunArgs({
'evals_completed': self._evals_completed
})
return session_run_hook.SessionRunArgs(
{'evals_completed': self._evals_completed})
def after_run(self, run_context, run_values):
evals_completed = run_values.results['evals_completed']
@ -205,20 +203,20 @@ def _evaluate_once(checkpoint_path,
Args:
checkpoint_path: The path to a checkpoint to use for evaluation.
master: The BNS address of the TensorFlow master.
scaffold: An tf.train.Scaffold instance for initializing variables and
restoring variables. Note that `scaffold.init_fn` is used by the function
to restore the checkpoint. If you supply a custom init_fn, then it must
also take care of restoring the model from its checkpoint.
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`, which is run until the session is requested to stop,
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables
and restoring variables. Note that `scaffold.init_fn` is used by the
function to restore the checkpoint. If you supply a custom init_fn, then
it must also take care of restoring the model from its checkpoint.
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
`Tensors`, which is run until the session is requested to stop, commonly
done by a `tf.contrib.training.StopAfterNEvalsHook`.
feed_dict: The feed dictionary to use when executing the `eval_ops`.
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
to `Tensors`.
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
evaluation loop.
config: An instance of `tf.ConfigProto` that will be used to
hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside
the evaluation loop.
config: An instance of `tf.compat.v1.ConfigProto` that will be used to
configure the `Session`. If left as `None`, the default will be used.
Returns:
@ -263,8 +261,8 @@ def _evaluate_once(checkpoint_path,
master=master,
config=config)
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
final_ops, final_ops_feed_dict)
final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
final_ops_feed_dict)
hooks.append(final_ops_hook)
with monitored_session.MonitoredSession(

View File

@ -1090,7 +1090,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
The `tensors_list` argument is a list of tuples of tensors, or a list of
dictionaries of tensors. Each element in the list is treated similarly
to the `tensors` argument of `tf.train.batch()`.
to the `tensors` argument of `tf.compat.v1.train.batch()`.
WARNING: This function is nondeterministic, since it starts a separate thread
for each tensor.
@ -1284,7 +1284,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
```python
# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch(
image_batch, label_batch = tf.compat.v1.train.shuffle_batch(
[single_image, single_label],
batch_size=32,
num_threads=4,
@ -1425,7 +1425,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
The `tensors_list` argument is a list of tuples of tensors, or a list of
dictionaries of tensors. Each element in the list is treated similarly
to the `tensors` argument of `tf.train.shuffle_batch()`.
to the `tensors` argument of `tf.compat.v1.train.shuffle_batch()`.
This version enqueues a different list of tensors in different threads.
It adds the following to the current `Graph`:

View File

@ -56,24 +56,25 @@ def exponential_decay(learning_rate,
...
global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 0.1
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
global_step,
100000, 0.96, staircase=True)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.train.GradientDescentOptimizer(learning_rate)
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
```
Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Must be positive. See the decay computation above.
decay_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The decay rate.
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
be positive. See the decay computation above.
decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
The decay rate.
staircase: Boolean. If `True` decay the learning rate at discrete intervals
name: String. Optional name of the operation. Defaults to
'ExponentialDecay'.
@ -91,11 +92,8 @@ def exponential_decay(learning_rate,
the learning rate value across different invocations of optimizer functions.
@end_compatibility
"""
decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate,
decay_steps,
decay_rate,
staircase=staircase,
name=name)
decayed_lr = learning_rate_schedule.ExponentialDecay(
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
if not context.executing_eagerly():
decayed_lr = decayed_lr(global_step)
else:
@ -114,7 +112,8 @@ def piecewise_constant(x, boundaries, values, name=None):
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries,
values)
# Later, whenever we perform an optimization step, we increment global_step.
```
@ -202,27 +201,28 @@ def polynomial_decay(learning_rate,
starter_learning_rate = 0.1
end_learning_rate = 0.01
decay_steps = 10000
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step,
learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate,
global_step,
decay_steps, end_learning_rate,
power=0.5)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.train.GradientDescentOptimizer(learning_rate)
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
```
Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Must be positive. See the decay computation above.
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The minimal end learning rate.
power: A scalar `float32` or `float64` `Tensor` or a
Python number. The power of the polynomial. Defaults to linear, 1.0.
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation. Must not be negative.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
be positive. See the decay computation above.
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python
number. The minimal end learning rate.
power: A scalar `float32` or `float64` `Tensor` or a Python number. The
power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
name: String. Optional name of the operation. Defaults to
'PolynomialDecay'.
@ -292,21 +292,22 @@ def natural_exp_decay(learning_rate,
learning_rate = 0.1
decay_steps = 5
k = 0.5
learning_rate = tf.train.natural_exp_decay(learning_rate, global_step,
learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate,
global_step,
decay_steps, k)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.train.GradientDescentOptimizer(learning_rate)
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
```
Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
global_step: A Python number.
Global step to use for the decay computation. Must not be negative.
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
The initial learning rate.
global_step: A Python number. Global step to use for the decay computation.
Must not be negative.
decay_steps: How often to apply decay.
decay_rate: A Python number. The decay rate.
staircase: Whether to apply decay in a discrete staircase, as opposed to
@ -329,7 +330,10 @@ def natural_exp_decay(learning_rate,
"""
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
decayed_lr = learning_rate_schedule.ExponentialDecay(
learning_rate, decay_steps, natural_exp_rate, staircase=staircase,
learning_rate,
decay_steps,
natural_exp_rate,
staircase=staircase,
name=name)
if not context.executing_eagerly():
@ -376,21 +380,22 @@ def inverse_time_decay(learning_rate,
learning_rate = 0.1
decay_steps = 1.0
decay_rate = 0.5
learning_rate = tf.train.inverse_time_decay(learning_rate, global_step,
learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate,
global_step,
decay_steps, decay_rate)
# Passing global_step to minimize() will increment it at each step.
learning_step = (
tf.train.GradientDescentOptimizer(learning_rate)
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
.minimize(...my loss..., global_step=global_step)
)
```
Args:
learning_rate: A scalar `float32` or `float64` `Tensor` or a
Python number. The initial learning rate.
global_step: A Python number.
Global step to use for the decay computation. Must not be negative.
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
The initial learning rate.
global_step: A Python number. Global step to use for the decay computation.
Must not be negative.
decay_steps: How often to apply decay.
decay_rate: A Python number. The decay rate.
staircase: Whether to apply decay in a discrete staircase, as opposed to
@ -412,11 +417,7 @@ def inverse_time_decay(learning_rate,
@end_compatibility
"""
decayed_lr = learning_rate_schedule.InverseTimeDecay(
learning_rate,
decay_steps,
decay_rate,
staircase=staircase,
name=name)
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
if not context.executing_eagerly():
decayed_lr = decayed_lr(global_step)
@ -455,13 +456,14 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over.
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of learning_rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
of steps to decay over.
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
learning rate value as a fraction of learning_rate.
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate.
@ -519,17 +521,18 @@ def cosine_decay_restarts(learning_rate,
Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation.
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over.
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Used to derive the number of iterations in the i-th period
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to
derive the number of iterations in the i-th period
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
Used to derive the initial learning rate of the i-th period:
alpha: A scalar `float32` or `float64` Tensor or a Python number.
Minimum learning rate value as a fraction of the learning_rate.
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
learning rate value as a fraction of the learning_rate.
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate.
@ -602,16 +605,17 @@ def linear_cosine_decay(learning_rate,
Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over.
num_periods: Number of periods in the cosine part of the decay.
See computation above.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
of steps to decay over.
num_periods: Number of periods in the cosine part of the decay. See
computation above.
alpha: See computation above.
beta: See computation above.
name: String. Optional name of the operation. Defaults to
'LinearCosineDecay'.
Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate.
@ -690,18 +694,19 @@ def noisy_linear_cosine_decay(learning_rate,
Args:
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
The initial learning rate.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
Global step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
Number of steps to decay over.
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
step to use for the decay computation.
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
of steps to decay over.
initial_variance: initial variance for the noise. See computation above.
variance_decay: decay for the noise's variance. See computation above.
num_periods: Number of periods in the cosine part of the decay.
See computation above.
num_periods: Number of periods in the cosine part of the decay. See
computation above.
alpha: See computation above.
beta: See computation above.
name: String. Optional name of the operation. Defaults to
'NoisyLinearCosineDecay'.
Returns:
A scalar `Tensor` of the same type as `learning_rate`. The decayed
learning rate.

View File

@ -77,7 +77,8 @@ class Scaffold(object):
The following pieces are directly accessible as attributes of the `Scaffold`
object:
* `saver`: A `tf.train.Saver` object taking care of saving the variables.
* `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
variables.
Picked from and stored into the `SAVERS` collection in the graph by default.
* `init_op`: An op to run to initialize the variables. Picked from and
stored into the `INIT_OP` collection in the graph by default.
@ -133,9 +134,9 @@ class Scaffold(object):
local_init_op: Optional op to initialize local variables.
summary_op: Optional op to gather all summaries. Must return a scalar
string tensor containing a serialized `Summary` proto.
saver: Optional `tf.train.Saver` object to use to save and restore
variables. May also be a `tf.train.Checkpoint` object, in which case
object-based checkpoints are saved. This will also load some
saver: Optional `tf.compat.v1.train.Saver` object to use to save and
restore variables. May also be a `tf.train.Checkpoint` object, in which
case object-based checkpoints are saved. This will also load some
object-based checkpoints saved from elsewhere, but that loading may be
fragile since it uses fixed keys rather than performing a full
graph-based match. For example if a variable has two paths from the
@ -199,8 +200,9 @@ class Scaffold(object):
resources.report_uninitialized_resources()
], 0)
self._ready_op = Scaffold.get_or_default(
'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
self._ready_op = Scaffold.get_or_default('ready_op',
ops.GraphKeys.READY_OP,
default_ready_op)
if self._ready_for_local_init_op is None:
def default_ready_for_local_init_op():
@ -219,8 +221,9 @@ class Scaffold(object):
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold.default_local_init_op)
if self._summary_op is None:
self._summary_op = Scaffold.get_or_default(
'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all)
self._summary_op = Scaffold.get_or_default('summary_op',
ops.GraphKeys.SUMMARY_OP,
summary.merge_all)
# pylint: disable=g-long-lambda
if self._saver is None:
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
@ -292,7 +295,8 @@ class Scaffold(object):
This op is used during session initialization when a Scaffold is
initialized without specifying the local_init_op arg. It includes
`tf.local_variables_initializer`, `tf.tables_initializer`, and also
`tf.compat.v1.local_variables_initializer`,
`tf.compat.v1.tables_initializer`, and also
initializes local session resources.
Returns:
@ -435,7 +439,8 @@ def MonitoredTrainingSession(
For a chief, this utility sets proper session initializer/restorer. It also
creates hooks related to checkpoint and summary saving. For workers, this
utility sets proper session creator which waits for the chief to
initialize/restore. Please check `tf.train.MonitoredSession` for more
initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
more
information.
@ -464,8 +469,9 @@ def MonitoredTrainingSession(
to disk using a default summary saver. If both `save_summaries_steps` and
`save_summaries_secs` are set to `None`, then the default summary saver
isn't used. Default not enabled.
config: an instance of `tf.ConfigProto` proto used to configure the session.
It's the `config` argument of constructor of `tf.Session`.
config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
the session. It's the `config` argument of constructor of
`tf.compat.v1.Session`.
stop_grace_period_secs: Number of seconds given to threads to stop after
`close()` has been called.
log_step_count_steps: The frequency, in number of global steps, that the
@ -591,7 +597,7 @@ class SessionCreator(object):
@tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.Session for a chief."""
"""Creates a tf.compat.v1.Session for a chief."""
def __init__(self,
scaffold=None,
@ -643,7 +649,7 @@ class ChiefSessionCreator(SessionCreator):
@tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.Session for a worker."""
"""Creates a tf.compat.v1.Session for a worker."""
def __init__(self,
scaffold=None,
@ -757,8 +763,9 @@ class _MonitoredSession(object):
`step_fn` will be returned from `run_step_fn`, unless a stop is
requested. In that case, the next `should_stop` call will return True.
Example usage: ```python
with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v =
tf.add(c, 4.0) w = tf.add(c, 0.5)
with tf.Graph().as_default(): c =
tf.compat.v1.placeholder(dtypes.float32) v = tf.add(c, 4.0) w =
tf.add(c, 0.5)
def step_fn(step_context):
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
if a <= 4.5: step_context.request_stop()
@ -808,7 +815,7 @@ class _MonitoredSession(object):
"""Initializes the `step_context` argument for a `step_fn` invocation.
Args:
session: An instance of `tf.Session`.
session: An instance of `tf.compat.v1.Session`.
run_with_hooks_fn: A function for running fetches and hooks.
"""
self._session = session
@ -901,13 +908,13 @@ class _MonitoredSession(object):
return self._coordinated_creator.tf_sess is None
def _tf_sess(self):
"""Return underlying tf.Session object.
"""Return underlying tf.compat.v1.Session object.
Warning: accessing the returned object in user code is likely to cause races
or "flaky tests".
Returns:
A tf.Session object.
A tf.compat.v1.Session object.
"""
return self._coordinated_creator.tf_sess
@ -955,7 +962,7 @@ class MonitoredSession(_MonitoredSession):
* suppresses `OutOfRange` error which indicates that all inputs have been
processed if the monitored_session is used as a context
How to set `tf.Session` arguments:
How to set `tf.compat.v1.Session` arguments:
* In most cases you can set session arguments as follows:
@ -973,7 +980,8 @@ class MonitoredSession(_MonitoredSession):
See `MonitoredTrainingSession` for an example usage based on chief or worker.
Note: This is not a `tf.Session`. For example, it cannot do following:
Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
following:
* it cannot be set as default session.
* it cannot be sent to saver.save.
@ -1004,14 +1012,15 @@ class SingularMonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, restoring, and hooks.
Please note that this utility is not recommended for distributed settings.
For distributed settings, please use `tf.train.MonitoredSession`. The
For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
The
differences between `MonitoredSession` and `SingularMonitoredSession` are:
* `MonitoredSession` handles `AbortedError` and `UnavailableError` for
distributed settings, but `SingularMonitoredSession` does not.
* `MonitoredSession` can be created in `chief` or `worker` modes.
`SingularMonitoredSession` is always created as `chief`.
* You can access the raw `tf.Session` object used by
* You can access the raw `tf.compat.v1.Session` object used by
`SingularMonitoredSession`, whereas in MonitoredSession the raw session is
private. This can be used:
- To `run` without hooks.
@ -1093,7 +1102,7 @@ class SingularMonitoredSession(_MonitoredSession):
class _WrappedSession(object):
"""Wrapper around a `tf.Session`.
"""Wrapper around a `tf.compat.v1.Session`.
This wrapper is used as a base class for various session wrappers
that provide additional functionality such as monitoring, coordination,
@ -1108,7 +1117,8 @@ class _WrappedSession(object):
"""Creates a `_WrappedSession`.
Args:
sess: A `tf.Session` or `_WrappedSession` object. The wrapped session.
sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped
session.
"""
self._sess = sess
self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
@ -1293,7 +1303,7 @@ class _CoordinatedSession(_WrappedSession):
"""Create a new `_CoordinatedSession`.
Args:
sess: A `tf.Session` object. The wrapped session.
sess: A `tf.compat.v1.Session` object. The wrapped session.
coord: A `tf.train.Coordinator` object.
stop_grace_period_secs: Number of seconds given to threads to stop after
`close()` has been called.
@ -1364,7 +1374,7 @@ class _HookedSession(_WrappedSession):
"""Initializes a _HookedSession object.
Args:
sess: A `tf.Session` or a `_WrappedSession` object.
sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
hooks: An iterable of `SessionRunHook' objects.
"""

View File

@ -56,9 +56,9 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
E.g.:
```
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
with tf.compat.v1.variable_scope('scope1'):
with tf.compat.v1.variable_scope('scope2'):
var = tf.compat.v1.get_variable('foo')
update_1 = tf.assign_moving_average(var, 0.0, 1.0)
update_2 = tf.assign_moving_average(var, 0.0, 0.9)
@ -73,12 +73,13 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
decay: A float Tensor or float value. The moving average decay.
zero_debias: A python bool. If true, assume the variable is 0-initialized
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
`_zero_debias` for more details.
`_zero_debias` for more details.
name: Optional name of the returned operation.
Returns:
A tensor which if evaluated will compute and return the new moving average.
"""
def update_fn(v, value, decay=decay):
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
if decay.dtype != v.dtype.base_dtype:
@ -96,8 +97,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
# In a replica context, we update variable using the mean of value across
# replicas.
def merge_fn(strategy, v, value):
value = strategy.extended.reduce_to(
ds_reduce_util.ReduceOp.MEAN, value, v)
value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
v)
return strategy.extended.update(v, update_fn, args=(value,))
return replica_context.merge_call(merge_fn, args=(variable, value))
@ -124,15 +125,15 @@ def weighted_moving_average(value,
Args:
value: A numeric `Tensor`.
decay: A float `Tensor` or float value. The moving average decay.
weight: `Tensor` that keeps the current value of a weight.
Shape should be able to multiply `value`.
weight: `Tensor` that keeps the current value of a weight. Shape should be
able to multiply `value`.
truediv: Boolean, if `True`, dividing by `moving_average(weight)` is
floating point division. If `False`, use division implied by dtypes.
collections: List of graph collections keys to add the internal variables
`value * weight` and `weight` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
name: Optional name of the returned operation.
Defaults to "WeightedMovingAvg".
`value * weight` and `weight` to. Defaults to
`[GraphKeys.GLOBAL_VARIABLES]`.
name: Optional name of the returned operation. Defaults to
"WeightedMovingAvg".
Returns:
An Operation that updates and returns the weighted moving average.
@ -203,27 +204,34 @@ def _zero_debias(unbiased_var, value, decay):
tensor will also update the shadow variables appropriately.
"""
with variable_scope.variable_scope(
unbiased_var.name[:-len(":0")], values=[unbiased_var,
value, decay]) as scope:
unbiased_var.name[:-len(":0")], values=[unbiased_var, value,
decay]) as scope:
with ops.colocate_with(unbiased_var):
with ops.init_scope():
biased_initializer = init_ops.zeros_initializer(
dtype=unbiased_var.dtype)(unbiased_var.get_shape())
dtype=unbiased_var.dtype)(
unbiased_var.get_shape())
local_step_initializer = init_ops.zeros_initializer()
def _maybe_get_unique(name):
"""Get name for a unique variable, if not `reuse=True`."""
if variable_scope.get_variable_scope().reuse:
return name
vs_vars = [x.op.name for x in
variable_scope.get_variable_scope().global_variables()]
vs_vars = [
x.op.name
for x in variable_scope.get_variable_scope().global_variables()
]
full_name = variable_scope.get_variable_scope().name + "/" + name
if full_name not in vs_vars: return name
if full_name not in vs_vars:
return name
idx = 1
while full_name + ("_%d" % idx) in vs_vars:
idx += 1
return name + ("_%d" % idx)
biased_var = variable_scope.get_variable(
_maybe_get_unique("biased"), initializer=biased_initializer,
_maybe_get_unique("biased"),
initializer=biased_initializer,
trainable=False)
local_step = variable_scope.get_variable(
_maybe_get_unique("local_step"),
@ -233,18 +241,17 @@ def _zero_debias(unbiased_var, value, decay):
trainable=False)
# Get an update ops for both shadow variables.
update_biased = state_ops.assign_sub(biased_var,
(biased_var - value) * decay,
name=scope.name)
update_biased = state_ops.assign_sub(
biased_var, (biased_var - value) * decay, name=scope.name)
update_local_step = local_step.assign_add(1)
# Compute the value of the delta to update the unbiased EMA. Make sure to
# use the new values of the biased variable and the local step.
with ops.control_dependencies([update_biased, update_local_step]):
# This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
unbiased_ema_delta = (unbiased_var - biased_var.read_value() /
(1 - math_ops.pow(
1.0 - decay, local_step.read_value())))
unbiased_ema_delta = (
unbiased_var - biased_var.read_value() /
(1 - math_ops.pow(1.0 - decay, local_step.read_value())))
return unbiased_ema_delta
@ -315,7 +322,7 @@ class ExponentialMovingAverage(object):
for a given variable.
* Build a model normally but load the checkpoint files to evaluate by using
the shadow variable names. For this use the `average_name()` method. See
the `tf.train.Saver` for more
the `tf.compat.v1.train.Saver` for more
information on restoring saved variables.
Example of restoring the shadow variable values:
@ -324,13 +331,17 @@ class ExponentialMovingAverage(object):
# Create a Saver that loads variables from their saved shadow values.
shadow_var0_name = ema.average_name(var0)
shadow_var1_name = ema.average_name(var1)
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name:
var1})
saver.restore(...checkpoint filename...)
# var0 and var1 now hold the moving average values
```
"""
def __init__(self, decay, num_updates=None, zero_debias=False,
def __init__(self,
decay,
num_updates=None,
zero_debias=False,
name="ExponentialMovingAverage"):
"""Creates a new ExponentialMovingAverage object.
@ -376,7 +387,7 @@ class ExponentialMovingAverage(object):
shadow variables are created with `trainable=False` and added to the
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`.
`tf.compat.v1.global_variables()`.
Returns an op that updates all shadow variables from the current value of
their associated variables.
@ -386,8 +397,8 @@ class ExponentialMovingAverage(object):
be called in a loop.
Args:
var_list: A list of Variable or Tensor objects. The variables
and Tensors must be of types bfloat16, float16, float32, or float64.
var_list: A list of Variable or Tensor objects. The variables and Tensors
must be of types bfloat16, float16, float32, or float64.
Returns:
An Operation that updates the moving averages.
@ -417,10 +428,11 @@ class ExponentialMovingAverage(object):
# tensors, we rely on the existing device allocation mechanism.
with ops.init_scope():
if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(var,
var.initialized_value(),
self.name,
colocate_with_primary=True)
avg = slot_creator.create_slot(
var,
var.initialized_value(),
self.name,
colocate_with_primary=True)
# NOTE(mrry): We only add `tf.Variable` objects to the
# `MOVING_AVERAGE_VARIABLES` collection.
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
@ -428,9 +440,9 @@ class ExponentialMovingAverage(object):
avg = slot_creator.create_zeros_slot(
var,
self.name,
colocate_with_primary=(var.op.type in ["Variable",
"VariableV2",
"VarHandleOp"]))
colocate_with_primary=(var.op.type in [
"Variable", "VariableV2", "VarHandleOp"
]))
if self._zero_debias:
zero_debias_true.add(avg)
self._averages[var] = avg
@ -438,16 +450,16 @@ class ExponentialMovingAverage(object):
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
if self._num_updates is not None:
num_updates = math_ops.cast(self._num_updates,
dtypes.float32,
name="num_updates")
num_updates = math_ops.cast(
self._num_updates, dtypes.float32, name="num_updates")
decay = math_ops.minimum(decay,
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
zero_debias = self._averages[var] in zero_debias_true
updates.append(assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
updates.append(
assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
return control_flow_ops.group(*updates, name=scope)
def average(self, var):
@ -472,7 +484,7 @@ class ExponentialMovingAverage(object):
To restore variables, you have to know the name of the shadow variables.
That name and the original variable can then be passed to a `Saver()` object
to restore the variable from the moving average value with:
`saver = tf.train.Saver({ema.average_name(var): var})`
`saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
`average_name()` can be called whether or not `apply()` has been called.
@ -499,7 +511,7 @@ class ExponentialMovingAverage(object):
```python
variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
saver = tf.compat.v1.train.Saver(variables_to_restore)
```
Below is an example of such mapping:

View File

@ -434,14 +434,14 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
Raises:
ValueError: if `sess` is None and there isn't any default session.
TypeError: if `sess` is not a `tf.Session` object.
TypeError: if `sess` is not a `tf.compat.v1.Session` object.
Returns:
A list of threads.
Raises:
RuntimeError: If called with eager execution enabled.
ValueError: If called without a default `tf.Session` registered.
ValueError: If called without a default `tf.compat.v1.Session` registered.
@compatibility(eager)
Not compatible with eager execution. To ingest data under eager execution,

View File

@ -56,7 +56,6 @@ from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
# TODO(allenl): Remove these aliases once all users are migrated off.
get_checkpoint_state = checkpoint_management.get_checkpoint_state
update_checkpoint_state = checkpoint_management.update_checkpoint_state
@ -174,13 +173,11 @@ class BaseSaverBuilder(object):
tensors = []
for spec in saveable.specs:
tensors.append(
io_ops.restore_v2(
filename_tensor,
[spec.name],
[spec.slice_spec],
[spec.dtype])[0])
io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
[spec.dtype])[0])
return tensors
# pylint: enable=unused-argument
def sharded_filename(self, filename_tensor, shard, num_shards):
@ -217,8 +214,8 @@ class BaseSaverBuilder(object):
from each device.
Args:
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A
FILENAME*, but as a prefix of a V2 checkpoint;
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
but as a prefix of a V2 checkpoint;
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
returned by _GroupByDevices().
@ -319,8 +316,8 @@ class BaseSaverBuilder(object):
saveables: A list of SaveableObject objects.
restore_sequentially: True if we want to restore variables sequentially
within a shard.
reshape: True if we want to reshape loaded tensors to the shape of
the corresponding variable.
reshape: True if we want to reshape loaded tensors to the shape of the
corresponding variable.
preferred_shard: Shard to open first when loading a sharded file.
name: Name for the returned op.
@ -361,12 +358,12 @@ class BaseSaverBuilder(object):
Args:
filename_tensor: Tensor for the path of the file to load.
per_device: A list of (device, SaveableObject) pairs, as
returned by _GroupByDevices().
per_device: A list of (device, SaveableObject) pairs, as returned by
_GroupByDevices().
restore_sequentially: True if we want to restore variables sequentially
within a shard.
reshape: True if we want to reshape loaded tensors to the shape of
the corresponding variable.
reshape: True if we want to reshape loaded tensors to the shape of the
corresponding variable.
Returns:
An Operation that restores the variables.
@ -424,14 +421,13 @@ class BaseSaverBuilder(object):
Args:
names_to_saveables: A dictionary mapping name to a Variable or
SaveableObject. Each name will be associated with the
corresponding variable in the checkpoint.
reshape: If True, allow restoring parameters from a checkpoint
that where the parameters have a different shape. This is
only needed when you try to restore from a Dist-Belief checkpoint,
and only some times.
sharded: If True, shard the checkpoints, one per device that has
Variable nodes.
SaveableObject. Each name will be associated with the corresponding
variable in the checkpoint.
reshape: If True, allow restoring parameters from a checkpoint that where
the parameters have a different shape. This is only needed when you try
to restore from a Dist-Belief checkpoint, and only some times.
sharded: If True, shard the checkpoints, one per device that has Variable
nodes.
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
are created, old ones are deleted. If None or 0, no checkpoints are
deleted from the filesystem but only the last one is kept in the
@ -597,8 +593,8 @@ def _get_saver_or_default():
if len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor.".
format(collection_key))
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
return savers[0]
saver = Saver(sharded=True, allow_empty=True)
if saver is not None:
@ -662,9 +658,9 @@ class Saver(object):
```python
...
# Create a saver.
saver = tf.train.Saver(...variables...)
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
sess = tf.compat.v1.Session()
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
@ -717,13 +713,13 @@ class Saver(object):
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
saver = tf.compat.v1.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
```
The optional `reshape` argument, if `True`, allows restoring a variable from
@ -738,35 +734,33 @@ class Saver(object):
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
names to `SaveableObject`s. If `None`, defaults to the list of all
saveable objects.
reshape: If `True`, allows restoring parameters from a checkpoint
where the variables have a different shape.
reshape: If `True`, allows restoring parameters from a checkpoint where
the variables have a different shape.
sharded: If `True`, shard the checkpoints, one per device.
max_to_keep: Maximum number of recent checkpoints to keep.
Defaults to 5.
keep_checkpoint_every_n_hours: How often to keep checkpoints.
Defaults to 10,000 hours.
max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
10,000 hours.
name: String. Optional name to use as a prefix when adding operations.
restore_sequentially: A `Bool`, which if true, causes restore of different
variables to happen sequentially within each device. This can lower
memory usage when restoring very large models.
saver_def: Optional `SaverDef` proto to use instead of running the
builder. This is only useful for specialty code that wants to recreate
a `Saver` object for a previously built `Graph` that had a `Saver`.
The `saver_def` proto should be the one returned by the
`as_saver_def()` call of the `Saver` that was created for that `Graph`.
builder. This is only useful for specialty code that wants to recreate a
`Saver` object for a previously built `Graph` that had a `Saver`. The
`saver_def` proto should be the one returned by the `as_saver_def()`
call of the `Saver` that was created for that `Graph`.
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
Defaults to `BulkSaverBuilder()`.
defer_build: If `True`, defer adding the save and restore ops to the
`build()` call. In that case `build()` should be called before
finalizing the graph or using the saver.
allow_empty: If `False` (default) raise an error if there are no
variables in the graph. Otherwise, construct the saver anyway and make
it a no-op.
allow_empty: If `False` (default) raise an error if there are no variables
in the graph. Otherwise, construct the saver anyway and make it a no-op.
write_version: controls what format to use when saving checkpoints. It
also affects certain filepath matching logic. The V2 format is the
recommended choice: it is much more optimized than V1 in terms of
memory required and latency incurred during restore. Regardless of
this flag, the Saver is able to restore from both V2 and V1 checkpoints.
recommended choice: it is much more optimized than V1 in terms of memory
required and latency incurred during restore. Regardless of this
flag, the Saver is able to restore from both V2 and V1 checkpoints.
pad_step_number: if True, pads the global step number in the checkpoint
filepaths to some fixed width (8 by default). This is turned off by
default.
@ -877,7 +871,8 @@ class Saver(object):
name=self._name,
restore_sequentially=self._restore_sequentially,
filename=checkpoint_path,
build_save=build_save, build_restore=build_restore)
build_save=build_save,
build_restore=build_restore)
elif self.saver_def and self._name:
# Since self._name is used as a name_scope by builder(), we are
# overloading the use of this field to represent the "import_scope" as
@ -997,8 +992,8 @@ class Saver(object):
saver_def.filename_tensor_name, export_scope)
saver_def.save_tensor_name = ops.strip_name_scope(
saver_def.save_tensor_name, export_scope)
saver_def.restore_op_name = ops.strip_name_scope(
saver_def.restore_op_name, export_scope)
saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
export_scope)
return saver_def
@staticmethod
@ -1092,14 +1087,13 @@ class Saver(object):
Args:
sess: A Session to use to save the variables.
save_path: String. Prefix of filenames created for the checkpoint.
global_step: If provided the global step number is appended to
`save_path` to create the checkpoint filenames. The optional argument
can be a `Tensor`, a `Tensor` name or an integer.
global_step: If provided the global step number is appended to `save_path`
to create the checkpoint filenames. The optional argument can be a
`Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoints. That file,
kept in the same directory as the checkpoint files, is automatically
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
contains the list of most recent checkpoints. That file, kept in the
same directory as the checkpoint files, is automatically managed by the
saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
write_meta_graph: `Boolean` indicating whether or not to write the meta
graph file.
@ -1107,7 +1101,8 @@ class Saver(object):
`CheckpointStateProto`.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
[Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of save_path and with `_debug` added before
the file extension. This is only enabled when `write_meta_graph` is
@ -1151,8 +1146,7 @@ class Saver(object):
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
else:
checkpoint_file = save_path
if os.path.basename(
save_path) == latest_filename and not self._sharded:
if os.path.basename(save_path) == latest_filename and not self._sharded:
# Guard against collision between data file and checkpoint state file.
raise ValueError(
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
@ -1197,7 +1191,8 @@ class Saver(object):
if not context.executing_eagerly():
with sess.graph.as_default():
self.export_meta_graph(
meta_graph_filename, strip_default_attrs=strip_default_attrs,
meta_graph_filename,
strip_default_attrs=strip_default_attrs,
save_debug_info=save_debug_info)
if self._is_empty:
@ -1225,11 +1220,12 @@ class Saver(object):
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with this Saver.
graph (both Save/Restore ops and SaverDefs) that are not associated with
this Saver.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
[Stripping Default-Valued
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before
the file extension.
@ -1274,8 +1270,8 @@ class Saver(object):
raise ValueError("Can't load save_path when it is None.")
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
raise ValueError("The passed save_path is not a valid checkpoint: "
+ compat.as_text(save_path))
raise ValueError("The passed save_path is not a valid checkpoint: " +
compat.as_text(save_path))
logging.info("Restoring parameters from %s", compat.as_text(save_path))
try:
@ -1330,13 +1326,15 @@ class Saver(object):
key: One of the GraphKeys or user-defined string.
export_scope: Optional `string`. Name scope to remove.
"""
meta_graph.add_collection_def(meta_graph_def, key,
export_scope=export_scope)
meta_graph.add_collection_def(
meta_graph_def, key, export_scope=export_scope)
@tf_export(v1=["train.import_meta_graph"])
def import_meta_graph(meta_graph_or_file, clear_devices=False,
import_scope=None, **kwargs):
def import_meta_graph(meta_graph_or_file,
clear_devices=False,
import_scope=None,
**kwargs):
"""Recreates a Graph saved in a `MetaGraphDef` proto.
This function takes a `MetaGraphDef` protocol buffer as input. If
@ -1358,10 +1356,10 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python
...
# Create a saver.
saver = tf.train.Saver(...variables...)
saver = tf.compat.v1.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
tf.compat.v1.add_to_collection('train_op', train_op)
sess = tf.compat.v1.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
@ -1374,12 +1372,13 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
the model from scratch.
```Python
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
with tf.compat.v1.Session() as sess:
new_saver =
tf.compat.v1.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() returns a list. In this example we only want the
# first one.
train_op = tf.get_collection('train_op')[0]
# tf.compat.v1.get_collection() returns a list. In this example we only want
# the first one.
train_op = tf.compat.v1.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)
```
@ -1393,14 +1392,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python
# Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v1 = tf.compat.v1.placeholder(tf.float32, name="v1")
v2 = tf.compat.v1.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.compat.v1.train.Saver([vx])
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
@ -1411,8 +1410,8 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
```Python
# Restoring variables and running operations.
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver = tf.compat.v1.train.import_meta_graph("./model_ex1.meta")
sess = tf.compat.v1.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)
@ -1441,13 +1440,16 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
execution is enabled.
@end_compatibility
""" # pylint: disable=g-doc-exception
return _import_meta_graph_with_return_elements(
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
return _import_meta_graph_with_return_elements(meta_graph_or_file,
clear_devices, import_scope,
**kwargs)[0]
def _import_meta_graph_with_return_elements(
meta_graph_or_file, clear_devices=False, import_scope=None,
return_elements=None, **kwargs):
def _import_meta_graph_with_return_elements(meta_graph_or_file,
clear_devices=False,
import_scope=None,
return_elements=None,
**kwargs):
"""Import MetaGraph, and return both a saver and returned elements."""
if context.executing_eagerly():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
@ -1466,13 +1468,13 @@ def _import_meta_graph_with_return_elements(
return_elements=return_elements,
**kwargs))
saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars)
saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
imported_vars)
return saver, imported_return_elements
def _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars):
def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
imported_vars):
"""Return a saver for restoring variable values to an imported MetaGraph."""
if meta_graph_def.HasField("saver_def"):
# Infer the scope that is prepended by `import_scoped_meta_graph`.
@ -1510,7 +1512,9 @@ def export_meta_graph(filename=None,
save_debug_info=False,
**kwargs):
# pylint: disable=line-too-long
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
"""Returns `MetaGraphDef` proto.
Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
`MetaGraphDef` protocol buffer with the intention of it being imported
@ -1518,29 +1522,29 @@ def export_meta_graph(filename=None,
a subgraph.
Args:
filename: Optional filename including the path for writing the
generated `MetaGraphDef` protocol buffer.
filename: Optional filename including the path for writing the generated
`MetaGraphDef` protocol buffer.
meta_info_def: `MetaInfoDef` protocol buffer.
graph_def: `GraphDef` protocol buffer.
saver_def: `SaverDef` protocol buffer.
collection_list: List of string keys to collect.
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
graph: The `Graph` to export. If `None`, use the default graph.
export_scope: Optional `string`. Name scope under which to extract
the subgraph. The scope name will be striped from the node definitions
for easy import later into new name scopes. If `None`, the whole graph
is exported. graph_def and export_scope cannot both be specified.
export_scope: Optional `string`. Name scope under which to extract the
subgraph. The scope name will be striped from the node definitions for
easy import later into new name scopes. If `None`, the whole graph is
exported. graph_def and export_scope cannot both be specified.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
clear_extraneous_savers: Remove any Saver-related information from the graph
(both Save/Restore ops and SaverDefs) that are not associated with the
provided SaverDef.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
which in the same directory of filename and with `_debug` added before
the file extend.
which in the same directory of filename and with `_debug` added before the
file extend.
**kwargs: Optional keyed arguments.
Returns:
@ -1603,10 +1607,8 @@ def object_graph_key_mapping(checkpoint_path):
Dictionary mapping tensor names to checkpoint keys.
"""
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
object_graph_string = reader.get_tensor(
trackable.OBJECT_GRAPH_PROTO_KEY)
object_graph_proto = (
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
names_to_keys = {}
for node in object_graph_proto.nodes:
@ -1615,9 +1617,11 @@ def object_graph_key_mapping(checkpoint_path):
return names_to_keys
def saver_from_object_based_checkpoint(
checkpoint_path, var_list=None, builder=None, names_to_keys=None,
cached_saver=None):
def saver_from_object_based_checkpoint(checkpoint_path,
var_list=None,
builder=None,
names_to_keys=None,
cached_saver=None):
"""Return a `Saver` which reads from an object-based checkpoint.
This function validates that all variables in the variables list are remapped
@ -1659,8 +1663,8 @@ def saver_from_object_based_checkpoint(
try:
names_to_keys = object_graph_key_mapping(checkpoint_path)
except errors.NotFoundError:
raise ValueError("Checkpoint in %s not an object-based checkpoint."
% checkpoint_path)
raise ValueError("Checkpoint in %s not an object-based checkpoint." %
checkpoint_path)
if var_list is None:
var_list = variables._all_saveable_objects() # pylint: disable=protected-access
if builder is None:
@ -1677,7 +1681,8 @@ def saver_from_object_based_checkpoint(
extra_names = previous_names - current_names
intersecting_names = previous_names.intersection(current_names)
raise errors.NotFoundError(
None, None,
None,
None,
message=(
"\n\nExisting variables not in the checkpoint: %s\n\n"
"Variables names when this checkpoint was written which don't "
@ -1695,9 +1700,9 @@ def saver_from_object_based_checkpoint(
"existed, and if variable names have changed you may need to "
"make this a dictionary with the old names as keys. If you're "
"using an Estimator, you'll need to return a tf.train.Saver "
"inside a tf.train.Scaffold from your model_fn.")
% (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)),
len(intersecting_names)))
"inside a tf.train.Scaffold from your model_fn.") %
(", ".join(sorted(missing_names)), ", ".join(
sorted(extra_names)), len(intersecting_names)))
for saveable in saveables:
for spec in saveable.specs:
spec.name = names_to_keys[spec.name]

View File

@ -32,21 +32,19 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
"""Creates a `tf.train.ServerDef` protocol buffer.
Args:
server_or_cluster_def: A `tf.train.ServerDef` or
`tf.train.ClusterDef` protocol buffer, or a
`tf.train.ClusterSpec` object, describing the server to be
defined and/or the cluster of which it is a member.
job_name: (Optional.) Specifies the name of the job of which the server
is a member. Defaults to the value in `server_or_cluster_def`, if
specified.
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
to be defined and/or the cluster of which it is a member.
job_name: (Optional.) Specifies the name of the job of which the server is a
member. Defaults to the value in `server_or_cluster_def`, if specified.
task_index: (Optional.) Specifies the task index of the server in its job.
Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
defaults to 0 if the server's job has only one task.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
config: (Options.) A `tf.ConfigProto` that specifies default configuration
options for all sessions that run on this server.
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
Returns:
A `tf.train.ServerDef`.
@ -88,7 +86,9 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
server_def = tensorflow_server_pb2.ServerDef(
cluster=cluster_spec.as_cluster_def(),
job_name=job_name, task_index=task_index, protocol=protocol)
job_name=job_name,
task_index=task_index,
protocol=protocol)
if config is not None:
server_def.default_session_config.MergeFrom(config)
return server_def
@ -99,8 +99,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
class Server(object):
"""An in-process TensorFlow server, for use in distributed training.
A `tf.train.Server` instance encapsulates a set of devices and a
`tf.Session` target that
A `tf.distribute.Server` instance encapsulates a set of devices and a
`tf.compat.v1.Session` target that
can participate in distributed training. A server belongs to a
cluster (specified by a `tf.train.ClusterSpec`), and
corresponds to a particular task in a named job. The server can
@ -120,31 +120,30 @@ class Server(object):
override any information provided in `server_or_cluster_def`.
Args:
server_or_cluster_def: A `tf.train.ServerDef` or
`tf.train.ClusterDef` protocol buffer, or a
`tf.train.ClusterSpec` object, describing the server to be
created and/or the cluster of which it is a member.
job_name: (Optional.) Specifies the name of the job of which the server
is a member. Defaults to the value in `server_or_cluster_def`, if
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
protocol buffer, or a `tf.train.ClusterSpec` object, describing the
server to be created and/or the cluster of which it is a member.
job_name: (Optional.) Specifies the name of the job of which the server is
a member. Defaults to the value in `server_or_cluster_def`, if
specified.
task_index: (Optional.) Specifies the task index of the server in its
job. Defaults to the value in `server_or_cluster_def`, if specified.
task_index: (Optional.) Specifies the task index of the server in its job.
Defaults to the value in `server_or_cluster_def`, if specified.
Otherwise defaults to 0 if the server's job has only one task.
protocol: (Optional.) Specifies the protocol to be used by the server.
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the
value in `server_or_cluster_def`, if specified. Otherwise defaults to
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
in `server_or_cluster_def`, if specified. Otherwise defaults to
`"grpc"`.
config: (Options.) A `tf.ConfigProto` that specifies default
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server
after creating it. Defaults to `True`.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
creating the TensorFlow server.
"""
self._server_def = _make_server_def(server_or_cluster_def,
job_name, task_index, protocol, config)
self._server_def = _make_server_def(server_or_cluster_def, job_name,
task_index, protocol, config)
self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
if start:
self.start()
@ -195,15 +194,15 @@ class Server(object):
@property
def target(self):
"""Returns the target for a `tf.Session` to connect to this server.
"""Returns the target for a `tf.compat.v1.Session` to connect to this server.
To create a
`tf.Session` that
`tf.compat.v1.Session` that
connects to this server, use the following snippet:
```python
server = tf.train.Server(...)
with tf.Session(server.target):
server = tf.distribute.Server(...)
with tf.compat.v1.Session(server.target):
# ...
```
@ -217,22 +216,24 @@ class Server(object):
"""Creates a new single-process cluster running on the local host.
This method is a convenience wrapper for creating a
`tf.train.Server` with a `tf.train.ServerDef` that specifies a
`tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
single-process cluster containing a single task in a job called
`"local"`.
Args:
config: (Options.) A `tf.ConfigProto` that specifies default
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
configuration options for all sessions that run on this server.
start: (Optional.) Boolean, indicating whether to start the server after
creating it. Defaults to `True`.
Returns:
A local `tf.train.Server`.
A local `tf.distribute.Server`.
"""
# Specifying port 0 means that the OS will choose a free port for the
# server.
return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
return Server({"local": ["localhost:0"]},
protocol="grpc",
config=config,
start=start)
@ -242,7 +243,7 @@ class ClusterSpec(object):
A `tf.train.ClusterSpec` represents the set of processes that
participate in a distributed TensorFlow computation. Every
`tf.train.Server` is constructed in a particular cluster.
`tf.distribute.Server` is constructed in a particular cluster.
To create a cluster with two jobs and five tasks, you specify the
mapping from job names to lists of network addresses (typically
@ -272,10 +273,9 @@ class ClusterSpec(object):
"""Creates a `ClusterSpec`.
Args:
cluster: A dictionary mapping one or more job names to (i) a
list of network addresses, or (ii) a dictionary mapping integer
task indices to network addresses; or a `tf.train.ClusterDef`
protocol buffer.
cluster: A dictionary mapping one or more job names to (i) a list of
network addresses, or (ii) a dictionary mapping integer task indices to
network addresses; or a `tf.train.ClusterDef` protocol buffer.
Raises:
TypeError: If `cluster` is not a dictionary mapping strings to lists
@ -298,14 +298,16 @@ class ClusterSpec(object):
self._cluster_spec = {}
for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()}
i: t for i, t in job_def.tasks.items()
}
elif isinstance(cluster, ClusterSpec):
self._cluster_def = cluster_pb2.ClusterDef()
self._cluster_def.MergeFrom(cluster.as_cluster_def())
self._cluster_spec = {}
for job_def in self._cluster_def.job:
self._cluster_spec[job_def.name] = {
i: t for i, t in job_def.tasks.items()}
i: t for i, t in job_def.tasks.items()
}
else:
raise TypeError("`cluster` must be a dictionary mapping one or more "
"job names to lists of network addresses, or a "
@ -326,7 +328,8 @@ class ClusterSpec(object):
def __str__(self):
key_values = self.as_dict()
string_items = [
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
]
return "ClusterSpec({" + ", ".join(string_items) + "})"
def as_dict(self):
@ -427,8 +430,8 @@ class ClusterSpec(object):
try:
return job[task_index]
except KeyError:
raise ValueError("No task with index %r in job %r"
% (task_index, job_name))
raise ValueError("No task with index %r in job %r" %
(task_index, job_name))
def job_tasks(self, job_name):
"""Returns a mapping from task ID to address in the given job.
@ -482,6 +485,6 @@ class ClusterSpec(object):
try:
task_address = compat.as_bytes(task_address)
except TypeError:
raise TypeError(
"Task address %r must be bytes or unicode" % task_address)
raise TypeError("Task address %r must be bytes or unicode" %
task_address)
job_def.tasks[i] = task_address

View File

@ -107,7 +107,7 @@ class GrpcServerTest(test.TestCase):
self.assertAllEqual(2.0, sess.run(v1))
def _useRPCConfig(self):
"""Return a `tf.ConfigProto` that ensures we use the RPC stack for tests.
"""Return a `tf.compat.v1.ConfigProto` that ensures we use the RPC stack for tests.
This configuration ensures that we continue to exercise the gRPC
stack when testing, rather than using the in-process optimization,
@ -115,7 +115,7 @@ class GrpcServerTest(test.TestCase):
master in the same process.
Returns:
A `tf.ConfigProto`.
A `tf.compat.v1.ConfigProto`.
"""
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
use_rpc_for_inprocess_master=True))

View File

@ -174,8 +174,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
saver.save(sess, os.path.join(checkpoint_dir,
"recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
@ -202,9 +202,9 @@ class SessionManagerTest(test.TestCase):
def testInitWithNoneLocalInitOpError(self):
# Creating a SessionManager with a None local_init_op but
# non-None ready_for_local_init_op raises ValueError
with self.assertRaisesRegexp(ValueError,
"If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "):
with self.assertRaisesRegexp(
ValueError, "If you pass a ready_for_local_init_op "
"you must also pass a local_init_op "):
session_manager.SessionManager(
ready_for_local_init_op=variables.report_uninitialized_variables(
variables.global_variables()),
@ -231,8 +231,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
saver.save(sess, os.path.join(checkpoint_dir,
"recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
v = variables.VariableV1(2, name="v")
@ -266,7 +266,7 @@ class SessionManagerTest(test.TestCase):
@test_util.run_v1_only("b/120545219")
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
# We use ready_for_local_init_op=tf.report_uninitialized_variables(),
# We use ready_for_local_init_op=report_uninitialized_variables(),
# which causes recover_session to not run local_init_op, and to return
# initialized=False
@ -290,8 +290,8 @@ class SessionManagerTest(test.TestCase):
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
saver.save(sess, os.path.join(checkpoint_dir,
"recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
v = variables.VariableV1(2, name="v")
@ -780,8 +780,8 @@ class ObsoleteSessionManagerTest(test.TestCase):
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
saver.save(sess, os.path.join(checkpoint_dir,
"recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
v = variables.VariableV1(2, name="v")

View File

@ -68,7 +68,7 @@ look at following code:
Above user code leads to following execution:
call hooks.begin()
sess = tf.Session()
sess = tf.compat.v1.Session()
call hooks.after_create_session()
while not stop is requested:
call hooks.before_run()

View File

@ -56,9 +56,9 @@ class SummaryWriter(_FileWriter):
```python
...create a graph...
# Launch the graph in a session.
sess = tf.Session()
sess = tf.compat.v1.Session()
# Create a summary writer, add the 'graph' to the event file.
writer = tf.summary.FileWriter(<some-directory>, sess.graph)
writer = tf.compat.v1.summary.FileWriter(<some-directory>, sess.graph)
```
The other arguments to the constructor control the asynchronous writes to

View File

@ -45,7 +45,7 @@ class Supervisor(object):
"""A training helper that checkpoints models and computes summaries.
This class is deprecated. Please use
`tf.train.MonitoredTrainingSession` instead.
`tf.compat.v1.train.MonitoredTrainingSession` instead.
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
and a `SessionManager` that takes care of common needs of TensorFlow
@ -97,7 +97,7 @@ class Supervisor(object):
# or job_def.name, or job_def.tasks. It's entirely up to the end user.
# But there can be only one *chief*.
is_chief = (server_def.task_index == 0)
server = tf.train.Server(server_def)
server = tf.distribute.Server(server_def)
with tf.Graph().as_default():
...add operations to the graph...
@ -140,7 +140,7 @@ class Supervisor(object):
* Specifying `'grpc://hostname:port'` requests a session that uses
the RPC interface to a specific host, and also allows the in-process
master to access remote tensorflow workers. Often, it is
appropriate to pass `server.target` (for some `tf.train.Server`
appropriate to pass `server.target` (for some `tf.distribute.Server`
named `server).
#### Advanced use
@ -237,17 +237,16 @@ class Supervisor(object):
ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
`prepare_or_wait_for_session()` to check if the model is ready to use.
The model is considered ready if it returns an empty array. Defaults to
the tensor returned from `tf.report_uninitialized_variables()` If
`None`, the model is not checked for readiness.
the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
If `None`, the model is not checked for readiness.
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
supervisors in `prepare_or_wait_for_session()` to check if the model is
ready to run the local_init_op.
The model is considered ready if it returns an empty array. Defaults to
`None`. If `None`, the model is not checked for readiness before running
local_init_op.
is_chief: If True, create a chief supervisor in charge of initializing
and restoring the model. If False, create a supervisor that relies
on a chief supervisor for inits and restore.
ready to run the local_init_op. The model is considered ready if it
returns an empty array. Defaults to `None`. If `None`, the model is not
checked for readiness before running local_init_op.
is_chief: If True, create a chief supervisor in charge of initializing and
restoring the model. If False, create a supervisor that relies on a
chief supervisor for inits and restore.
init_op: `Operation`. Used by chief supervisors to initialize the model
when it can not be recovered. Defaults to an `Operation` that
initializes all global variables. If `None`, no initialization is done
@ -255,20 +254,19 @@ class Supervisor(object):
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
This feed dictionary will be used when `init_op` is evaluated.
local_init_op: `Operation`. Used by all supervisors to run initializations
that should run for every new supervisor instance. By default these
are table initializers and initializers for local variables.
If `None`, no further per supervisor-instance initialization is
done automatically.
that should run for every new supervisor instance. By default these are
table initializers and initializers for local variables. If `None`, no
further per supervisor-instance initialization is done automatically.
logdir: A string. Optional path to a directory where to checkpoint the
model and log events for the visualizer. Used by chief supervisors.
The directory will be created if it does not exist.
summary_op: An `Operation` that returns a Summary for the event logs.
Used by chief supervisors if a `logdir` was specified. Defaults to the
model and log events for the visualizer. Used by chief supervisors. The
directory will be created if it does not exist.
summary_op: An `Operation` that returns a Summary for the event logs. Used
by chief supervisors if a `logdir` was specified. Defaults to the
operation returned from summary.merge_all(). If `None`, summaries are
not computed automatically.
saver: A Saver object. Used by chief supervisors if a `logdir` was
specified. Defaults to the saved returned by Saver().
If `None`, the model is not saved automatically.
specified. Defaults to the saved returned by Saver(). If `None`, the
model is not saved automatically.
global_step: An integer Tensor of size 1 that counts steps. The value
from 'global_step' is used in summaries and checkpoint filenames.
Default to the op named 'global_step' in the graph if it exists, is of
@ -280,20 +278,20 @@ class Supervisor(object):
disable summaries.
save_model_secs: Number of seconds between the creation of model
checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
recovery_wait_secs: Number of seconds between checks that the model
is ready. Used by supervisors when waiting for a chief supervisor
to initialize or restore the model. Defaults to 30 seconds.
recovery_wait_secs: Number of seconds between checks that the model is
ready. Used by supervisors when waiting for a chief supervisor to
initialize or restore the model. Defaults to 30 seconds.
stop_grace_secs: Grace period, in seconds, given to running threads to
stop when `stop()` is called. Defaults to 120 seconds.
checkpoint_basename: The basename for checkpoint saving.
session_manager: `SessionManager`, which manages Session creation and
recovery. If it is `None`, a default `SessionManager` will be created
with the set of arguments passed in for backwards compatibility.
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None`
to indicate that no summaries should be written.
init_fn: Optional callable used to initialize the model. Called
after the optional `init_op` is called. The callable must accept one
argument, the session being initialized.
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to
indicate that no summaries should be written.
init_fn: Optional callable used to initialize the model. Called after the
optional `init_op` is called. The callable must accept one argument,
the session being initialized.
local_init_run_options: RunOptions to be passed as the SessionManager
local_init_run_options parameter.
@ -397,12 +395,11 @@ class Supervisor(object):
"""Initializes ready_op.
Args:
ready_op: `Tensor` to check if the model is initialized.
If it's set to USE_DEFAULT, creates an op that checks all
the variables are initialized.
ready_op: `Tensor` to check if the model is initialized. If it's set to
USE_DEFAULT, creates an op that checks all the variables are
initialized.
ready_for_local_init_op: `Tensor` to check if the model is ready to run
local_init_op.
If it's set to USE_DEFAULT, creates an op that checks all
local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
the global variables are initialized.
"""
if ready_op is Supervisor.USE_DEFAULT:
@ -440,9 +437,9 @@ class Supervisor(object):
Args:
local_init_op: `Operation` run for every new supervisor instance. If set
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
collection. If the collection is empty, create an op that initializes
all local variables and all tables.
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
collection. If the collection is empty, create an op that initializes
all local variables and all tables.
"""
if local_init_op is Supervisor.USE_DEFAULT:
local_init_op = self._get_first_op_from_collection(
@ -461,8 +458,8 @@ class Supervisor(object):
"""Initializes saver.
Args:
saver: A `Saver` object. If set to USE_DEFAULT, create one that
saves all the variables.
saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
the variables.
"""
if saver is Supervisor.USE_DEFAULT:
saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
@ -475,8 +472,8 @@ class Supervisor(object):
"""Initializes summary_op.
Args:
summary_op: An Operation that returns a Summary for the event logs.
If set to USE_DEFAULT, create an op that merges all the summaries.
summary_op: An Operation that returns a Summary for the event logs. If set
to USE_DEFAULT, create an op that merges all the summaries.
"""
if summary_op is Supervisor.USE_DEFAULT:
summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
@ -490,8 +487,8 @@ class Supervisor(object):
"""Initializes global_step.
Args:
global_step: An integer Tensor of size 1 that counts steps. If
set to USE_DEFAULT, creates global_step tensor.
global_step: An integer Tensor of size 1 that counts steps. If set to
USE_DEFAULT, creates global_step tensor.
"""
if global_step is Supervisor.USE_DEFAULT:
global_step = self._get_first_op_from_collection(
@ -630,8 +627,9 @@ class Supervisor(object):
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
assert self._is_chief
if self._logdir:
training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
self._logdir, "graph.pbtxt")
training_util.write_graph(
self._graph.as_graph_def(add_shapes=True), self._logdir,
"graph.pbtxt")
if self._summary_writer and not self._graph_added_to_summary:
self._summary_writer.add_graph(self._graph)
self._summary_writer.add_meta_graph(self._meta_graph_def)
@ -675,8 +673,7 @@ class Supervisor(object):
# if there is no step value.
current_step = training_util.global_step(sess, self._global_step)
self._summary_writer.add_session_log(
SessionLog(status=SessionLog.START),
current_step)
SessionLog(status=SessionLog.START), current_step)
threads = []
if self._save_summaries_secs and self._summary_writer:
@ -690,7 +687,9 @@ class Supervisor(object):
t.start()
return threads
def prepare_or_wait_for_session(self, master="", config=None,
def prepare_or_wait_for_session(self,
master="",
config=None,
wait_for_checkpoint=False,
max_wait_secs=7200,
start_standard_services=True):
@ -702,10 +701,10 @@ class Supervisor(object):
manager to start the standard services.
Args:
master: name of the TensorFlow master to use. See the `tf.Session`
constructor for how this is interpreted.
config: Optional ConfigProto proto used to configure the session,
which is passed as-is to create the session.
master: name of the TensorFlow master to use. See the
`tf.compat.v1.Session` constructor for how this is interpreted.
config: Optional ConfigProto proto used to configure the session, which is
passed as-is to create the session.
wait_for_checkpoint: Whether we should wait for the availability of a
checkpoint before creating Session. Defaults to False.
max_wait_secs: Maximum time to wait for the session to become available.
@ -724,18 +723,22 @@ class Supervisor(object):
if self._is_chief:
sess = self._session_manager.prepare_session(
master, init_op=self.init_op, saver=self.saver,
checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs, config=config,
init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
master,
init_op=self.init_op,
saver=self.saver,
checkpoint_dir=self._logdir,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config,
init_feed_dict=self._init_feed_dict,
init_fn=self._init_fn)
self._write_graph()
if start_standard_services:
logging.info("Starting standard services.")
self.start_standard_services(sess)
else:
sess = self._session_manager.wait_for_session(master,
config=config,
max_wait_secs=max_wait_secs)
sess = self._session_manager.wait_for_session(
master, config=config, max_wait_secs=max_wait_secs)
if start_standard_services:
logging.info("Starting queue runners.")
self.start_queue_runners(sess)
@ -772,8 +775,8 @@ class Supervisor(object):
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
threads = []
for qr in queue_runners:
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
start=True))
threads.extend(
qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
return threads
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
@ -795,8 +798,12 @@ class Supervisor(object):
Returns:
The started thread.
"""
looper = coordinator.LooperThread(self._coord, timer_interval_secs,
target=target, args=args, kwargs=kwargs)
looper = coordinator.LooperThread(
self._coord,
timer_interval_secs,
target=target,
args=args,
kwargs=kwargs)
looper.start()
return looper
@ -812,13 +819,13 @@ class Supervisor(object):
threads: Optional list of threads to join with the coordinator. If
`None`, defaults to the threads running the standard services, the
threads started for `QueueRunners`, and the threads started by the
`loop()` method. To wait on additional threads, pass the
list in this parameter.
`loop()` method. To wait on additional threads, pass the list in this
parameter.
close_summary_writer: Whether to close the `summary_writer`. Defaults to
`True` if the summary writer was created by the supervisor, `False`
otherwise.
ignore_live_threads: If `True` ignores threads that remain running after
a grace period when joining threads via the coordinator, instead of
ignore_live_threads: If `True` ignores threads that remain running after a
grace period when joining threads via the coordinator, instead of
raising a RuntimeError.
"""
self._coord.request_stop()
@ -926,7 +933,9 @@ class Supervisor(object):
# pylint: disable=g-doc-return-or-yield,broad-except
@contextlib.contextmanager
def managed_session(self, master="", config=None,
def managed_session(self,
master="",
config=None,
start_standard_services=True,
close_summary_writer=True):
"""Returns a context manager for a managed session.
@ -940,7 +949,7 @@ class Supervisor(object):
```python
def train():
sv = tf.train.Supervisor(...)
sv = tf.compat.v1.train.Supervisor(...)
with sv.managed_session(<master>) as sess:
for step in xrange(..):
if sv.should_stop():
@ -973,14 +982,14 @@ class Supervisor(object):
the training loop and are considered normal termination.
Args:
master: name of the TensorFlow master to use. See the `tf.Session`
constructor for how this is interpreted.
config: Optional `ConfigProto` proto used to configure the session.
Passed as-is to create the session.
start_standard_services: Whether to start the standard services,
such as checkpoint, summary and step counter.
close_summary_writer: Whether to close the summary writer when
closing the session. Defaults to True.
master: name of the TensorFlow master to use. See the
`tf.compat.v1.Session` constructor for how this is interpreted.
config: Optional `ConfigProto` proto used to configure the session. Passed
as-is to create the session.
start_standard_services: Whether to start the standard services, such as
checkpoint, summary and step counter.
close_summary_writer: Whether to close the summary writer when closing the
session. Defaults to True.
Returns:
A context manager that yields a `Session` restored from the latest
@ -989,7 +998,8 @@ class Supervisor(object):
"""
try:
sess = self.prepare_or_wait_for_session(
master=master, config=config,
master=master,
config=config,
start_standard_services=start_standard_services)
yield sess
except Exception as e:
@ -1011,6 +1021,7 @@ class Supervisor(object):
except Exception:
# Silently ignore exceptions raised by close().
pass
# pylint: enable=g-doc-return-or-yield,broad-except
@ -1030,8 +1041,8 @@ class SVSummaryThread(coordinator.LooperThread):
def run_loop(self):
if self._sv.global_step is not None:
summary_strs, global_step = self._sess.run([self._sv.summary_op,
self._sv.global_step])
summary_strs, global_step = self._sess.run(
[self._sv.summary_op, self._sv.global_step])
else:
summary_strs = self._sess.run(self._sv.summary_op)
global_step = None
@ -1063,8 +1074,7 @@ class SVStepCounterThread(coordinator.LooperThread):
def start_loop(self):
self._last_time = time.time()
self._last_step = training_util.global_step(
self._sess, self._step_counter)
self._last_step = training_util.global_step(self._sess, self._step_counter)
def run_loop(self):
# Count the steps.
@ -1080,12 +1090,13 @@ class SVStepCounterThread(coordinator.LooperThread):
steps_per_sec = added_steps / elapsed_time
else:
steps_per_sec = float("inf")
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
simple_value=steps_per_sec)])
summary = Summary(value=[
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
])
if self._sv.summary_writer:
self._sv.summary_writer.add_summary(summary, current_step)
logging.log_first_n(logging.INFO, "%s: %g", 10,
self._summary_tag, steps_per_sec)
logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
steps_per_sec)
class SVTimerCheckpointThread(coordinator.LooperThread):
@ -1104,13 +1115,13 @@ class SVTimerCheckpointThread(coordinator.LooperThread):
def run_loop(self):
logging.info("Saving checkpoint to path %s", self._sv.save_path)
self._sv.saver.save(self._sess, self._sv.save_path,
global_step=self._sv.global_step)
self._sv.saver.save(
self._sess, self._sv.save_path, global_step=self._sv.global_step)
if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT,
checkpoint_path=self._sv.save_path),
SessionLog(
status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
current_step)

View File

@ -110,7 +110,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
# Note that if you want to have 2 backup replicas, you can change
# total_num_replicas=52 and make sure this number matches how many physical
# replicas you started in your job.
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
total_num_replicas=50)
# Some models have startup_delays to help stabilize the model but when using

View File

@ -38,11 +38,9 @@ from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
# A key indicating a variable's value in an object's checkpointed Tensors
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
# the object has no dependencies, then its value may be restored on object
@ -74,8 +72,7 @@ class CheckpointInitialValue(ops.Tensor):
"""
def __init__(self, checkpoint_position, shape=None):
self.wrapped_value = checkpoint_position.value_tensors()[
VARIABLE_VALUE_KEY]
self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY]
if shape:
# We need to set the static shape information on the initializer if
# possible so we don't get a variable with an unknown shape.
@ -97,8 +94,8 @@ class NoRestoreSaveable(saveable_object.SaveableObject):
"""Embeds a tensor in a checkpoint with no restore ops."""
def __init__(self, tensor, name, dtype=None, device=None):
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype,
device=device)
spec = saveable_object.SaveSpec(
tensor, "", name, dtype=dtype, device=device)
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes):
@ -123,7 +120,8 @@ class PythonStateSaveable(saveable_object.SaveableObject):
"""Create a new `SaveableObject` which freezes current state as a constant.
Used when executing eagerly to embed the current state as a constant, or
when creating a static tf.train.Saver with the frozen current Python state.
when creating a static tf.compat.v1.train.Saver with the frozen current
Python state.
Returns:
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
@ -140,24 +138,26 @@ class PythonStringStateSaveable(PythonStateSaveable):
Args:
name: The checkpoint key to write to.
state_callback: A function taking no arguments which returns a
string. This function is run every time a checkpoint is written.
state_callback: A function taking no arguments which returns a string.
This function is run every time a checkpoint is written.
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing, in which case it is ignored
by status assertions such as assert_consumed().
"""
self._has_trivial_state_callback = (restore_callback is None)
def _state_callback_wrapper():
with ops.init_scope():
return state_callback()
self._state_callback = _state_callback_wrapper
self._restore_callback = restore_callback
with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
name)
@property
def optional_restore(self):
@ -170,8 +170,10 @@ class PythonStringStateSaveable(PythonStateSaveable):
def freeze(self):
"""Create a frozen `SaveableObject` which saves the current state."""
def _constant_state():
return constant_op.constant(self._state_callback(), dtype=dtypes.string)
return NoRestoreSaveable(
tensor=_constant_state,
dtype=dtypes.string,
@ -217,6 +219,7 @@ class CheckpointPosition(object):
Args:
trackable: The object to record a correspondence for.
Returns:
True if this is a new assignment, False if this object has already been
mapped to a checkpointed `Object` proto.
@ -263,21 +266,21 @@ class CheckpointPosition(object):
# consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not trackable:
logging.warning(
("Inconsistent references when loading the checkpoint into this "
"object graph. Either the Trackable object references in the "
"Python program have changed in an incompatible way, or the "
"checkpoint was generated in an incompatible program.\n\nTwo "
"checkpoint references resolved to different objects (%s and %s).")
% (current_assignment, trackable))
logging.warning((
"Inconsistent references when loading the checkpoint into this "
"object graph. Either the Trackable object references in the "
"Python program have changed in an incompatible way, or the "
"checkpoint was generated in an incompatible program.\n\nTwo "
"checkpoint references resolved to different objects (%s and %s)."),
current_assignment, trackable)
return False # Not a new assignment
def is_simple_variable(self):
"""Determine whether this value is restorable with a Tensor initializer."""
attributes = self.object_proto.attributes
return (len(attributes) == 1
and attributes[0].name == VARIABLE_VALUE_KEY
and not self.object_proto.children)
return (len(attributes) == 1 and
attributes[0].name == VARIABLE_VALUE_KEY and
not self.object_proto.children)
def value_tensors(self):
"""Create value `Tensor`s for this object's attributes.
@ -335,8 +338,9 @@ class CheckpointPosition(object):
# If we've already created and cached a SaveableObject for this
# attribute, we can re-use it to avoid re-creating some ops when graph
# building.
saveable_list = saveables_cache.get(
self.trackable, {}).get(serialized_tensor.name, (None,))
saveable_list = saveables_cache.get(self.trackable,
{}).get(serialized_tensor.name,
(None,))
if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject.
saveable, = saveable_list
@ -370,8 +374,8 @@ class CheckpointPosition(object):
else:
saveable = saveable_factory
if saveables_cache is not None:
saveables_cache.setdefault(
self.trackable, {})[serialized_tensor.name] = [saveable]
saveables_cache.setdefault(self.trackable,
{})[serialized_tensor.name] = [saveable]
if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
@ -388,11 +392,10 @@ class CheckpointPosition(object):
A list of operations when graph building, or an empty list when executing
eagerly.
"""
(restore_ops,
tensor_saveables,
(restore_ops, tensor_saveables,
python_saveables) = self._gather_ops_or_named_saveables()
restore_ops.extend(self._checkpoint.restore_saveables(
tensor_saveables, python_saveables))
restore_ops.extend(
self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
return restore_ops
@property
@ -416,13 +419,11 @@ class CheckpointPosition(object):
_DeferredSlotVariableRestoration = collections.namedtuple(
"_DeferredSlotVariableRestoration",
[
"_DeferredSlotVariableRestoration", [
"original_variable",
"slot_variable_id",
"slot_name",
]
)
])
_SlotVariableRestoration = collections.namedtuple(
"_SlotVariableRestoration",
@ -446,6 +447,7 @@ def no_automatic_dependency_tracking(method):
Args:
method: The method to decorate.
Returns:
A decorated method which sets and un-sets automatic dependency tracking for
the object the method is called on (not thread safe).
@ -595,16 +597,21 @@ class Trackable(object):
Args:
name: The local name of the dependency.
Returns:
A `Trackable` object, or `None` if no dependency by this name was
found.
"""
return self._self_unconditional_dependency_names.get(name, None)
def _add_variable_with_custom_getter(
self, name, shape=None, dtype=dtypes.float32,
initializer=None, getter=None, overwrite=False,
**kwargs_for_getter):
def _add_variable_with_custom_getter(self,
name,
shape=None,
dtype=dtypes.float32,
initializer=None,
getter=None,
overwrite=False,
**kwargs_for_getter):
"""Restore-on-create for a variable be saved with this `Trackable`.
If the user has requested that this object or another `Trackable` which
@ -640,11 +647,9 @@ class Trackable(object):
name=name, shape=shape)
else:
checkpoint_initializer = None
if (checkpoint_initializer is not None
and not (
isinstance(initializer, CheckpointInitialValue)
and (initializer.restore_uid
> checkpoint_initializer.restore_uid))):
if (checkpoint_initializer is not None and
not (isinstance(initializer, CheckpointInitialValue) and
(initializer.restore_uid > checkpoint_initializer.restore_uid))):
# If multiple Trackable objects are "creating" the same variable
# via the magic of custom getters, the one with the highest restore UID
# (the one called last) has to make the final initializer. If another
@ -654,7 +659,10 @@ class Trackable(object):
initializer = checkpoint_initializer
shape = None
new_variable = getter(
name=name, shape=shape, dtype=dtype, initializer=initializer,
name=name,
shape=shape,
dtype=dtype,
initializer=initializer,
**kwargs_for_getter)
# If we set an initializer and the variable processed it, tracking will not
@ -662,8 +670,7 @@ class Trackable(object):
# is a non-trivial restoration queued, it will handle that. This also
# handles slot variables.
if not overwrite or isinstance(new_variable, Trackable):
return self._track_trackable(new_variable, name=name,
overwrite=overwrite)
return self._track_trackable(new_variable, name=name, overwrite=overwrite)
else:
# TODO(allenl): Some variable types are not yet supported. Remove this
# fallback once all get_variable() return types are Trackable.
@ -681,6 +688,7 @@ class Trackable(object):
name: The object-local name of the dependency holding the variable's
value.
shape: The shape of the variable being loaded into.
Returns:
An callable for use as a variable's initializer/initial_value, or None if
one should not be set (either because there was no variable with this name
@ -718,8 +726,8 @@ class Trackable(object):
Args:
trackable: A `Trackable` which this object depends on.
name: A local name for `trackable`, used for loading checkpoints into
the correct objects.
name: A local name for `trackable`, used for loading checkpoints into the
correct objects.
overwrite: Boolean, whether silently replacing dependencies is OK. Used
for __setattr__, where throwing an error on attribute reassignment would
be inappropriate.
@ -734,13 +742,11 @@ class Trackable(object):
"""
self._maybe_initialize_trackable()
if not isinstance(trackable, Trackable):
raise TypeError(
("Trackable._track_trackable() passed type %s, not a "
"Trackable.") % (type(trackable),))
raise TypeError(("Trackable._track_trackable() passed type %s, not a "
"Trackable.") % (type(trackable),))
new_reference = TrackableReference(name=name, ref=trackable)
current_object = self._lookup_dependency(name)
if (current_object is not None
and current_object is not trackable):
if (current_object is not None and current_object is not trackable):
if not overwrite:
raise ValueError(
("Called Trackable._track_trackable() with name='%s', "
@ -755,8 +761,7 @@ class Trackable(object):
index] = new_reference
elif current_object is None:
self._self_unconditional_checkpoint_dependencies.append(new_reference)
self._handle_deferred_dependencies(
name=name, trackable=trackable)
self._handle_deferred_dependencies(name=name, trackable=trackable)
self._self_unconditional_dependency_names[name] = trackable
return trackable
@ -780,8 +785,7 @@ class Trackable(object):
Args:
name: The name of the dependency within this object (`self`), used to
match `trackable` with values saved in a checkpoint.
trackable: The Trackable object to restore (inheriting from
`Trackable`).
trackable: The Trackable object to restore (inheriting from `Trackable`).
"""
self._maybe_initialize_trackable()
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
@ -809,15 +813,15 @@ class Trackable(object):
restore_ops = []
while visit_queue:
current_position = visit_queue.popleft()
restore_ops.extend(nest.flatten(
current_position.trackable # pylint: disable=protected-access
._single_restoration_from_checkpoint_position(
checkpoint_position=current_position,
visit_queue=visit_queue)))
restore_ops.extend(
nest.flatten(current_position.trackable # pylint: disable=protected-access
._single_restoration_from_checkpoint_position(
checkpoint_position=current_position,
visit_queue=visit_queue)))
return restore_ops
def _single_restoration_from_checkpoint_position(
self, checkpoint_position, visit_queue):
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
visit_queue):
"""Restore this object, and either queue its dependencies or defer them."""
self._maybe_initialize_trackable()
checkpoint = checkpoint_position.checkpoint
@ -831,14 +835,13 @@ class Trackable(object):
restore_ops = ()
for child in checkpoint_position.object_proto.children:
child_position = CheckpointPosition(
checkpoint=checkpoint,
proto_id=child.node_id)
checkpoint=checkpoint, proto_id=child.node_id)
local_object = self._lookup_dependency(child.local_name)
if local_object is None:
# We don't yet have a dependency registered with this name. Save it
# in case we do.
self._deferred_dependencies.setdefault(child.local_name, []).append(
child_position)
self._deferred_dependencies.setdefault(child.local_name,
[]).append(child_position)
else:
if child_position.bind_object(trackable=local_object):
# This object's correspondence is new, so dependencies need to be
@ -853,7 +856,8 @@ class Trackable(object):
Keys in the returned dictionary are local to this object and in a separate
namespace from dependencies. Values may either be `SaveableObject` factories
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
or variables easily converted to `SaveableObject`s (as in
`tf.compat.v1.train.Saver`'s
`var_list` constructor argument).
`SaveableObjects` have a name set, which Trackable needs to generate
@ -861,7 +865,8 @@ class Trackable(object):
should return a dictionary of callables which take `name` arguments and
return `SaveableObjects` with that name.
If this object may also be passed to the global-name-based `tf.train.Saver`,
If this object may also be passed to the global-name-based
`tf.compat.v1.train.Saver`,
the returned callables should have a default value for their name argument
(i.e. be callable with no arguments).
@ -884,6 +889,7 @@ class Trackable(object):
except NotImplementedError:
return {}
weak_self = weakref.ref(self)
def _state_callback():
"""Serializes `self.get_config()` for saving."""
dereferenced_self = weak_self()
@ -898,9 +904,12 @@ class Trackable(object):
return ""
else:
return ""
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
return {
OBJECT_CONFIG_JSON_KEY:
functools.partial(
PythonStringStateSaveable, state_callback=_state_callback)
}
def _list_functions_for_serialization(self):
"""Lists the functions of this trackable to serialize.

View File

@ -61,8 +61,8 @@ class _CheckpointRestoreCoordinator(object):
"""Specify the checkpoint being loaded.
Args:
object_graph_proto: The TrackableObjectGraph protocol buffer
associated with this checkpoint.
object_graph_proto: The TrackableObjectGraph protocol buffer associated
with this checkpoint.
save_path: A string, the path to the checkpoint, as returned by
`tf.train.latest_checkpoint`.
save_path_tensor: A string `Tensor` which contains or will be fed the save
@ -142,12 +142,10 @@ class _CheckpointRestoreCoordinator(object):
"""
restore_ops = []
# Eagerly run restorations for Python state.
reader = pywrap_tensorflow.NewCheckpointReader(
self.save_path_string)
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
for saveable in python_saveables:
spec_names = [spec.name for spec in saveable.specs]
saveable.python_restore(
[reader.get_tensor(name) for name in spec_names])
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
# If we have new SaveableObjects, extract and cache restore ops.
if tensor_saveables:
@ -205,14 +203,13 @@ class _NameBasedRestoreCoordinator(object):
# whether it's optional to restore it. If it's optional we don't need
# to make assertions fail.
if not saveable_factory("").optional_restore:
self.unused_attributes.setdefault(trackable, []).append(
attribute_name)
self.unused_attributes.setdefault(trackable,
[]).append(attribute_name)
continue
else:
saveable = saveable_factory
names_to_saveables = saveable_object_util.op_list_to_dict(
[saveable],
convert_variable_to_tensor=False)
[saveable], convert_variable_to_tensor=False)
for name, op in names_to_saveables.items():
for saveable_object in saveable_object_util.saveable_objects_for_op(
op=op, name=name):
@ -224,8 +221,7 @@ class _NameBasedRestoreCoordinator(object):
# run_restore_ops/initialize_or_restore on the status object for name-based
# checkpoints.
assert context.executing_eagerly()
for saveable in self.globally_named_object_attributes(
trackable):
for saveable in self.globally_named_object_attributes(trackable):
restored_tensors = []
tensor_missing = False
for spec in saveable.specs:
@ -248,14 +244,18 @@ class _NameBasedRestoreCoordinator(object):
# Ignores values missing from the checkpoint, as with object-based
# restore. Status assertions can be used to check exact matches,
# although it's unlikely to ever happen for name-based checkpoints.
saveable.restore(restored_tensors=restored_tensors,
restored_shapes=None)
saveable.restore(
restored_tensors=restored_tensors, restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
# or consolidating the implementation with get_variable.
def _default_getter(name, shape, dtype, initializer=None,
partition_info=None, **kwargs):
def _default_getter(name,
shape,
dtype,
initializer=None,
partition_info=None,
**kwargs):
"""A pared-down version of get_variable which does not reuse variables."""
dtype = dtypes.as_dtype(dtype)
shape_object = tensor_shape.as_shape(shape)
@ -263,7 +263,9 @@ def _default_getter(name, shape, dtype, initializer=None,
if initializer is None:
initializer, initializing_from_value = (
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
name=name, shape=shape_object, dtype=dtype))
name=name,
shape=shape_object,
dtype=dtype))
else:
initializing_from_value = not callable(initializer)
# Same logic as get_variable
@ -276,24 +278,33 @@ def _default_getter(name, shape, dtype, initializer=None,
# Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype)
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
return variables.VariableV1(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
use_resource=True,
**kwargs
)
**kwargs)
def add_variable(trackable, name, shape=None, dtype=dtypes.float32,
initializer=None, trainable=True):
def add_variable(trackable,
name,
shape=None,
dtype=dtypes.float32,
initializer=None,
trainable=True):
"""Add a variable to a Trackable with no scope influence."""
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
name=name, shape=shape, dtype=dtype,
initializer=initializer, getter=_default_getter, trainable=trainable)
name=name,
shape=shape,
dtype=dtype,
initializer=initializer,
getter=_default_getter,
trainable=trainable)
def object_metadata(save_path):
@ -313,6 +324,7 @@ def object_metadata(save_path):
Args:
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`.
Returns:
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
Raises:
@ -320,16 +332,14 @@ def object_metadata(save_path):
"""
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
try:
object_graph_string = reader.get_tensor(
base.OBJECT_GRAPH_PROTO_KEY)
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
raise ValueError(
('The specified checkpoint "%s" does not appear to be object-based (it '
'is missing the key "%s"). Likely it was created with a name-based '
'saver and does not contain an object dependency graph.') % (
save_path, base.OBJECT_GRAPH_PROTO_KEY))
object_graph_proto = (
trackable_object_graph_pb2.TrackableObjectGraph())
"saver and does not contain an object dependency graph.") %
(save_path, base.OBJECT_GRAPH_PROTO_KEY))
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
return object_graph_proto
@ -343,8 +353,8 @@ def list_objects(root_trackable):
(i.e. if they would be saved with a checkpoint).
Args:
root_trackable: A `Trackable` object whose dependencies should be
flattened.
root_trackable: A `Trackable` object whose dependencies should be flattened.
Returns:
A flat list of objects.
"""
@ -362,12 +372,16 @@ def gather_initializers(root_trackable):
Args:
root_trackable: A `Trackable` object to gather initializers for.
Returns:
A list of initialization ops.
"""
trackable_objects = list_objects(root_trackable)
return [c.initializer for c in trackable_objects
if hasattr(c, "initializer") and c.initializer is not None]
return [
c.initializer
for c in trackable_objects
if hasattr(c, "initializer") and c.initializer is not None
]
@tf_contextlib.contextmanager
@ -380,7 +394,7 @@ def capture_dependencies(template):
object to add dependencies on variables created in a block of code which is
not aware of object-based saving (and instead uses variable names
heavily). This is how `Template` objects add dependencies on variables and
sub-`Template`s. Where possible, use `tf.make_template` directly.
sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
Args:
template: The `Template` object to register dependencies with.
@ -390,8 +404,11 @@ def capture_dependencies(template):
"""
name_prefix = template.variable_scope.name
def _trackable_custom_creator(next_creator, name, initial_value,
trackable_parent=None, **kwargs):
def _trackable_custom_creator(next_creator,
name,
initial_value,
trackable_parent=None,
**kwargs):
"""A variable creation hook which adds Trackable dependencies.
Set for example during a `Template`'s first wrapped function
@ -415,21 +432,20 @@ def capture_dependencies(template):
initial_value: See `variable_scope.variable_creator_scope`. Taken
explicitly so the argument can be re-named and used with
`Trackable._add_variable_with_custom_getter`.
trackable_parent: If not None, a more deeply nested trackable
object and its name prefix which were passed to `capture_dependencies`
to add a dependency on (rather than depending on the variable directly).
trackable_parent: If not None, a more deeply nested trackable object and
its name prefix which were passed to `capture_dependencies` to add a
dependency on (rather than depending on the variable directly).
**kwargs: Passed through to the next creator.
Returns:
The output of `next_creator`: the fetched/created variable object.
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate.
return next_creator(
initial_value=initializer,
name=name,
**inner_kwargs)
return next_creator(initial_value=initializer, name=name, **inner_kwargs)
if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not trackable_parent:
@ -450,8 +466,10 @@ def capture_dependencies(template):
name=parent_name_prefix[len(name_prefix) + 1:],
overwrite=True)
return next_creator(
name=name, initial_value=initial_value,
trackable_parent=(template, name_prefix), **kwargs)
name=name,
initial_value=initial_value,
trackable_parent=(template, name_prefix),
**kwargs)
with variable_scope.variable_creator_scope(_trackable_custom_creator):
yield
@ -490,9 +508,8 @@ def streaming_restore(status, session=None):
"""When graph building, runs restore ops as soon as they come in.
Args:
status: A _LoadStatus objects from an object-based saver's
restore(). Streaming restore from name-based checkpoints is not currently
supported.
status: A _LoadStatus objects from an object-based saver's restore().
Streaming restore from name-based checkpoints is not currently supported.
session: A session to run new restore ops in.
"""
if context.executing_eagerly():
@ -553,13 +570,13 @@ class CheckpointLoadStatus(_LoadStatus):
if self._checkpoint.slot_restorations:
# Sanity check; this collection should be clear if everything has been
# restored.
raise AssertionError("Unresolved slot restorations: %s" % (
self._checkpoint.slot_restorations,))
raise AssertionError("Unresolved slot restorations: %s" %
(self._checkpoint.slot_restorations,))
if self._checkpoint.unused_attributes:
raise AssertionError(
("Unused attributes in these objects (the attributes exist in the "
"checkpoint but not in the objects): %s") % (
list(self._checkpoint.unused_attributes.items()),))
"checkpoint but not in the objects): %s") %
(list(self._checkpoint.unused_attributes.items()),))
return self
def assert_existing_objects_matched(self):
@ -581,10 +598,10 @@ class CheckpointLoadStatus(_LoadStatus):
"""
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
if (trackable is not None
and trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
raise AssertionError(
"Object not assigned a value from checkpoint: %s" % (node,))
if (trackable is not None and
trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
raise AssertionError("Object not assigned a value from checkpoint: %s" %
(node,))
for trackable_object in self._graph_view.list_objects():
# Remove data structures that do not contain any variables from
# restoration checks.
@ -594,14 +611,14 @@ class CheckpointLoadStatus(_LoadStatus):
continue
self._checkpoint.all_python_objects.add(trackable_object)
unused_python_objects = (
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
- object_identity.ObjectIdentitySet(
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) -
object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely "
"due to changes in the Python program: %s")
% (list(unused_python_objects),))
"due to changes in the Python program: %s") %
(list(unused_python_objects),))
return self
def assert_nontrivial_match(self):
@ -610,8 +627,7 @@ class CheckpointLoadStatus(_LoadStatus):
self._checkpoint.all_python_objects.add(trackable_object)
if len(self._checkpoint.object_by_proto_id) <= 1:
unused_python_objects = (
object_identity.ObjectIdentitySet(
self._checkpoint.all_python_objects)
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
- object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
@ -622,8 +638,8 @@ class CheckpointLoadStatus(_LoadStatus):
"checkpointed value: %s") % (list(unused_python_objects),))
else:
raise AssertionError(
"Nothing to load. No dependencies have been added to %s yet." % (
self._graph_view.root,))
"Nothing to load. No dependencies have been added to %s yet." %
(self._graph_view.root,))
return self
def run_restore_ops(self, session=None):
@ -760,8 +776,8 @@ class NameBasedSaverStatus(_LoadStatus):
unused_attributes = dict(self._checkpoint.unused_attributes)
if unused_attributes:
raise AssertionError(
"Some objects had attributes which were not restored: %s"
% (unused_attributes,))
"Some objects had attributes which were not restored: %s" %
(unused_attributes,))
for trackable in self._graph_view.list_objects():
# pylint: disable=protected-access
trackable._maybe_initialize_trackable()
@ -799,12 +815,11 @@ class NameBasedSaverStatus(_LoadStatus):
continue
# pylint: enable=protected-access
saveable_objects.extend(
self._checkpoint.globally_named_object_attributes(
trackable))
self._checkpoint.globally_named_object_attributes(trackable))
return saveable_objects
def run_restore_ops(self, session=None):
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
"""Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
if context.executing_eagerly():
return # Nothing to do, variables are restored on creation.
if session is None:
@ -840,7 +855,8 @@ class TrackableSaver(object):
"""Saves and restores a `Trackable` object and its dependencies.
See `Trackable` for details of dependency management. `Saver` wraps
`tf.train.Saver` for saving, including extra information about the graph of
`tf.compat.v1.train.Saver` for saving, including extra information about the
graph of
dependencies between Python objects. When restoring, it uses this information
about the save-time dependency graph to more robustly match objects with their
checkpointed values. When executing eagerly, it supports restoring variables
@ -851,7 +867,8 @@ class TrackableSaver(object):
checkpoint was written. To avoid breaking existing checkpoints when modifying
a class, dependency names (the names of attributes to which `Trackable`
objects are assigned) may not change. These names are local to objects, in
contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and
contrast to the `Variable.name`-based save/restore from
`tf.compat.v1.train.Saver`, and
so allow additional program transformations.
"""
@ -877,8 +894,7 @@ class TrackableSaver(object):
self._restore_op_cache = {}
self._graph_view = graph_view
def _gather_saveables(
self, object_graph_tensor=None):
def _gather_saveables(self, object_graph_tensor=None):
"""Wraps _serialize_object_graph to include the object graph proto."""
(named_saveable_objects, graph_proto,
feed_additions) = self._graph_view.serialize_object_graph()
@ -892,14 +908,12 @@ class TrackableSaver(object):
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
named_saveable_objects.append(
base.NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects, graph_proto, feed_additions
def _save_cached_when_graph_building(
self,
file_prefix,
object_graph_tensor=None):
def _save_cached_when_graph_building(self,
file_prefix,
object_graph_tensor=None):
"""Create or retrieve save ops.
Args:
@ -921,8 +935,7 @@ class TrackableSaver(object):
# save() is called so they pick up new Tensors passed to their
# constructors. That means the Saver needs to be copied with a new
# var_list.
or context.executing_eagerly()
or ops.inside_function()):
or context.executing_eagerly() or ops.inside_function()):
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
save_op = saver.save(file_prefix)
with ops.device("/cpu:0"):
@ -954,8 +967,8 @@ class TrackableSaver(object):
The full path to the checkpoint.
"""
feed_dict = {}
use_session = (not context.executing_eagerly()
and not ops.inside_function())
use_session = (not context.executing_eagerly() and
not ops.inside_function())
if checkpoint_number:
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
if use_session:
@ -976,8 +989,7 @@ class TrackableSaver(object):
file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor,
object_graph_tensor=object_graph_tensor)
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
if new_feed_additions:
feed_dict.update(new_feed_additions)
if not use_session:
@ -1024,7 +1036,7 @@ class TrackableSaver(object):
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. There is no deferred loading, and names are used to match
variables. No restore ops are created/run until `run_restore_ops()` or
`initialize_or_restore()` are called on the returned status object, even
@ -1035,9 +1047,9 @@ class TrackableSaver(object):
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
names are used to match variables.
object which may run initializers for objects in the dependency graph.
If the checkpoint was written by the name-based
`tf.compat.v1.train.Saver`, names are used to match variables.
Returns:
A load status object, which can be used to make assertions about the
@ -1057,8 +1069,7 @@ class TrackableSaver(object):
else:
dtype_map = reader.get_variable_to_dtype_map()
try:
object_graph_string = reader.get_tensor(
base.OBJECT_GRAPH_PROTO_KEY)
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
# The object graph proto does not exist in this checkpoint. Try the
# name-based compatibility mode.
@ -1069,8 +1080,7 @@ class TrackableSaver(object):
# pylint: disable=protected-access
existing_trackable._maybe_initialize_trackable()
existing_trackable._name_based_restores.add(restore_coordinator)
existing_trackable._name_based_attribute_restore(
restore_coordinator)
existing_trackable._name_based_attribute_restore(restore_coordinator)
# pylint: enable=protected-access
return NameBasedSaverStatus(
restore_coordinator, graph_view=self._graph_view)
@ -1085,8 +1095,7 @@ class TrackableSaver(object):
with ops.device("/cpu:0"):
file_prefix_tensor = constant_op.constant(save_path)
file_prefix_feed_dict = None
object_graph_proto = (
trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
checkpoint = _CheckpointRestoreCoordinator(
object_graph_proto=object_graph_proto,
@ -1094,8 +1103,8 @@ class TrackableSaver(object):
save_path_tensor=file_prefix_tensor,
restore_op_cache=self._restore_op_cache,
graph_view=self._graph_view)
base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(
self._graph_view.root)
base.CheckpointPosition(
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
load_status = CheckpointLoadStatus(
checkpoint,
graph_view=self._graph_view,
@ -1104,7 +1113,7 @@ class TrackableSaver(object):
def frozen_saver(root_trackable):
"""Creates a static `tf.train.Saver` from a trackable object.
"""Creates a static `tf.compat.v1.train.Saver` from a trackable object.
The returned `Saver` saves object-based checkpoints, but these checkpoints
will no longer reflect structural changes to the object graph, only changes to
@ -1135,9 +1144,9 @@ def saver_with_op_caching(obj):
saveables_cache = None
else:
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
return TrackableSaver(graph_view_lib.ObjectGraphView(
weakref.ref(obj),
saveables_cache=saveables_cache))
return TrackableSaver(
graph_view_lib.ObjectGraphView(
weakref.ref(obj), saveables_cache=saveables_cache))
# Mentions graph building / Sessions. The v2 version is below.
@ -1146,7 +1155,7 @@ class CheckpointV1(tracking.AutoTrackable):
"""Groups trackable objects, saving and restoring them.
`Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.train.Optimizer`
that contain trackable state, such as `tf.compat.v1.train.Optimizer`
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints.
@ -1164,7 +1173,7 @@ class CheckpointV1(tracking.AutoTrackable):
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
train_op = optimizer.minimize( ... )
status.assert_consumed() # Optional sanity checks.
with tf.Session() as session:
with tf.compat.v1.Session() as session:
# Use the Session to restore variables, or initialize them if
# tf.train.latest_checkpoint returned None.
status.initialize_or_restore(session)
@ -1179,7 +1188,7 @@ class CheckpointV1(tracking.AutoTrackable):
import tensorflow as tf
import os
tf.enable_eager_execution()
tf.compat.v1.enable_eager_execution()
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@ -1193,13 +1202,14 @@ class CheckpointV1(tracking.AutoTrackable):
```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to `tf.train.Saver` which writes and reads
checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
etc.) with named edges, and this graph is used to match variables when
restoring a checkpoint. It can be more robust to changes in the Python
program, and helps to support restore-on-create for variables when executing
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code.
eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
code.
`Checkpoint` objects have dependencies on the objects passed as keyword
arguments to their constructors, and each dependency is given a name that is
@ -1244,6 +1254,7 @@ class CheckpointV1(tracking.AutoTrackable):
Args:
**kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Values must be trackable objects.
Raises:
ValueError: If objects in `kwargs` are not trackable.
"""
@ -1269,8 +1280,12 @@ class CheckpointV1(tracking.AutoTrackable):
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64, trainable=False))
add_variable(
self,
name="save_counter",
initializer=0,
dtype=dtypes.int64,
trainable=False))
def write(self, file_prefix, session=None):
"""Writes a training checkpoint.
@ -1294,9 +1309,7 @@ class CheckpointV1(tracking.AutoTrackable):
Returns:
The full path to the checkpoint (i.e. `file_prefix`).
"""
output = self._saver.save(
file_prefix=file_prefix,
session=session)
output = self._saver.save(file_prefix=file_prefix, session=session)
if tensor_util.is_tensor(output):
if context.executing_eagerly():
return compat.as_str(output.numpy())
@ -1370,8 +1383,8 @@ class CheckpointV1(tracking.AutoTrackable):
checkpoint_number = session.run(self._save_assign_op)
else:
checkpoint_number = assign_op.numpy()
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
session=session)
file_path = self.write(
"%s-%d" % (file_prefix, checkpoint_number), session=session)
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(file_prefix),
model_checkpoint_path=file_path,
@ -1417,7 +1430,7 @@ class CheckpointV1(tracking.AutoTrackable):
If the checkpoint has not been consumed completely, then the list of restore
ops will grow as more objects are added to the dependency graph.
Name-based `tf.train.Saver` checkpoints can be loaded using this
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
method. Names are used to match variables. No restore ops are created/run
until `run_restore_ops()` or `initialize_or_restore()` are called on the
returned status object when graph building, but there is restore-on-creation
@ -1428,9 +1441,9 @@ class CheckpointV1(tracking.AutoTrackable):
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
names are used to match variables.
object which may run initializers for objects in the dependency graph.
If the checkpoint was written by the name-based
`tf.compat.v1.train.Saver`, names are used to match variables.
Returns:
A load status object, which can be used to make assertions about the
@ -1453,7 +1466,8 @@ class CheckpointV1(tracking.AutoTrackable):
built, and so has not created any variables, will pass this assertion
but fail `assert_consumed`. Useful when loading part of a larger
checkpoint into a new Python program, e.g. a training checkpoint with
a `tf.train.Optimizer` was saved but only the state required for
a `tf.compat.v1.train.Optimizer` was saved but only the state required
for
inference is being loaded. This method returns the status object, and
so may be chained with `initialize_or_restore` or `run_restore_ops`.
@ -1488,7 +1502,7 @@ class Checkpoint(tracking.AutoTrackable):
"""Groups trackable objects, saving and restoring them.
`Checkpoint`'s constructor accepts keyword arguments whose values are types
that contain trackable state, such as `tf.train.Optimizer`
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
maintains a `save_counter` for numbering checkpoints.
@ -1511,7 +1525,8 @@ class Checkpoint(tracking.AutoTrackable):
```
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
checkpoints, in contrast to TensorFlow 1.x's `tf.train.Saver` which writes and
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
writes and
reads `variable.name` based checkpoints. Object-based checkpointing saves a
graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
`Variable`s, etc.) with named edges, and this graph is used to match variables
@ -1561,6 +1576,7 @@ class Checkpoint(tracking.AutoTrackable):
Args:
**kwargs: Keyword arguments are set as attributes of this object, and are
saved with the checkpoint. Values must be trackable objects.
Raises:
ValueError: If objects in `kwargs` are not trackable.
"""
@ -1586,8 +1602,12 @@ class Checkpoint(tracking.AutoTrackable):
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64, trainable=False))
add_variable(
self,
name="save_counter",
initializer=0,
dtype=dtypes.int64,
trainable=False))
def write(self, file_prefix):
"""Writes a training checkpoint.
@ -1608,8 +1628,7 @@ class Checkpoint(tracking.AutoTrackable):
Returns:
The full path to the checkpoint (i.e. `file_prefix`).
"""
output = self._saver.save(
file_prefix=file_prefix)
output = self._saver.save(file_prefix=file_prefix)
if tensor_util.is_tensor(output):
if context.executing_eagerly():
return compat.as_str(output.numpy())
@ -1711,7 +1730,8 @@ class Checkpoint(tracking.AutoTrackable):
were not found in the checkpoint, or if any checkpointed values do not have
a matching Python object.
Name-based `tf.train.Saver` checkpoints from TensorFlow 1.x can be loaded
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
loaded
using this method. Names are used to match variables. Re-encode name-based
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
@ -1719,9 +1739,9 @@ class Checkpoint(tracking.AutoTrackable):
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`. If None (as when there is no latest
checkpoint for `tf.train.latest_checkpoint` to return), returns an
object which may run initializers for objects in the dependency
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
names are used to match variables.
object which may run initializers for objects in the dependency graph.
If the checkpoint was written by the name-based
`tf.compat.v1.train.Saver`, names are used to match variables.
Returns:
A load status object, which can be used to make assertions about the
@ -1744,7 +1764,8 @@ class Checkpoint(tracking.AutoTrackable):
built, and so has not created any variables, will pass this assertion
but fail `assert_consumed`. Useful when loading part of a larger
checkpoint into a new Python program, e.g. a training checkpoint with
a `tf.train.Optimizer` was saved but only the state required for
a `tf.compat.v1.train.Optimizer` was saved but only the state required
for
inference is being loaded. This method returns the status object, and
so may be chained with other assertions.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for training."""
from __future__ import absolute_import
from __future__ import division
@ -34,7 +33,6 @@ from tensorflow.python.util.tf_export import tf_export
# collection keys.
GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
# TODO(drpng): remove this after legacy uses are resolved.
write_graph = graph_io.write_graph
@ -47,11 +45,12 @@ def global_step(sess, global_step_tensor):
# Create a variable to hold the global_step.
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
# Create a session.
sess = tf.Session()
sess = tf.compat.v1.Session()
# Initialize the variable
sess.run(global_step_tensor.initializer)
# Get the variable value.
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))
print('global_step: %s' % tf.compat.v1.train.global_step(sess,
global_step_tensor))
global_step: 10
```
@ -109,8 +108,8 @@ def create_global_step(graph=None):
"""Create global step tensor in graph.
Args:
graph: The graph in which to create the global step tensor. If missing,
use default graph.
graph: The graph in which to create the global step tensor. If missing, use
default graph.
Returns:
Global step tensor.
@ -130,8 +129,9 @@ def create_global_step(graph=None):
initializer=init_ops.zeros_initializer(),
trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
collections=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
])
# Create in proper graph and base name_scope.
with graph.as_default() as g, g.name_scope(None):
return variable_scope.get_variable(
@ -141,8 +141,7 @@ def create_global_step(graph=None):
initializer=init_ops.zeros_initializer(),
trainable=False,
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
@tf_export(v1=['train.get_or_create_global_step'])
@ -173,9 +172,8 @@ def assert_global_step(global_step_tensor):
if not (isinstance(global_step_tensor, variables.Variable) or
isinstance(global_step_tensor, ops.Tensor) or
resource_variable_ops.is_resource_variable(global_step_tensor)):
raise TypeError(
'Existing "global_step" must be a Variable or Tensor: %s.' %
global_step_tensor)
raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
global_step_tensor)
if not global_step_tensor.dtype.base_dtype.is_integer:
raise TypeError('Existing "global_step" does not have integer type: %s' %

View File

@ -49,8 +49,8 @@ class VocabInfo(
VocabInfo to warm-start.
Attributes:
new_vocab: [Required] A path to the new vocabulary file (used with the
model to be trained).
new_vocab: [Required] A path to the new vocabulary file (used with the model
to be trained).
new_vocab_size: [Required] An integer indicating how many entries of the new
vocabulary will used in training.
num_oov_buckets: [Required] An integer indicating how many OOV buckets are
@ -76,7 +76,7 @@ class VocabInfo(
num_oov_buckets=1,
old_vocab='pretrained_embeddings_vocab',
old_vocab_size=10000,
backup_initializer=tf.truncated_normal_initializer(
backup_initializer=tf.compat.v1.truncated_normal_initializer(
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
axis=0)
@ -86,7 +86,7 @@ class VocabInfo(
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.glorot_uniform_initializer(),
backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
axis=1)
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
@ -95,7 +95,7 @@ class VocabInfo(
num_oov_buckets=0, # No OOV for classes.
old_vocab='old_class_vocab',
old_vocab_size=8,
backup_initializer=tf.zeros_initializer(),
backup_initializer=tf.compat.v1.zeros_initializer(),
axis=0)
Currently, only axis=0 and axis=1 are supported.
@ -255,8 +255,7 @@ def _warm_start_var_with_vocab(var,
partition_info = None
if slice_info:
partition_info = variable_scope._PartitionInfo(
full_shape=slice_info.full_shape,
var_offset=slice_info.var_offset)
full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
if axis == 0:
new_row_vocab_size = current_vocab_size
@ -301,6 +300,8 @@ def _warm_start_var_with_vocab(var,
new_init_val = ops.convert_to_tensor(
init(shape=v_shape, partition_info=partition_info))
v._initializer_op = state_ops.assign(v, new_init_val)
# pylint: enable=protected-access
@ -314,7 +315,8 @@ def _get_grouped_variables(vars_to_warm_start):
vars_to_warm_start: One of the following:
- A regular expression (string) that captures which variables to
warm-start (see tf.get_collection). This expression will only consider
warm-start (see tf.compat.v1.get_collection). This expression will
only consider
variables in the TRAINABLE_VARIABLES collection.
- A list of Variables to warm-start.
- A list of strings, each representing a full variable name to warm-start.
@ -330,14 +332,13 @@ def _get_grouped_variables(vars_to_warm_start):
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
# everything (in TRAINABLE_VARIABLES) here.
list_of_vars = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=vars_to_warm_start)
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
elif isinstance(vars_to_warm_start, list):
if all(isinstance(v, str) for v in vars_to_warm_start):
list_of_vars = []
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=v)
list_of_vars += ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start
else:
@ -377,16 +378,16 @@ def warm_start(ckpt_to_initialize_from,
vars_to_warm_start: [Optional] One of the following:
- A regular expression (string) that captures which variables to
warm-start (see tf.get_collection). This expression will only consider
variables in the TRAINABLE_VARIABLES collection -- if you need to
warm-start non_TRAINABLE vars (such as optimizer accumulators or batch
norm statistics), please use the below option.
warm-start (see tf.compat.v1.get_collection). This expression will only
consider variables in the TRAINABLE_VARIABLES collection -- if you need
to warm-start non_TRAINABLE vars (such as optimizer accumulators or
batch norm statistics), please use the below option.
- A list of Variables to warm-start. If you do not have access to the
`Variable` objects at the call site, please use the below option.
- A list of strings, each a regex scope provided to tf.get_collection with
GLOBAL_VARIABLES (please see tf.get_collection). For backwards
compatibility reasons, this is separate from the single-string argument
type.
- A list of strings, each a regex scope provided to
tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
tf.compat.v1.get_collection). For backwards compatibility reasons,
this is separate from the single-string argument type.
- `None`, in which case only variables specified in
`var_name_to_vocab_info` will be warm-started.
@ -404,6 +405,7 @@ def warm_start(ckpt_to_initialize_from,
effect on the set of variables that is warm-started, and only controls
name mapping (use `vars_to_warm_start` for controlling what variables to
warm-start).
Raises:
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
configuration for variable names that are not used. This is to ensure