Apply tf1->tf2 name replaces to doc-strings and comments in tensorflow.
No code changes, only doc-strings and comments. PiperOrigin-RevId: 244275767
This commit is contained in:
parent
7fdf27b688
commit
18b680216e
tensorflow/python/training
basic_loops.pybasic_session_run_hooks.pycheckpoint_utils.pydevice_setter.pyevaluation.pyinput.pylearning_rate_decay.pymonitored_session.pymoving_averages.pyqueue_runner_impl.pysaver.pyserver_lib.pyserver_lib_test.pysession_manager_test.pysession_run_hook.pysummary_io.pysupervisor.pysync_replicas_optimizer.py
tracking
training_util.pywarm_starting_util.py@ -22,8 +22,11 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export(v1=["train.basic_train_loop"])
|
||||
def basic_train_loop(supervisor, train_step_fn, args=None,
|
||||
kwargs=None, master=""):
|
||||
def basic_train_loop(supervisor,
|
||||
train_step_fn,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
master=""):
|
||||
"""Basic loop to train a model.
|
||||
|
||||
Calls `train_step_fn` in a loop to train a model. The function is called as:
|
||||
@ -32,17 +35,18 @@ def basic_train_loop(supervisor, train_step_fn, args=None,
|
||||
train_step_fn(session, *args, **kwargs)
|
||||
```
|
||||
|
||||
It is passed a `tf.Session` in addition to `args` and `kwargs`. The function
|
||||
It is passed a `tf.compat.v1.Session` in addition to `args` and `kwargs`. The
|
||||
function
|
||||
typically runs one training step in the session.
|
||||
|
||||
Args:
|
||||
supervisor: `tf.train.Supervisor` to run the training services.
|
||||
train_step_fn: Callable to execute one training step. Called
|
||||
repeatedly as `train_step_fn(session, *args **kwargs)`.
|
||||
supervisor: `tf.compat.v1.train.Supervisor` to run the training services.
|
||||
train_step_fn: Callable to execute one training step. Called repeatedly as
|
||||
`train_step_fn(session, *args **kwargs)`.
|
||||
args: Optional positional arguments passed to `train_step_fn`.
|
||||
kwargs: Optional keyword arguments passed to `train_step_fn`.
|
||||
master: Master to use to create the training session. Defaults to
|
||||
`""` which causes the session to be created in the local process.
|
||||
master: Master to use to create the training session. Defaults to `""`
|
||||
which causes the session to be created in the local process.
|
||||
"""
|
||||
if args is None:
|
||||
args = []
|
||||
|
@ -42,7 +42,6 @@ from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_HOOKS = "hooks"
|
||||
_STEPS_PER_RUN_VAR = "steps_per_run"
|
||||
|
||||
@ -85,8 +84,7 @@ class _HookTimer(object):
|
||||
|
||||
@tf_export(v1=["train.SecondOrStepTimer"])
|
||||
class SecondOrStepTimer(_HookTimer):
|
||||
"""Timer that triggers at most once every N seconds or once every N steps.
|
||||
"""
|
||||
"""Timer that triggers at most once every N seconds or once every N steps."""
|
||||
|
||||
def __init__(self, every_secs=None, every_steps=None):
|
||||
self.reset()
|
||||
@ -171,29 +169,33 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
|
||||
seeing the logs, you might want to add the following line after your imports:
|
||||
|
||||
```python
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
||||
```
|
||||
|
||||
Note that if `at_end` is True, `tensors` should not include any tensor
|
||||
whose evaluation produces a side effect such as consuming additional inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
|
||||
at_end=False, formatter=None):
|
||||
def __init__(self,
|
||||
tensors,
|
||||
every_n_iter=None,
|
||||
every_n_secs=None,
|
||||
at_end=False,
|
||||
formatter=None):
|
||||
"""Initializes a `LoggingTensorHook`.
|
||||
|
||||
Args:
|
||||
tensors: `dict` that maps string-valued tags to tensors/tensor names,
|
||||
or `iterable` of tensors/tensor names.
|
||||
tensors: `dict` that maps string-valued tags to tensors/tensor names, or
|
||||
`iterable` of tensors/tensor names.
|
||||
every_n_iter: `int`, print the values of `tensors` once every N local
|
||||
steps taken on the current worker.
|
||||
steps taken on the current worker.
|
||||
every_n_secs: `int` or `float`, print the values of `tensors` once every N
|
||||
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
|
||||
provided.
|
||||
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
|
||||
provided.
|
||||
at_end: `bool` specifying whether to print the values of `tensors` at the
|
||||
end of the run.
|
||||
end of the run.
|
||||
formatter: function, takes dict of `tag`->`Tensor` and returns a string.
|
||||
If `None` uses default printing all tensors.
|
||||
If `None` uses default printing all tensors.
|
||||
|
||||
Raises:
|
||||
ValueError: if `every_n_iter` is non-positive.
|
||||
@ -215,16 +217,18 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
|
||||
self._tensors = tensors
|
||||
self._formatter = formatter
|
||||
self._timer = (
|
||||
NeverTriggerTimer() if only_log_at_end else
|
||||
SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter))
|
||||
NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
|
||||
every_secs=every_n_secs, every_steps=every_n_iter))
|
||||
self._log_at_end = at_end
|
||||
|
||||
def begin(self):
|
||||
self._timer.reset()
|
||||
self._iter_count = 0
|
||||
# Convert names to tensors if given
|
||||
self._current_tensors = {tag: _as_graph_element(tensor)
|
||||
for (tag, tensor) in self._tensors.items()}
|
||||
self._current_tensors = {
|
||||
tag: _as_graph_element(tensor)
|
||||
for (tag, tensor) in self._tensors.items()
|
||||
}
|
||||
|
||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
|
||||
@ -463,9 +467,10 @@ class CheckpointSaverListener(object):
|
||||
|
||||
...
|
||||
listener = ExampleCheckpointSaverListener()
|
||||
saver_hook = tf.train.CheckpointSaverHook(
|
||||
saver_hook = tf.estimator.CheckpointSaverHook(
|
||||
checkpoint_dir, listeners=[listener])
|
||||
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
|
||||
with
|
||||
tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
|
||||
...
|
||||
```
|
||||
|
||||
@ -516,9 +521,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
saver: `Saver` object, used for saving.
|
||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||
scaffold: `Scaffold`, use to get saver object.
|
||||
listeners: List of `CheckpointSaverListener` subclass instances.
|
||||
Used for callbacks that run immediately before or after this hook saves
|
||||
the checkpoint.
|
||||
listeners: List of `CheckpointSaverListener` subclass instances. Used for
|
||||
callbacks that run immediately before or after this hook saves the
|
||||
checkpoint.
|
||||
|
||||
Raises:
|
||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||
@ -531,8 +536,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||
self._scaffold = scaffold
|
||||
self._timer = SecondOrStepTimer(every_secs=save_secs,
|
||||
every_steps=save_steps)
|
||||
self._timer = SecondOrStepTimer(
|
||||
every_secs=save_secs, every_steps=save_steps)
|
||||
self._listeners = listeners or []
|
||||
self._steps_per_run = 1
|
||||
|
||||
@ -555,13 +560,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
# add variables in begin. Graph is finalized after all begin calls.
|
||||
training_util.write_graph(
|
||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir,
|
||||
"graph.pbtxt")
|
||||
self._checkpoint_dir, "graph.pbtxt")
|
||||
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
||||
graph = ops.get_default_graph()
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
graph_def=graph.as_graph_def(add_shapes=True),
|
||||
saver_def=saver_def)
|
||||
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
|
||||
self._summary_writer.add_graph(graph)
|
||||
self._summary_writer.add_meta_graph(meta_graph_def)
|
||||
# The checkpoint saved here is the state at step "global_step".
|
||||
@ -573,8 +576,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
stale_global_step = run_values.results
|
||||
if self._timer.should_trigger_for_step(
|
||||
stale_global_step + self._steps_per_run):
|
||||
if self._timer.should_trigger_for_step(stale_global_step +
|
||||
self._steps_per_run):
|
||||
# get the real value after train op.
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
@ -627,8 +630,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
||||
elif len(savers) > 1:
|
||||
raise RuntimeError(
|
||||
"More than one item in collection {}. "
|
||||
"Please indicate which one to use by passing it to the constructor.".
|
||||
format(collection_key))
|
||||
"Please indicate which one to use by passing it to the constructor."
|
||||
.format(collection_key))
|
||||
|
||||
self._saver = savers[0]
|
||||
return savers[0]
|
||||
@ -647,8 +650,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
||||
if (every_n_steps is None) == (every_n_secs is None):
|
||||
raise ValueError(
|
||||
"exactly one of every_n_steps and every_n_secs should be provided.")
|
||||
self._timer = SecondOrStepTimer(every_steps=every_n_steps,
|
||||
every_secs=every_n_secs)
|
||||
self._timer = SecondOrStepTimer(
|
||||
every_steps=every_n_steps, every_secs=every_n_secs)
|
||||
|
||||
self._summary_writer = summary_writer
|
||||
self._output_dir = output_dir
|
||||
@ -673,8 +676,9 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
||||
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
|
||||
steps_per_sec = elapsed_steps / elapsed_time
|
||||
if self._summary_writer is not None:
|
||||
summary = Summary(value=[Summary.Value(
|
||||
tag=self._summary_tag, simple_value=steps_per_sec)])
|
||||
summary = Summary(value=[
|
||||
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
|
||||
])
|
||||
self._summary_writer.add_summary(summary, global_step)
|
||||
logging.info("%s: %g", self._summary_tag, steps_per_sec)
|
||||
|
||||
@ -682,8 +686,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
||||
_ = run_context
|
||||
|
||||
stale_global_step = run_values.results
|
||||
if self._timer.should_trigger_for_step(
|
||||
stale_global_step + self._steps_per_run):
|
||||
if self._timer.should_trigger_for_step(stale_global_step +
|
||||
self._steps_per_run):
|
||||
# get the real value after train op.
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
if self._timer.should_trigger_for_step(global_step):
|
||||
@ -767,18 +771,18 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
|
||||
|
||||
Args:
|
||||
save_steps: `int`, save summaries every N steps. Exactly one of
|
||||
`save_secs` and `save_steps` should be set.
|
||||
`save_secs` and `save_steps` should be set.
|
||||
save_secs: `int`, save summaries every N seconds.
|
||||
output_dir: `string`, the directory to save the summaries to. Only used
|
||||
if no `summary_writer` is supplied.
|
||||
output_dir: `string`, the directory to save the summaries to. Only used if
|
||||
no `summary_writer` is supplied.
|
||||
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
|
||||
one will be created accordingly.
|
||||
one will be created accordingly.
|
||||
scaffold: `Scaffold` to get summary_op if it's not provided.
|
||||
summary_op: `Tensor` of type `string` containing the serialized `Summary`
|
||||
protocol buffer or a list of `Tensor`. They are most likely an output
|
||||
by TF summary methods like `tf.summary.scalar` or
|
||||
`tf.summary.merge_all`. It can be passed in as one tensor; if more
|
||||
than one, they must be passed in as a list.
|
||||
protocol buffer or a list of `Tensor`. They are most likely an output by
|
||||
TF summary methods like `tf.compat.v1.summary.scalar` or
|
||||
`tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
|
||||
more than one, they must be passed in as a list.
|
||||
|
||||
Raises:
|
||||
ValueError: Exactly one of scaffold or summary_op should be set.
|
||||
@ -791,8 +795,8 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
|
||||
self._summary_writer = summary_writer
|
||||
self._output_dir = output_dir
|
||||
self._scaffold = scaffold
|
||||
self._timer = SecondOrStepTimer(every_secs=save_secs,
|
||||
every_steps=save_steps)
|
||||
self._timer = SecondOrStepTimer(
|
||||
every_secs=save_secs, every_steps=save_steps)
|
||||
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
||||
|
||||
def begin(self):
|
||||
@ -903,8 +907,9 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
|
||||
self._worker_is_started = True
|
||||
return None
|
||||
if current_step - last_logged_step > 1000:
|
||||
logging.info("Waiting for global step %d before starting training. "
|
||||
"Current step is %d.", self._wait_until_step, current_step)
|
||||
logging.info(
|
||||
"Waiting for global step %d before starting training. "
|
||||
"Current step is %d.", self._wait_until_step, current_step)
|
||||
last_logged_step = current_step
|
||||
time.sleep(0.5)
|
||||
|
||||
@ -917,8 +922,8 @@ class FinalOpsHook(session_run_hook.SessionRunHook):
|
||||
"""Initializes `FinalOpHook` with ops to run at the end of the session.
|
||||
|
||||
Args:
|
||||
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of
|
||||
names to `Tensors`.
|
||||
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
||||
to `Tensors`.
|
||||
final_ops_feed_dict: A feed dictionary to use when running
|
||||
`final_ops_dict`.
|
||||
"""
|
||||
@ -997,14 +1002,14 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
||||
|
||||
Args:
|
||||
save_steps: `int`, save profile traces every N steps. Exactly one of
|
||||
`save_secs` and `save_steps` should be set.
|
||||
`save_secs` and `save_steps` should be set.
|
||||
save_secs: `int` or `float`, save profile traces every N seconds.
|
||||
output_dir: `string`, the directory to save the profile traces to.
|
||||
Defaults to the current directory.
|
||||
Defaults to the current directory.
|
||||
show_dataflow: `bool`, if True, add flow events to the trace connecting
|
||||
producers and consumers of tensors.
|
||||
producers and consumers of tensors.
|
||||
show_memory: `bool`, if True, add object snapshot events to the trace
|
||||
showing the sizes and lifetimes of tensors.
|
||||
showing the sizes and lifetimes of tensors.
|
||||
"""
|
||||
self._output_file = os.path.join(output_dir, "timeline-{}.json")
|
||||
self._file_writer = SummaryWriterCache.get(output_dir)
|
||||
@ -1024,8 +1029,9 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
||||
self._next_step is not None and
|
||||
self._timer.should_trigger_for_step(self._next_step))
|
||||
requests = {"global_step": self._global_step_tensor}
|
||||
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
if self._request_summary else None)
|
||||
opts = (
|
||||
config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||
if self._request_summary else None)
|
||||
|
||||
return SessionRunArgs(requests, options=opts)
|
||||
|
||||
@ -1039,8 +1045,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
||||
if self._request_summary:
|
||||
global_step = run_context.session.run(self._global_step_tensor)
|
||||
self._timer.update_last_triggered_step(global_step)
|
||||
self._save(global_step,
|
||||
self._output_file.format(global_step),
|
||||
self._save(global_step, self._output_file.format(global_step),
|
||||
run_values.run_metadata.step_stats)
|
||||
self._file_writer.add_run_metadata(run_values.run_metadata,
|
||||
"step_%d" % global_step)
|
||||
|
@ -106,7 +106,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
||||
"""Replaces `tf.Variable` initializers so they load from a checkpoint file.
|
||||
|
||||
Values are not loaded immediately, but when the initializer is run
|
||||
(typically by running a `tf.global_variables_initializer` op).
|
||||
(typically by running a `tf.compat.v1.global_variables_initializer` op).
|
||||
|
||||
Note: This overrides default initialization ops of specified variables and
|
||||
redefines dtype.
|
||||
@ -139,15 +139,15 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
||||
# -- name='old_scope_2/var3', shape=[100, 100]
|
||||
|
||||
# Create new model's variables
|
||||
with tf.variable_scope('new_scope_1'):
|
||||
var1 = tf.get_variable('var1', shape=[20, 2],
|
||||
initializer=tf.zeros_initializer())
|
||||
with tf.variable_scope('new_scope_2'):
|
||||
var2 = tf.get_variable('var2', shape=[50, 4],
|
||||
initializer=tf.zeros_initializer())
|
||||
with tf.compat.v1.variable_scope('new_scope_1'):
|
||||
var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
|
||||
initializer=tf.compat.v1.zeros_initializer())
|
||||
with tf.compat.v1.variable_scope('new_scope_2'):
|
||||
var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
|
||||
initializer=tf.compat.v1.zeros_initializer())
|
||||
# Partition into 5 variables along the first axis.
|
||||
var3 = tf.get_variable(name='var3', shape=[100, 100],
|
||||
initializer=tf.zeros_initializer(),
|
||||
var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
|
||||
initializer=tf.compat.v1.zeros_initializer(),
|
||||
partitioner=lambda shape, dtype: [5, 1])
|
||||
|
||||
# Initialize all variables in `new_scope_1` from `old_scope_1`.
|
||||
|
@ -131,9 +131,13 @@ class _ReplicaDeviceChooser(object):
|
||||
|
||||
|
||||
@tf_export(v1=["train.replica_device_setter"])
|
||||
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
||||
worker_device="/job:worker", merge_devices=True,
|
||||
cluster=None, ps_ops=None, ps_strategy=None):
|
||||
def replica_device_setter(ps_tasks=0,
|
||||
ps_device="/job:ps",
|
||||
worker_device="/job:worker",
|
||||
merge_devices=True,
|
||||
cluster=None,
|
||||
ps_ops=None,
|
||||
ps_strategy=None):
|
||||
"""Return a `device function` to use when building a Graph for replicas.
|
||||
|
||||
Device Functions are used in `with tf.device(device_function):` statement to
|
||||
@ -158,7 +162,8 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
||||
cluster_spec = {
|
||||
"ps": ["ps0:2222", "ps1:2222"],
|
||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
|
||||
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
|
||||
with
|
||||
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
|
||||
# Build your graph
|
||||
v1 = tf.Variable(...) # assigned to /job:ps/task:0
|
||||
v2 = tf.Variable(...) # assigned to /job:ps/task:1
|
||||
@ -218,6 +223,6 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
||||
ps_strategy = _RoundRobinStrategy(ps_tasks)
|
||||
if not six.callable(ps_strategy):
|
||||
raise TypeError("ps_strategy must be callable")
|
||||
chooser = _ReplicaDeviceChooser(
|
||||
ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy)
|
||||
chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
|
||||
merge_devices, ps_ops, ps_strategy)
|
||||
return chooser.device_function
|
||||
|
@ -65,8 +65,8 @@ def _get_latest_eval_step_value(update_ops):
|
||||
"""Gets the eval step `Tensor` value after running `update_ops`.
|
||||
|
||||
Args:
|
||||
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
|
||||
which are run before reading the eval step value.
|
||||
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which
|
||||
are run before reading the eval step value.
|
||||
|
||||
Returns:
|
||||
A `Tensor` representing the value for the evaluation step.
|
||||
@ -102,21 +102,20 @@ class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
# Update number of steps to run in the first run call
|
||||
if self._num_evals is None:
|
||||
if self._num_evals is None:
|
||||
steps = self._steps_per_run_initial_value
|
||||
else:
|
||||
steps = min(self._steps_per_run_initial_value, self._num_evals)
|
||||
self._steps_per_run_variable.load(steps, session=session)
|
||||
|
||||
def before_run(self, run_context):
|
||||
return session_run_hook.SessionRunArgs({
|
||||
'evals_completed': self._evals_completed
|
||||
})
|
||||
return session_run_hook.SessionRunArgs(
|
||||
{'evals_completed': self._evals_completed})
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
evals_completed = run_values.results['evals_completed']
|
||||
# Update number of steps to run in the next iteration
|
||||
if self._num_evals is None:
|
||||
if self._num_evals is None:
|
||||
steps = self._steps_per_run_initial_value
|
||||
else:
|
||||
steps = min(self._num_evals - evals_completed,
|
||||
@ -147,16 +146,15 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
||||
self._evals_completed = None
|
||||
self._log_progress = log_progress
|
||||
# Reduce logging frequency if there are 20 or more evaluations.
|
||||
self._log_frequency = (1 if (num_evals is None or num_evals < 20)
|
||||
else math.floor(num_evals / 10.))
|
||||
self._log_frequency = (1 if (num_evals is None or num_evals < 20) else
|
||||
math.floor(num_evals / 10.))
|
||||
|
||||
def _set_evals_completed_tensor(self, updated_eval_step):
|
||||
self._evals_completed = updated_eval_step
|
||||
|
||||
def before_run(self, run_context):
|
||||
return session_run_hook.SessionRunArgs({
|
||||
'evals_completed': self._evals_completed
|
||||
})
|
||||
return session_run_hook.SessionRunArgs(
|
||||
{'evals_completed': self._evals_completed})
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
evals_completed = run_values.results['evals_completed']
|
||||
@ -205,20 +203,20 @@ def _evaluate_once(checkpoint_path,
|
||||
Args:
|
||||
checkpoint_path: The path to a checkpoint to use for evaluation.
|
||||
master: The BNS address of the TensorFlow master.
|
||||
scaffold: An tf.train.Scaffold instance for initializing variables and
|
||||
restoring variables. Note that `scaffold.init_fn` is used by the function
|
||||
to restore the checkpoint. If you supply a custom init_fn, then it must
|
||||
also take care of restoring the model from its checkpoint.
|
||||
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
||||
to `Tensors`, which is run until the session is requested to stop,
|
||||
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
|
||||
scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables
|
||||
and restoring variables. Note that `scaffold.init_fn` is used by the
|
||||
function to restore the checkpoint. If you supply a custom init_fn, then
|
||||
it must also take care of restoring the model from its checkpoint.
|
||||
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
|
||||
`Tensors`, which is run until the session is requested to stop, commonly
|
||||
done by a `tf.contrib.training.StopAfterNEvalsHook`.
|
||||
feed_dict: The feed dictionary to use when executing the `eval_ops`.
|
||||
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
||||
to `Tensors`.
|
||||
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
|
||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||
evaluation loop.
|
||||
config: An instance of `tf.ConfigProto` that will be used to
|
||||
hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside
|
||||
the evaluation loop.
|
||||
config: An instance of `tf.compat.v1.ConfigProto` that will be used to
|
||||
configure the `Session`. If left as `None`, the default will be used.
|
||||
|
||||
Returns:
|
||||
@ -263,8 +261,8 @@ def _evaluate_once(checkpoint_path,
|
||||
master=master,
|
||||
config=config)
|
||||
|
||||
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
|
||||
final_ops, final_ops_feed_dict)
|
||||
final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
|
||||
final_ops_feed_dict)
|
||||
hooks.append(final_ops_hook)
|
||||
|
||||
with monitored_session.MonitoredSession(
|
||||
|
@ -1090,7 +1090,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
|
||||
|
||||
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
||||
dictionaries of tensors. Each element in the list is treated similarly
|
||||
to the `tensors` argument of `tf.train.batch()`.
|
||||
to the `tensors` argument of `tf.compat.v1.train.batch()`.
|
||||
|
||||
WARNING: This function is nondeterministic, since it starts a separate thread
|
||||
for each tensor.
|
||||
@ -1284,7 +1284,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
|
||||
|
||||
```python
|
||||
# Creates batches of 32 images and 32 labels.
|
||||
image_batch, label_batch = tf.train.shuffle_batch(
|
||||
image_batch, label_batch = tf.compat.v1.train.shuffle_batch(
|
||||
[single_image, single_label],
|
||||
batch_size=32,
|
||||
num_threads=4,
|
||||
@ -1425,7 +1425,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
|
||||
|
||||
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
||||
dictionaries of tensors. Each element in the list is treated similarly
|
||||
to the `tensors` argument of `tf.train.shuffle_batch()`.
|
||||
to the `tensors` argument of `tf.compat.v1.train.shuffle_batch()`.
|
||||
|
||||
This version enqueues a different list of tensors in different threads.
|
||||
It adds the following to the current `Graph`:
|
||||
|
@ -56,24 +56,25 @@ def exponential_decay(learning_rate,
|
||||
...
|
||||
global_step = tf.Variable(0, trainable=False)
|
||||
starter_learning_rate = 0.1
|
||||
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
|
||||
learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
|
||||
global_step,
|
||||
100000, 0.96, staircase=True)
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
learning_step = (
|
||||
tf.train.GradientDescentOptimizer(learning_rate)
|
||||
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||
.minimize(...my loss..., global_step=global_step)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation. Must not be negative.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Must be positive. See the decay computation above.
|
||||
decay_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The decay rate.
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation. Must not be negative.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
|
||||
be positive. See the decay computation above.
|
||||
decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
The decay rate.
|
||||
staircase: Boolean. If `True` decay the learning rate at discrete intervals
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'ExponentialDecay'.
|
||||
@ -91,11 +92,8 @@ def exponential_decay(learning_rate,
|
||||
the learning rate value across different invocations of optimizer functions.
|
||||
@end_compatibility
|
||||
"""
|
||||
decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate,
|
||||
decay_steps,
|
||||
decay_rate,
|
||||
staircase=staircase,
|
||||
name=name)
|
||||
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
|
||||
if not context.executing_eagerly():
|
||||
decayed_lr = decayed_lr(global_step)
|
||||
else:
|
||||
@ -114,7 +112,8 @@ def piecewise_constant(x, boundaries, values, name=None):
|
||||
global_step = tf.Variable(0, trainable=False)
|
||||
boundaries = [100000, 110000]
|
||||
values = [1.0, 0.5, 0.1]
|
||||
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
|
||||
learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries,
|
||||
values)
|
||||
|
||||
# Later, whenever we perform an optimization step, we increment global_step.
|
||||
```
|
||||
@ -202,27 +201,28 @@ def polynomial_decay(learning_rate,
|
||||
starter_learning_rate = 0.1
|
||||
end_learning_rate = 0.01
|
||||
decay_steps = 10000
|
||||
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step,
|
||||
learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate,
|
||||
global_step,
|
||||
decay_steps, end_learning_rate,
|
||||
power=0.5)
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
learning_step = (
|
||||
tf.train.GradientDescentOptimizer(learning_rate)
|
||||
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||
.minimize(...my loss..., global_step=global_step)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation. Must not be negative.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Must be positive. See the decay computation above.
|
||||
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The minimal end learning rate.
|
||||
power: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The power of the polynomial. Defaults to linear, 1.0.
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation. Must not be negative.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
|
||||
be positive. See the decay computation above.
|
||||
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python
|
||||
number. The minimal end learning rate.
|
||||
power: A scalar `float32` or `float64` `Tensor` or a Python number. The
|
||||
power of the polynomial. Defaults to linear, 1.0.
|
||||
cycle: A boolean, whether or not it should cycle beyond decay_steps.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'PolynomialDecay'.
|
||||
@ -292,21 +292,22 @@ def natural_exp_decay(learning_rate,
|
||||
learning_rate = 0.1
|
||||
decay_steps = 5
|
||||
k = 0.5
|
||||
learning_rate = tf.train.natural_exp_decay(learning_rate, global_step,
|
||||
learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate,
|
||||
global_step,
|
||||
decay_steps, k)
|
||||
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
learning_step = (
|
||||
tf.train.GradientDescentOptimizer(learning_rate)
|
||||
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||
.minimize(...my loss..., global_step=global_step)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
global_step: A Python number.
|
||||
Global step to use for the decay computation. Must not be negative.
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A Python number. Global step to use for the decay computation.
|
||||
Must not be negative.
|
||||
decay_steps: How often to apply decay.
|
||||
decay_rate: A Python number. The decay rate.
|
||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
||||
@ -329,7 +330,10 @@ def natural_exp_decay(learning_rate,
|
||||
"""
|
||||
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
|
||||
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||
learning_rate, decay_steps, natural_exp_rate, staircase=staircase,
|
||||
learning_rate,
|
||||
decay_steps,
|
||||
natural_exp_rate,
|
||||
staircase=staircase,
|
||||
name=name)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
@ -376,21 +380,22 @@ def inverse_time_decay(learning_rate,
|
||||
learning_rate = 0.1
|
||||
decay_steps = 1.0
|
||||
decay_rate = 0.5
|
||||
learning_rate = tf.train.inverse_time_decay(learning_rate, global_step,
|
||||
learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate,
|
||||
global_step,
|
||||
decay_steps, decay_rate)
|
||||
|
||||
# Passing global_step to minimize() will increment it at each step.
|
||||
learning_step = (
|
||||
tf.train.GradientDescentOptimizer(learning_rate)
|
||||
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||
.minimize(...my loss..., global_step=global_step)
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
||||
Python number. The initial learning rate.
|
||||
global_step: A Python number.
|
||||
Global step to use for the decay computation. Must not be negative.
|
||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A Python number. Global step to use for the decay computation.
|
||||
Must not be negative.
|
||||
decay_steps: How often to apply decay.
|
||||
decay_rate: A Python number. The decay rate.
|
||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
||||
@ -412,11 +417,7 @@ def inverse_time_decay(learning_rate,
|
||||
@end_compatibility
|
||||
"""
|
||||
decayed_lr = learning_rate_schedule.InverseTimeDecay(
|
||||
learning_rate,
|
||||
decay_steps,
|
||||
decay_rate,
|
||||
staircase=staircase,
|
||||
name=name)
|
||||
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
decayed_lr = decayed_lr(global_step)
|
||||
@ -455,13 +456,14 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Number of steps to decay over.
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
Minimum learning rate value as a fraction of learning_rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||
of steps to decay over.
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
|
||||
learning rate value as a fraction of learning_rate.
|
||||
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||
learning rate.
|
||||
@ -519,17 +521,18 @@ def cosine_decay_restarts(learning_rate,
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation.
|
||||
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Number of steps to decay over.
|
||||
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
Used to derive the number of iterations in the i-th period
|
||||
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to
|
||||
derive the number of iterations in the i-th period
|
||||
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||
Used to derive the initial learning rate of the i-th period:
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
Minimum learning rate value as a fraction of the learning_rate.
|
||||
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
|
||||
learning rate value as a fraction of the learning_rate.
|
||||
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||
learning rate.
|
||||
@ -602,16 +605,17 @@ def linear_cosine_decay(learning_rate,
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Number of steps to decay over.
|
||||
num_periods: Number of periods in the cosine part of the decay.
|
||||
See computation above.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||
of steps to decay over.
|
||||
num_periods: Number of periods in the cosine part of the decay. See
|
||||
computation above.
|
||||
alpha: See computation above.
|
||||
beta: See computation above.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'LinearCosineDecay'.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||
learning rate.
|
||||
@ -690,18 +694,19 @@ def noisy_linear_cosine_decay(learning_rate,
|
||||
Args:
|
||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||
The initial learning rate.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Global step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||
Number of steps to decay over.
|
||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||
step to use for the decay computation.
|
||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||
of steps to decay over.
|
||||
initial_variance: initial variance for the noise. See computation above.
|
||||
variance_decay: decay for the noise's variance. See computation above.
|
||||
num_periods: Number of periods in the cosine part of the decay.
|
||||
See computation above.
|
||||
num_periods: Number of periods in the cosine part of the decay. See
|
||||
computation above.
|
||||
alpha: See computation above.
|
||||
beta: See computation above.
|
||||
name: String. Optional name of the operation. Defaults to
|
||||
'NoisyLinearCosineDecay'.
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||
learning rate.
|
||||
|
@ -77,7 +77,8 @@ class Scaffold(object):
|
||||
The following pieces are directly accessible as attributes of the `Scaffold`
|
||||
object:
|
||||
|
||||
* `saver`: A `tf.train.Saver` object taking care of saving the variables.
|
||||
* `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
|
||||
variables.
|
||||
Picked from and stored into the `SAVERS` collection in the graph by default.
|
||||
* `init_op`: An op to run to initialize the variables. Picked from and
|
||||
stored into the `INIT_OP` collection in the graph by default.
|
||||
@ -133,9 +134,9 @@ class Scaffold(object):
|
||||
local_init_op: Optional op to initialize local variables.
|
||||
summary_op: Optional op to gather all summaries. Must return a scalar
|
||||
string tensor containing a serialized `Summary` proto.
|
||||
saver: Optional `tf.train.Saver` object to use to save and restore
|
||||
variables. May also be a `tf.train.Checkpoint` object, in which case
|
||||
object-based checkpoints are saved. This will also load some
|
||||
saver: Optional `tf.compat.v1.train.Saver` object to use to save and
|
||||
restore variables. May also be a `tf.train.Checkpoint` object, in which
|
||||
case object-based checkpoints are saved. This will also load some
|
||||
object-based checkpoints saved from elsewhere, but that loading may be
|
||||
fragile since it uses fixed keys rather than performing a full
|
||||
graph-based match. For example if a variable has two paths from the
|
||||
@ -199,8 +200,9 @@ class Scaffold(object):
|
||||
resources.report_uninitialized_resources()
|
||||
], 0)
|
||||
|
||||
self._ready_op = Scaffold.get_or_default(
|
||||
'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
|
||||
self._ready_op = Scaffold.get_or_default('ready_op',
|
||||
ops.GraphKeys.READY_OP,
|
||||
default_ready_op)
|
||||
if self._ready_for_local_init_op is None:
|
||||
|
||||
def default_ready_for_local_init_op():
|
||||
@ -219,8 +221,9 @@ class Scaffold(object):
|
||||
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
|
||||
Scaffold.default_local_init_op)
|
||||
if self._summary_op is None:
|
||||
self._summary_op = Scaffold.get_or_default(
|
||||
'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all)
|
||||
self._summary_op = Scaffold.get_or_default('summary_op',
|
||||
ops.GraphKeys.SUMMARY_OP,
|
||||
summary.merge_all)
|
||||
# pylint: disable=g-long-lambda
|
||||
if self._saver is None:
|
||||
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
|
||||
@ -292,7 +295,8 @@ class Scaffold(object):
|
||||
|
||||
This op is used during session initialization when a Scaffold is
|
||||
initialized without specifying the local_init_op arg. It includes
|
||||
`tf.local_variables_initializer`, `tf.tables_initializer`, and also
|
||||
`tf.compat.v1.local_variables_initializer`,
|
||||
`tf.compat.v1.tables_initializer`, and also
|
||||
initializes local session resources.
|
||||
|
||||
Returns:
|
||||
@ -435,7 +439,8 @@ def MonitoredTrainingSession(
|
||||
For a chief, this utility sets proper session initializer/restorer. It also
|
||||
creates hooks related to checkpoint and summary saving. For workers, this
|
||||
utility sets proper session creator which waits for the chief to
|
||||
initialize/restore. Please check `tf.train.MonitoredSession` for more
|
||||
initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
|
||||
more
|
||||
information.
|
||||
|
||||
|
||||
@ -464,8 +469,9 @@ def MonitoredTrainingSession(
|
||||
to disk using a default summary saver. If both `save_summaries_steps` and
|
||||
`save_summaries_secs` are set to `None`, then the default summary saver
|
||||
isn't used. Default not enabled.
|
||||
config: an instance of `tf.ConfigProto` proto used to configure the session.
|
||||
It's the `config` argument of constructor of `tf.Session`.
|
||||
config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
|
||||
the session. It's the `config` argument of constructor of
|
||||
`tf.compat.v1.Session`.
|
||||
stop_grace_period_secs: Number of seconds given to threads to stop after
|
||||
`close()` has been called.
|
||||
log_step_count_steps: The frequency, in number of global steps, that the
|
||||
@ -591,7 +597,7 @@ class SessionCreator(object):
|
||||
|
||||
@tf_export(v1=['train.ChiefSessionCreator'])
|
||||
class ChiefSessionCreator(SessionCreator):
|
||||
"""Creates a tf.Session for a chief."""
|
||||
"""Creates a tf.compat.v1.Session for a chief."""
|
||||
|
||||
def __init__(self,
|
||||
scaffold=None,
|
||||
@ -643,7 +649,7 @@ class ChiefSessionCreator(SessionCreator):
|
||||
|
||||
@tf_export(v1=['train.WorkerSessionCreator'])
|
||||
class WorkerSessionCreator(SessionCreator):
|
||||
"""Creates a tf.Session for a worker."""
|
||||
"""Creates a tf.compat.v1.Session for a worker."""
|
||||
|
||||
def __init__(self,
|
||||
scaffold=None,
|
||||
@ -757,8 +763,9 @@ class _MonitoredSession(object):
|
||||
`step_fn` will be returned from `run_step_fn`, unless a stop is
|
||||
requested. In that case, the next `should_stop` call will return True.
|
||||
Example usage: ```python
|
||||
with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v =
|
||||
tf.add(c, 4.0) w = tf.add(c, 0.5)
|
||||
with tf.Graph().as_default(): c =
|
||||
tf.compat.v1.placeholder(dtypes.float32) v = tf.add(c, 4.0) w =
|
||||
tf.add(c, 0.5)
|
||||
def step_fn(step_context):
|
||||
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
|
||||
if a <= 4.5: step_context.request_stop()
|
||||
@ -808,7 +815,7 @@ class _MonitoredSession(object):
|
||||
"""Initializes the `step_context` argument for a `step_fn` invocation.
|
||||
|
||||
Args:
|
||||
session: An instance of `tf.Session`.
|
||||
session: An instance of `tf.compat.v1.Session`.
|
||||
run_with_hooks_fn: A function for running fetches and hooks.
|
||||
"""
|
||||
self._session = session
|
||||
@ -901,13 +908,13 @@ class _MonitoredSession(object):
|
||||
return self._coordinated_creator.tf_sess is None
|
||||
|
||||
def _tf_sess(self):
|
||||
"""Return underlying tf.Session object.
|
||||
"""Return underlying tf.compat.v1.Session object.
|
||||
|
||||
Warning: accessing the returned object in user code is likely to cause races
|
||||
or "flaky tests".
|
||||
|
||||
Returns:
|
||||
A tf.Session object.
|
||||
A tf.compat.v1.Session object.
|
||||
"""
|
||||
return self._coordinated_creator.tf_sess
|
||||
|
||||
@ -955,7 +962,7 @@ class MonitoredSession(_MonitoredSession):
|
||||
* suppresses `OutOfRange` error which indicates that all inputs have been
|
||||
processed if the monitored_session is used as a context
|
||||
|
||||
How to set `tf.Session` arguments:
|
||||
How to set `tf.compat.v1.Session` arguments:
|
||||
|
||||
* In most cases you can set session arguments as follows:
|
||||
|
||||
@ -973,7 +980,8 @@ class MonitoredSession(_MonitoredSession):
|
||||
|
||||
See `MonitoredTrainingSession` for an example usage based on chief or worker.
|
||||
|
||||
Note: This is not a `tf.Session`. For example, it cannot do following:
|
||||
Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
|
||||
following:
|
||||
|
||||
* it cannot be set as default session.
|
||||
* it cannot be sent to saver.save.
|
||||
@ -1004,14 +1012,15 @@ class SingularMonitoredSession(_MonitoredSession):
|
||||
"""Session-like object that handles initialization, restoring, and hooks.
|
||||
|
||||
Please note that this utility is not recommended for distributed settings.
|
||||
For distributed settings, please use `tf.train.MonitoredSession`. The
|
||||
For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
|
||||
The
|
||||
differences between `MonitoredSession` and `SingularMonitoredSession` are:
|
||||
|
||||
* `MonitoredSession` handles `AbortedError` and `UnavailableError` for
|
||||
distributed settings, but `SingularMonitoredSession` does not.
|
||||
* `MonitoredSession` can be created in `chief` or `worker` modes.
|
||||
`SingularMonitoredSession` is always created as `chief`.
|
||||
* You can access the raw `tf.Session` object used by
|
||||
* You can access the raw `tf.compat.v1.Session` object used by
|
||||
`SingularMonitoredSession`, whereas in MonitoredSession the raw session is
|
||||
private. This can be used:
|
||||
- To `run` without hooks.
|
||||
@ -1093,7 +1102,7 @@ class SingularMonitoredSession(_MonitoredSession):
|
||||
|
||||
|
||||
class _WrappedSession(object):
|
||||
"""Wrapper around a `tf.Session`.
|
||||
"""Wrapper around a `tf.compat.v1.Session`.
|
||||
|
||||
This wrapper is used as a base class for various session wrappers
|
||||
that provide additional functionality such as monitoring, coordination,
|
||||
@ -1108,7 +1117,8 @@ class _WrappedSession(object):
|
||||
"""Creates a `_WrappedSession`.
|
||||
|
||||
Args:
|
||||
sess: A `tf.Session` or `_WrappedSession` object. The wrapped session.
|
||||
sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped
|
||||
session.
|
||||
"""
|
||||
self._sess = sess
|
||||
self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
|
||||
@ -1293,7 +1303,7 @@ class _CoordinatedSession(_WrappedSession):
|
||||
"""Create a new `_CoordinatedSession`.
|
||||
|
||||
Args:
|
||||
sess: A `tf.Session` object. The wrapped session.
|
||||
sess: A `tf.compat.v1.Session` object. The wrapped session.
|
||||
coord: A `tf.train.Coordinator` object.
|
||||
stop_grace_period_secs: Number of seconds given to threads to stop after
|
||||
`close()` has been called.
|
||||
@ -1364,7 +1374,7 @@ class _HookedSession(_WrappedSession):
|
||||
"""Initializes a _HookedSession object.
|
||||
|
||||
Args:
|
||||
sess: A `tf.Session` or a `_WrappedSession` object.
|
||||
sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
|
||||
hooks: An iterable of `SessionRunHook' objects.
|
||||
"""
|
||||
|
||||
|
@ -56,9 +56,9 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||
E.g.:
|
||||
|
||||
```
|
||||
with tf.variable_scope('scope1'):
|
||||
with tf.variable_scope('scope2'):
|
||||
var = tf.get_variable('foo')
|
||||
with tf.compat.v1.variable_scope('scope1'):
|
||||
with tf.compat.v1.variable_scope('scope2'):
|
||||
var = tf.compat.v1.get_variable('foo')
|
||||
update_1 = tf.assign_moving_average(var, 0.0, 1.0)
|
||||
update_2 = tf.assign_moving_average(var, 0.0, 0.9)
|
||||
|
||||
@ -73,12 +73,13 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||
decay: A float Tensor or float value. The moving average decay.
|
||||
zero_debias: A python bool. If true, assume the variable is 0-initialized
|
||||
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
|
||||
`_zero_debias` for more details.
|
||||
`_zero_debias` for more details.
|
||||
name: Optional name of the returned operation.
|
||||
|
||||
Returns:
|
||||
A tensor which if evaluated will compute and return the new moving average.
|
||||
"""
|
||||
|
||||
def update_fn(v, value, decay=decay):
|
||||
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
||||
if decay.dtype != v.dtype.base_dtype:
|
||||
@ -96,8 +97,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
||||
# In a replica context, we update variable using the mean of value across
|
||||
# replicas.
|
||||
def merge_fn(strategy, v, value):
|
||||
value = strategy.extended.reduce_to(
|
||||
ds_reduce_util.ReduceOp.MEAN, value, v)
|
||||
value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
|
||||
v)
|
||||
return strategy.extended.update(v, update_fn, args=(value,))
|
||||
|
||||
return replica_context.merge_call(merge_fn, args=(variable, value))
|
||||
@ -124,15 +125,15 @@ def weighted_moving_average(value,
|
||||
Args:
|
||||
value: A numeric `Tensor`.
|
||||
decay: A float `Tensor` or float value. The moving average decay.
|
||||
weight: `Tensor` that keeps the current value of a weight.
|
||||
Shape should be able to multiply `value`.
|
||||
weight: `Tensor` that keeps the current value of a weight. Shape should be
|
||||
able to multiply `value`.
|
||||
truediv: Boolean, if `True`, dividing by `moving_average(weight)` is
|
||||
floating point division. If `False`, use division implied by dtypes.
|
||||
collections: List of graph collections keys to add the internal variables
|
||||
`value * weight` and `weight` to.
|
||||
Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
||||
name: Optional name of the returned operation.
|
||||
Defaults to "WeightedMovingAvg".
|
||||
`value * weight` and `weight` to. Defaults to
|
||||
`[GraphKeys.GLOBAL_VARIABLES]`.
|
||||
name: Optional name of the returned operation. Defaults to
|
||||
"WeightedMovingAvg".
|
||||
|
||||
Returns:
|
||||
An Operation that updates and returns the weighted moving average.
|
||||
@ -203,27 +204,34 @@ def _zero_debias(unbiased_var, value, decay):
|
||||
tensor will also update the shadow variables appropriately.
|
||||
"""
|
||||
with variable_scope.variable_scope(
|
||||
unbiased_var.name[:-len(":0")], values=[unbiased_var,
|
||||
value, decay]) as scope:
|
||||
unbiased_var.name[:-len(":0")], values=[unbiased_var, value,
|
||||
decay]) as scope:
|
||||
with ops.colocate_with(unbiased_var):
|
||||
with ops.init_scope():
|
||||
biased_initializer = init_ops.zeros_initializer(
|
||||
dtype=unbiased_var.dtype)(unbiased_var.get_shape())
|
||||
dtype=unbiased_var.dtype)(
|
||||
unbiased_var.get_shape())
|
||||
local_step_initializer = init_ops.zeros_initializer()
|
||||
|
||||
def _maybe_get_unique(name):
|
||||
"""Get name for a unique variable, if not `reuse=True`."""
|
||||
if variable_scope.get_variable_scope().reuse:
|
||||
return name
|
||||
vs_vars = [x.op.name for x in
|
||||
variable_scope.get_variable_scope().global_variables()]
|
||||
vs_vars = [
|
||||
x.op.name
|
||||
for x in variable_scope.get_variable_scope().global_variables()
|
||||
]
|
||||
full_name = variable_scope.get_variable_scope().name + "/" + name
|
||||
if full_name not in vs_vars: return name
|
||||
if full_name not in vs_vars:
|
||||
return name
|
||||
idx = 1
|
||||
while full_name + ("_%d" % idx) in vs_vars:
|
||||
idx += 1
|
||||
return name + ("_%d" % idx)
|
||||
|
||||
biased_var = variable_scope.get_variable(
|
||||
_maybe_get_unique("biased"), initializer=biased_initializer,
|
||||
_maybe_get_unique("biased"),
|
||||
initializer=biased_initializer,
|
||||
trainable=False)
|
||||
local_step = variable_scope.get_variable(
|
||||
_maybe_get_unique("local_step"),
|
||||
@ -233,18 +241,17 @@ def _zero_debias(unbiased_var, value, decay):
|
||||
trainable=False)
|
||||
|
||||
# Get an update ops for both shadow variables.
|
||||
update_biased = state_ops.assign_sub(biased_var,
|
||||
(biased_var - value) * decay,
|
||||
name=scope.name)
|
||||
update_biased = state_ops.assign_sub(
|
||||
biased_var, (biased_var - value) * decay, name=scope.name)
|
||||
update_local_step = local_step.assign_add(1)
|
||||
|
||||
# Compute the value of the delta to update the unbiased EMA. Make sure to
|
||||
# use the new values of the biased variable and the local step.
|
||||
with ops.control_dependencies([update_biased, update_local_step]):
|
||||
# This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
|
||||
unbiased_ema_delta = (unbiased_var - biased_var.read_value() /
|
||||
(1 - math_ops.pow(
|
||||
1.0 - decay, local_step.read_value())))
|
||||
unbiased_ema_delta = (
|
||||
unbiased_var - biased_var.read_value() /
|
||||
(1 - math_ops.pow(1.0 - decay, local_step.read_value())))
|
||||
|
||||
return unbiased_ema_delta
|
||||
|
||||
@ -315,7 +322,7 @@ class ExponentialMovingAverage(object):
|
||||
for a given variable.
|
||||
* Build a model normally but load the checkpoint files to evaluate by using
|
||||
the shadow variable names. For this use the `average_name()` method. See
|
||||
the `tf.train.Saver` for more
|
||||
the `tf.compat.v1.train.Saver` for more
|
||||
information on restoring saved variables.
|
||||
|
||||
Example of restoring the shadow variable values:
|
||||
@ -324,13 +331,17 @@ class ExponentialMovingAverage(object):
|
||||
# Create a Saver that loads variables from their saved shadow values.
|
||||
shadow_var0_name = ema.average_name(var0)
|
||||
shadow_var1_name = ema.average_name(var1)
|
||||
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
|
||||
saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name:
|
||||
var1})
|
||||
saver.restore(...checkpoint filename...)
|
||||
# var0 and var1 now hold the moving average values
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, decay, num_updates=None, zero_debias=False,
|
||||
def __init__(self,
|
||||
decay,
|
||||
num_updates=None,
|
||||
zero_debias=False,
|
||||
name="ExponentialMovingAverage"):
|
||||
"""Creates a new ExponentialMovingAverage object.
|
||||
|
||||
@ -376,7 +387,7 @@ class ExponentialMovingAverage(object):
|
||||
|
||||
shadow variables are created with `trainable=False` and added to the
|
||||
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
|
||||
`tf.global_variables()`.
|
||||
`tf.compat.v1.global_variables()`.
|
||||
|
||||
Returns an op that updates all shadow variables from the current value of
|
||||
their associated variables.
|
||||
@ -386,8 +397,8 @@ class ExponentialMovingAverage(object):
|
||||
be called in a loop.
|
||||
|
||||
Args:
|
||||
var_list: A list of Variable or Tensor objects. The variables
|
||||
and Tensors must be of types bfloat16, float16, float32, or float64.
|
||||
var_list: A list of Variable or Tensor objects. The variables and Tensors
|
||||
must be of types bfloat16, float16, float32, or float64.
|
||||
|
||||
Returns:
|
||||
An Operation that updates the moving averages.
|
||||
@ -417,10 +428,11 @@ class ExponentialMovingAverage(object):
|
||||
# tensors, we rely on the existing device allocation mechanism.
|
||||
with ops.init_scope():
|
||||
if isinstance(var, variables.Variable):
|
||||
avg = slot_creator.create_slot(var,
|
||||
var.initialized_value(),
|
||||
self.name,
|
||||
colocate_with_primary=True)
|
||||
avg = slot_creator.create_slot(
|
||||
var,
|
||||
var.initialized_value(),
|
||||
self.name,
|
||||
colocate_with_primary=True)
|
||||
# NOTE(mrry): We only add `tf.Variable` objects to the
|
||||
# `MOVING_AVERAGE_VARIABLES` collection.
|
||||
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
||||
@ -428,9 +440,9 @@ class ExponentialMovingAverage(object):
|
||||
avg = slot_creator.create_zeros_slot(
|
||||
var,
|
||||
self.name,
|
||||
colocate_with_primary=(var.op.type in ["Variable",
|
||||
"VariableV2",
|
||||
"VarHandleOp"]))
|
||||
colocate_with_primary=(var.op.type in [
|
||||
"Variable", "VariableV2", "VarHandleOp"
|
||||
]))
|
||||
if self._zero_debias:
|
||||
zero_debias_true.add(avg)
|
||||
self._averages[var] = avg
|
||||
@ -438,16 +450,16 @@ class ExponentialMovingAverage(object):
|
||||
with ops.name_scope(self.name) as scope:
|
||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||
if self._num_updates is not None:
|
||||
num_updates = math_ops.cast(self._num_updates,
|
||||
dtypes.float32,
|
||||
name="num_updates")
|
||||
num_updates = math_ops.cast(
|
||||
self._num_updates, dtypes.float32, name="num_updates")
|
||||
decay = math_ops.minimum(decay,
|
||||
(1.0 + num_updates) / (10.0 + num_updates))
|
||||
updates = []
|
||||
for var in var_list:
|
||||
zero_debias = self._averages[var] in zero_debias_true
|
||||
updates.append(assign_moving_average(
|
||||
self._averages[var], var, decay, zero_debias=zero_debias))
|
||||
updates.append(
|
||||
assign_moving_average(
|
||||
self._averages[var], var, decay, zero_debias=zero_debias))
|
||||
return control_flow_ops.group(*updates, name=scope)
|
||||
|
||||
def average(self, var):
|
||||
@ -472,7 +484,7 @@ class ExponentialMovingAverage(object):
|
||||
To restore variables, you have to know the name of the shadow variables.
|
||||
That name and the original variable can then be passed to a `Saver()` object
|
||||
to restore the variable from the moving average value with:
|
||||
`saver = tf.train.Saver({ema.average_name(var): var})`
|
||||
`saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
|
||||
|
||||
`average_name()` can be called whether or not `apply()` has been called.
|
||||
|
||||
@ -499,7 +511,7 @@ class ExponentialMovingAverage(object):
|
||||
|
||||
```python
|
||||
variables_to_restore = ema.variables_to_restore()
|
||||
saver = tf.train.Saver(variables_to_restore)
|
||||
saver = tf.compat.v1.train.Saver(variables_to_restore)
|
||||
```
|
||||
|
||||
Below is an example of such mapping:
|
||||
|
@ -434,14 +434,14 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
|
||||
|
||||
Raises:
|
||||
ValueError: if `sess` is None and there isn't any default session.
|
||||
TypeError: if `sess` is not a `tf.Session` object.
|
||||
TypeError: if `sess` is not a `tf.compat.v1.Session` object.
|
||||
|
||||
Returns:
|
||||
A list of threads.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called with eager execution enabled.
|
||||
ValueError: If called without a default `tf.Session` registered.
|
||||
ValueError: If called without a default `tf.compat.v1.Session` registered.
|
||||
|
||||
@compatibility(eager)
|
||||
Not compatible with eager execution. To ingest data under eager execution,
|
||||
|
@ -56,7 +56,6 @@ from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(allenl): Remove these aliases once all users are migrated off.
|
||||
get_checkpoint_state = checkpoint_management.get_checkpoint_state
|
||||
update_checkpoint_state = checkpoint_management.update_checkpoint_state
|
||||
@ -174,13 +173,11 @@ class BaseSaverBuilder(object):
|
||||
tensors = []
|
||||
for spec in saveable.specs:
|
||||
tensors.append(
|
||||
io_ops.restore_v2(
|
||||
filename_tensor,
|
||||
[spec.name],
|
||||
[spec.slice_spec],
|
||||
[spec.dtype])[0])
|
||||
io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
|
||||
[spec.dtype])[0])
|
||||
|
||||
return tensors
|
||||
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
def sharded_filename(self, filename_tensor, shard, num_shards):
|
||||
@ -217,8 +214,8 @@ class BaseSaverBuilder(object):
|
||||
from each device.
|
||||
|
||||
Args:
|
||||
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A
|
||||
FILENAME*, but as a prefix of a V2 checkpoint;
|
||||
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
|
||||
but as a prefix of a V2 checkpoint;
|
||||
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
|
||||
returned by _GroupByDevices().
|
||||
|
||||
@ -319,8 +316,8 @@ class BaseSaverBuilder(object):
|
||||
saveables: A list of SaveableObject objects.
|
||||
restore_sequentially: True if we want to restore variables sequentially
|
||||
within a shard.
|
||||
reshape: True if we want to reshape loaded tensors to the shape of
|
||||
the corresponding variable.
|
||||
reshape: True if we want to reshape loaded tensors to the shape of the
|
||||
corresponding variable.
|
||||
preferred_shard: Shard to open first when loading a sharded file.
|
||||
name: Name for the returned op.
|
||||
|
||||
@ -361,12 +358,12 @@ class BaseSaverBuilder(object):
|
||||
|
||||
Args:
|
||||
filename_tensor: Tensor for the path of the file to load.
|
||||
per_device: A list of (device, SaveableObject) pairs, as
|
||||
returned by _GroupByDevices().
|
||||
per_device: A list of (device, SaveableObject) pairs, as returned by
|
||||
_GroupByDevices().
|
||||
restore_sequentially: True if we want to restore variables sequentially
|
||||
within a shard.
|
||||
reshape: True if we want to reshape loaded tensors to the shape of
|
||||
the corresponding variable.
|
||||
reshape: True if we want to reshape loaded tensors to the shape of the
|
||||
corresponding variable.
|
||||
|
||||
Returns:
|
||||
An Operation that restores the variables.
|
||||
@ -424,14 +421,13 @@ class BaseSaverBuilder(object):
|
||||
|
||||
Args:
|
||||
names_to_saveables: A dictionary mapping name to a Variable or
|
||||
SaveableObject. Each name will be associated with the
|
||||
corresponding variable in the checkpoint.
|
||||
reshape: If True, allow restoring parameters from a checkpoint
|
||||
that where the parameters have a different shape. This is
|
||||
only needed when you try to restore from a Dist-Belief checkpoint,
|
||||
and only some times.
|
||||
sharded: If True, shard the checkpoints, one per device that has
|
||||
Variable nodes.
|
||||
SaveableObject. Each name will be associated with the corresponding
|
||||
variable in the checkpoint.
|
||||
reshape: If True, allow restoring parameters from a checkpoint that where
|
||||
the parameters have a different shape. This is only needed when you try
|
||||
to restore from a Dist-Belief checkpoint, and only some times.
|
||||
sharded: If True, shard the checkpoints, one per device that has Variable
|
||||
nodes.
|
||||
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
|
||||
are created, old ones are deleted. If None or 0, no checkpoints are
|
||||
deleted from the filesystem but only the last one is kept in the
|
||||
@ -597,8 +593,8 @@ def _get_saver_or_default():
|
||||
if len(savers) > 1:
|
||||
raise RuntimeError(
|
||||
"More than one item in collection {}. "
|
||||
"Please indicate which one to use by passing it to the constructor.".
|
||||
format(collection_key))
|
||||
"Please indicate which one to use by passing it to the constructor."
|
||||
.format(collection_key))
|
||||
return savers[0]
|
||||
saver = Saver(sharded=True, allow_empty=True)
|
||||
if saver is not None:
|
||||
@ -662,9 +658,9 @@ class Saver(object):
|
||||
```python
|
||||
...
|
||||
# Create a saver.
|
||||
saver = tf.train.Saver(...variables...)
|
||||
saver = tf.compat.v1.train.Saver(...variables...)
|
||||
# Launch the graph and train, saving the model every 1,000 steps.
|
||||
sess = tf.Session()
|
||||
sess = tf.compat.v1.Session()
|
||||
for step in xrange(1000000):
|
||||
sess.run(..training_op..)
|
||||
if step % 1000 == 0:
|
||||
@ -717,13 +713,13 @@ class Saver(object):
|
||||
v2 = tf.Variable(..., name='v2')
|
||||
|
||||
# Pass the variables as a dict:
|
||||
saver = tf.train.Saver({'v1': v1, 'v2': v2})
|
||||
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
|
||||
|
||||
# Or pass them as a list.
|
||||
saver = tf.train.Saver([v1, v2])
|
||||
saver = tf.compat.v1.train.Saver([v1, v2])
|
||||
# Passing a list is equivalent to passing a dict with the variable op names
|
||||
# as keys:
|
||||
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
|
||||
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
|
||||
```
|
||||
|
||||
The optional `reshape` argument, if `True`, allows restoring a variable from
|
||||
@ -738,35 +734,33 @@ class Saver(object):
|
||||
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
|
||||
names to `SaveableObject`s. If `None`, defaults to the list of all
|
||||
saveable objects.
|
||||
reshape: If `True`, allows restoring parameters from a checkpoint
|
||||
where the variables have a different shape.
|
||||
reshape: If `True`, allows restoring parameters from a checkpoint where
|
||||
the variables have a different shape.
|
||||
sharded: If `True`, shard the checkpoints, one per device.
|
||||
max_to_keep: Maximum number of recent checkpoints to keep.
|
||||
Defaults to 5.
|
||||
keep_checkpoint_every_n_hours: How often to keep checkpoints.
|
||||
Defaults to 10,000 hours.
|
||||
max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
|
||||
keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
|
||||
10,000 hours.
|
||||
name: String. Optional name to use as a prefix when adding operations.
|
||||
restore_sequentially: A `Bool`, which if true, causes restore of different
|
||||
variables to happen sequentially within each device. This can lower
|
||||
memory usage when restoring very large models.
|
||||
saver_def: Optional `SaverDef` proto to use instead of running the
|
||||
builder. This is only useful for specialty code that wants to recreate
|
||||
a `Saver` object for a previously built `Graph` that had a `Saver`.
|
||||
The `saver_def` proto should be the one returned by the
|
||||
`as_saver_def()` call of the `Saver` that was created for that `Graph`.
|
||||
builder. This is only useful for specialty code that wants to recreate a
|
||||
`Saver` object for a previously built `Graph` that had a `Saver`. The
|
||||
`saver_def` proto should be the one returned by the `as_saver_def()`
|
||||
call of the `Saver` that was created for that `Graph`.
|
||||
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
|
||||
Defaults to `BulkSaverBuilder()`.
|
||||
defer_build: If `True`, defer adding the save and restore ops to the
|
||||
`build()` call. In that case `build()` should be called before
|
||||
finalizing the graph or using the saver.
|
||||
allow_empty: If `False` (default) raise an error if there are no
|
||||
variables in the graph. Otherwise, construct the saver anyway and make
|
||||
it a no-op.
|
||||
allow_empty: If `False` (default) raise an error if there are no variables
|
||||
in the graph. Otherwise, construct the saver anyway and make it a no-op.
|
||||
write_version: controls what format to use when saving checkpoints. It
|
||||
also affects certain filepath matching logic. The V2 format is the
|
||||
recommended choice: it is much more optimized than V1 in terms of
|
||||
memory required and latency incurred during restore. Regardless of
|
||||
this flag, the Saver is able to restore from both V2 and V1 checkpoints.
|
||||
recommended choice: it is much more optimized than V1 in terms of memory
|
||||
required and latency incurred during restore. Regardless of this
|
||||
flag, the Saver is able to restore from both V2 and V1 checkpoints.
|
||||
pad_step_number: if True, pads the global step number in the checkpoint
|
||||
filepaths to some fixed width (8 by default). This is turned off by
|
||||
default.
|
||||
@ -877,7 +871,8 @@ class Saver(object):
|
||||
name=self._name,
|
||||
restore_sequentially=self._restore_sequentially,
|
||||
filename=checkpoint_path,
|
||||
build_save=build_save, build_restore=build_restore)
|
||||
build_save=build_save,
|
||||
build_restore=build_restore)
|
||||
elif self.saver_def and self._name:
|
||||
# Since self._name is used as a name_scope by builder(), we are
|
||||
# overloading the use of this field to represent the "import_scope" as
|
||||
@ -997,8 +992,8 @@ class Saver(object):
|
||||
saver_def.filename_tensor_name, export_scope)
|
||||
saver_def.save_tensor_name = ops.strip_name_scope(
|
||||
saver_def.save_tensor_name, export_scope)
|
||||
saver_def.restore_op_name = ops.strip_name_scope(
|
||||
saver_def.restore_op_name, export_scope)
|
||||
saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
|
||||
export_scope)
|
||||
return saver_def
|
||||
|
||||
@staticmethod
|
||||
@ -1092,14 +1087,13 @@ class Saver(object):
|
||||
Args:
|
||||
sess: A Session to use to save the variables.
|
||||
save_path: String. Prefix of filenames created for the checkpoint.
|
||||
global_step: If provided the global step number is appended to
|
||||
`save_path` to create the checkpoint filenames. The optional argument
|
||||
can be a `Tensor`, a `Tensor` name or an integer.
|
||||
global_step: If provided the global step number is appended to `save_path`
|
||||
to create the checkpoint filenames. The optional argument can be a
|
||||
`Tensor`, a `Tensor` name or an integer.
|
||||
latest_filename: Optional name for the protocol buffer file that will
|
||||
contains the list of most recent checkpoints. That file,
|
||||
kept in the same directory as the checkpoint files, is automatically
|
||||
managed by the saver to keep track of recent checkpoints. Defaults to
|
||||
'checkpoint'.
|
||||
contains the list of most recent checkpoints. That file, kept in the
|
||||
same directory as the checkpoint files, is automatically managed by the
|
||||
saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
|
||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||
write_meta_graph: `Boolean` indicating whether or not to write the meta
|
||||
graph file.
|
||||
@ -1107,7 +1101,8 @@ class Saver(object):
|
||||
`CheckpointStateProto`.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
[Stripping Default-Valued
|
||||
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||
which in the same directory of save_path and with `_debug` added before
|
||||
the file extension. This is only enabled when `write_meta_graph` is
|
||||
@ -1151,8 +1146,7 @@ class Saver(object):
|
||||
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
|
||||
else:
|
||||
checkpoint_file = save_path
|
||||
if os.path.basename(
|
||||
save_path) == latest_filename and not self._sharded:
|
||||
if os.path.basename(save_path) == latest_filename and not self._sharded:
|
||||
# Guard against collision between data file and checkpoint state file.
|
||||
raise ValueError(
|
||||
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
|
||||
@ -1197,7 +1191,8 @@ class Saver(object):
|
||||
if not context.executing_eagerly():
|
||||
with sess.graph.as_default():
|
||||
self.export_meta_graph(
|
||||
meta_graph_filename, strip_default_attrs=strip_default_attrs,
|
||||
meta_graph_filename,
|
||||
strip_default_attrs=strip_default_attrs,
|
||||
save_debug_info=save_debug_info)
|
||||
|
||||
if self._is_empty:
|
||||
@ -1225,11 +1220,12 @@ class Saver(object):
|
||||
clear_devices: Whether or not to clear the device field for an `Operation`
|
||||
or `Tensor` during export.
|
||||
clear_extraneous_savers: Remove any Saver-related information from the
|
||||
graph (both Save/Restore ops and SaverDefs) that are not associated
|
||||
with this Saver.
|
||||
graph (both Save/Restore ops and SaverDefs) that are not associated with
|
||||
this Saver.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
[Stripping Default-Valued
|
||||
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||
which in the same directory of filename and with `_debug` added before
|
||||
the file extension.
|
||||
@ -1274,8 +1270,8 @@ class Saver(object):
|
||||
raise ValueError("Can't load save_path when it is None.")
|
||||
|
||||
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
|
||||
raise ValueError("The passed save_path is not a valid checkpoint: "
|
||||
+ compat.as_text(save_path))
|
||||
raise ValueError("The passed save_path is not a valid checkpoint: " +
|
||||
compat.as_text(save_path))
|
||||
|
||||
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
||||
try:
|
||||
@ -1330,13 +1326,15 @@ class Saver(object):
|
||||
key: One of the GraphKeys or user-defined string.
|
||||
export_scope: Optional `string`. Name scope to remove.
|
||||
"""
|
||||
meta_graph.add_collection_def(meta_graph_def, key,
|
||||
export_scope=export_scope)
|
||||
meta_graph.add_collection_def(
|
||||
meta_graph_def, key, export_scope=export_scope)
|
||||
|
||||
|
||||
@tf_export(v1=["train.import_meta_graph"])
|
||||
def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
import_scope=None, **kwargs):
|
||||
def import_meta_graph(meta_graph_or_file,
|
||||
clear_devices=False,
|
||||
import_scope=None,
|
||||
**kwargs):
|
||||
"""Recreates a Graph saved in a `MetaGraphDef` proto.
|
||||
|
||||
This function takes a `MetaGraphDef` protocol buffer as input. If
|
||||
@ -1358,10 +1356,10 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
```Python
|
||||
...
|
||||
# Create a saver.
|
||||
saver = tf.train.Saver(...variables...)
|
||||
saver = tf.compat.v1.train.Saver(...variables...)
|
||||
# Remember the training_op we want to run by adding it to a collection.
|
||||
tf.add_to_collection('train_op', train_op)
|
||||
sess = tf.Session()
|
||||
tf.compat.v1.add_to_collection('train_op', train_op)
|
||||
sess = tf.compat.v1.Session()
|
||||
for step in xrange(1000000):
|
||||
sess.run(train_op)
|
||||
if step % 1000 == 0:
|
||||
@ -1374,12 +1372,13 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
the model from scratch.
|
||||
|
||||
```Python
|
||||
with tf.Session() as sess:
|
||||
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
|
||||
with tf.compat.v1.Session() as sess:
|
||||
new_saver =
|
||||
tf.compat.v1.train.import_meta_graph('my-save-dir/my-model-10000.meta')
|
||||
new_saver.restore(sess, 'my-save-dir/my-model-10000')
|
||||
# tf.get_collection() returns a list. In this example we only want the
|
||||
# first one.
|
||||
train_op = tf.get_collection('train_op')[0]
|
||||
# tf.compat.v1.get_collection() returns a list. In this example we only want
|
||||
# the first one.
|
||||
train_op = tf.compat.v1.get_collection('train_op')[0]
|
||||
for step in xrange(1000000):
|
||||
sess.run(train_op)
|
||||
```
|
||||
@ -1393,14 +1392,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
|
||||
```Python
|
||||
# Saving contents and operations.
|
||||
v1 = tf.placeholder(tf.float32, name="v1")
|
||||
v2 = tf.placeholder(tf.float32, name="v2")
|
||||
v1 = tf.compat.v1.placeholder(tf.float32, name="v1")
|
||||
v2 = tf.compat.v1.placeholder(tf.float32, name="v2")
|
||||
v3 = tf.mul(v1, v2)
|
||||
vx = tf.Variable(10.0, name="vx")
|
||||
v4 = tf.add(v3, vx, name="v4")
|
||||
saver = tf.train.Saver([vx])
|
||||
sess = tf.Session()
|
||||
sess.run(tf.initialize_all_variables())
|
||||
saver = tf.compat.v1.train.Saver([vx])
|
||||
sess = tf.compat.v1.Session()
|
||||
sess.run(tf.compat.v1.initialize_all_variables())
|
||||
sess.run(vx.assign(tf.add(vx, vx)))
|
||||
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
|
||||
print(result)
|
||||
@ -1411,8 +1410,8 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
|
||||
```Python
|
||||
# Restoring variables and running operations.
|
||||
saver = tf.train.import_meta_graph("./model_ex1.meta")
|
||||
sess = tf.Session()
|
||||
saver = tf.compat.v1.train.import_meta_graph("./model_ex1.meta")
|
||||
sess = tf.compat.v1.Session()
|
||||
saver.restore(sess, "./model_ex1")
|
||||
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
|
||||
print(result)
|
||||
@ -1441,13 +1440,16 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
execution is enabled.
|
||||
@end_compatibility
|
||||
""" # pylint: disable=g-doc-exception
|
||||
return _import_meta_graph_with_return_elements(
|
||||
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
|
||||
return _import_meta_graph_with_return_elements(meta_graph_or_file,
|
||||
clear_devices, import_scope,
|
||||
**kwargs)[0]
|
||||
|
||||
|
||||
def _import_meta_graph_with_return_elements(
|
||||
meta_graph_or_file, clear_devices=False, import_scope=None,
|
||||
return_elements=None, **kwargs):
|
||||
def _import_meta_graph_with_return_elements(meta_graph_or_file,
|
||||
clear_devices=False,
|
||||
import_scope=None,
|
||||
return_elements=None,
|
||||
**kwargs):
|
||||
"""Import MetaGraph, and return both a saver and returned elements."""
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError("Exporting/importing meta graphs is not supported when "
|
||||
@ -1466,13 +1468,13 @@ def _import_meta_graph_with_return_elements(
|
||||
return_elements=return_elements,
|
||||
**kwargs))
|
||||
|
||||
saver = _create_saver_from_imported_meta_graph(
|
||||
meta_graph_def, import_scope, imported_vars)
|
||||
saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
|
||||
imported_vars)
|
||||
return saver, imported_return_elements
|
||||
|
||||
|
||||
def _create_saver_from_imported_meta_graph(
|
||||
meta_graph_def, import_scope, imported_vars):
|
||||
def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
|
||||
imported_vars):
|
||||
"""Return a saver for restoring variable values to an imported MetaGraph."""
|
||||
if meta_graph_def.HasField("saver_def"):
|
||||
# Infer the scope that is prepended by `import_scoped_meta_graph`.
|
||||
@ -1510,7 +1512,9 @@ def export_meta_graph(filename=None,
|
||||
save_debug_info=False,
|
||||
**kwargs):
|
||||
# pylint: disable=line-too-long
|
||||
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
|
||||
"""Returns `MetaGraphDef` proto.
|
||||
|
||||
Optionally writes it to filename.
|
||||
|
||||
This function exports the graph, saver, and collection objects into
|
||||
`MetaGraphDef` protocol buffer with the intention of it being imported
|
||||
@ -1518,29 +1522,29 @@ def export_meta_graph(filename=None,
|
||||
a subgraph.
|
||||
|
||||
Args:
|
||||
filename: Optional filename including the path for writing the
|
||||
generated `MetaGraphDef` protocol buffer.
|
||||
filename: Optional filename including the path for writing the generated
|
||||
`MetaGraphDef` protocol buffer.
|
||||
meta_info_def: `MetaInfoDef` protocol buffer.
|
||||
graph_def: `GraphDef` protocol buffer.
|
||||
saver_def: `SaverDef` protocol buffer.
|
||||
collection_list: List of string keys to collect.
|
||||
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
|
||||
graph: The `Graph` to export. If `None`, use the default graph.
|
||||
export_scope: Optional `string`. Name scope under which to extract
|
||||
the subgraph. The scope name will be striped from the node definitions
|
||||
for easy import later into new name scopes. If `None`, the whole graph
|
||||
is exported. graph_def and export_scope cannot both be specified.
|
||||
export_scope: Optional `string`. Name scope under which to extract the
|
||||
subgraph. The scope name will be striped from the node definitions for
|
||||
easy import later into new name scopes. If `None`, the whole graph is
|
||||
exported. graph_def and export_scope cannot both be specified.
|
||||
clear_devices: Whether or not to clear the device field for an `Operation`
|
||||
or `Tensor` during export.
|
||||
clear_extraneous_savers: Remove any Saver-related information from the
|
||||
graph (both Save/Restore ops and SaverDefs) that are not associated
|
||||
with the provided SaverDef.
|
||||
clear_extraneous_savers: Remove any Saver-related information from the graph
|
||||
(both Save/Restore ops and SaverDefs) that are not associated with the
|
||||
provided SaverDef.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||
which in the same directory of filename and with `_debug` added before
|
||||
the file extend.
|
||||
which in the same directory of filename and with `_debug` added before the
|
||||
file extend.
|
||||
**kwargs: Optional keyed arguments.
|
||||
|
||||
Returns:
|
||||
@ -1603,10 +1607,8 @@ def object_graph_key_mapping(checkpoint_path):
|
||||
Dictionary mapping tensor names to checkpoint keys.
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||
object_graph_string = reader.get_tensor(
|
||||
trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_proto = (
|
||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
names_to_keys = {}
|
||||
for node in object_graph_proto.nodes:
|
||||
@ -1615,9 +1617,11 @@ def object_graph_key_mapping(checkpoint_path):
|
||||
return names_to_keys
|
||||
|
||||
|
||||
def saver_from_object_based_checkpoint(
|
||||
checkpoint_path, var_list=None, builder=None, names_to_keys=None,
|
||||
cached_saver=None):
|
||||
def saver_from_object_based_checkpoint(checkpoint_path,
|
||||
var_list=None,
|
||||
builder=None,
|
||||
names_to_keys=None,
|
||||
cached_saver=None):
|
||||
"""Return a `Saver` which reads from an object-based checkpoint.
|
||||
|
||||
This function validates that all variables in the variables list are remapped
|
||||
@ -1659,8 +1663,8 @@ def saver_from_object_based_checkpoint(
|
||||
try:
|
||||
names_to_keys = object_graph_key_mapping(checkpoint_path)
|
||||
except errors.NotFoundError:
|
||||
raise ValueError("Checkpoint in %s not an object-based checkpoint."
|
||||
% checkpoint_path)
|
||||
raise ValueError("Checkpoint in %s not an object-based checkpoint." %
|
||||
checkpoint_path)
|
||||
if var_list is None:
|
||||
var_list = variables._all_saveable_objects() # pylint: disable=protected-access
|
||||
if builder is None:
|
||||
@ -1677,7 +1681,8 @@ def saver_from_object_based_checkpoint(
|
||||
extra_names = previous_names - current_names
|
||||
intersecting_names = previous_names.intersection(current_names)
|
||||
raise errors.NotFoundError(
|
||||
None, None,
|
||||
None,
|
||||
None,
|
||||
message=(
|
||||
"\n\nExisting variables not in the checkpoint: %s\n\n"
|
||||
"Variables names when this checkpoint was written which don't "
|
||||
@ -1695,9 +1700,9 @@ def saver_from_object_based_checkpoint(
|
||||
"existed, and if variable names have changed you may need to "
|
||||
"make this a dictionary with the old names as keys. If you're "
|
||||
"using an Estimator, you'll need to return a tf.train.Saver "
|
||||
"inside a tf.train.Scaffold from your model_fn.")
|
||||
% (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)),
|
||||
len(intersecting_names)))
|
||||
"inside a tf.train.Scaffold from your model_fn.") %
|
||||
(", ".join(sorted(missing_names)), ", ".join(
|
||||
sorted(extra_names)), len(intersecting_names)))
|
||||
for saveable in saveables:
|
||||
for spec in saveable.specs:
|
||||
spec.name = names_to_keys[spec.name]
|
||||
|
@ -32,21 +32,19 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
||||
"""Creates a `tf.train.ServerDef` protocol buffer.
|
||||
|
||||
Args:
|
||||
server_or_cluster_def: A `tf.train.ServerDef` or
|
||||
`tf.train.ClusterDef` protocol buffer, or a
|
||||
`tf.train.ClusterSpec` object, describing the server to be
|
||||
defined and/or the cluster of which it is a member.
|
||||
job_name: (Optional.) Specifies the name of the job of which the server
|
||||
is a member. Defaults to the value in `server_or_cluster_def`, if
|
||||
specified.
|
||||
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
|
||||
protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
|
||||
to be defined and/or the cluster of which it is a member.
|
||||
job_name: (Optional.) Specifies the name of the job of which the server is a
|
||||
member. Defaults to the value in `server_or_cluster_def`, if specified.
|
||||
task_index: (Optional.) Specifies the task index of the server in its job.
|
||||
Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
|
||||
defaults to 0 if the server's job has only one task.
|
||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
|
||||
in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
|
||||
config: (Options.) A `tf.ConfigProto` that specifies default configuration
|
||||
options for all sessions that run on this server.
|
||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
|
||||
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
|
||||
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
|
||||
Returns:
|
||||
A `tf.train.ServerDef`.
|
||||
@ -88,7 +86,9 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
||||
|
||||
server_def = tensorflow_server_pb2.ServerDef(
|
||||
cluster=cluster_spec.as_cluster_def(),
|
||||
job_name=job_name, task_index=task_index, protocol=protocol)
|
||||
job_name=job_name,
|
||||
task_index=task_index,
|
||||
protocol=protocol)
|
||||
if config is not None:
|
||||
server_def.default_session_config.MergeFrom(config)
|
||||
return server_def
|
||||
@ -99,8 +99,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
||||
class Server(object):
|
||||
"""An in-process TensorFlow server, for use in distributed training.
|
||||
|
||||
A `tf.train.Server` instance encapsulates a set of devices and a
|
||||
`tf.Session` target that
|
||||
A `tf.distribute.Server` instance encapsulates a set of devices and a
|
||||
`tf.compat.v1.Session` target that
|
||||
can participate in distributed training. A server belongs to a
|
||||
cluster (specified by a `tf.train.ClusterSpec`), and
|
||||
corresponds to a particular task in a named job. The server can
|
||||
@ -120,31 +120,30 @@ class Server(object):
|
||||
override any information provided in `server_or_cluster_def`.
|
||||
|
||||
Args:
|
||||
server_or_cluster_def: A `tf.train.ServerDef` or
|
||||
`tf.train.ClusterDef` protocol buffer, or a
|
||||
`tf.train.ClusterSpec` object, describing the server to be
|
||||
created and/or the cluster of which it is a member.
|
||||
job_name: (Optional.) Specifies the name of the job of which the server
|
||||
is a member. Defaults to the value in `server_or_cluster_def`, if
|
||||
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
|
||||
protocol buffer, or a `tf.train.ClusterSpec` object, describing the
|
||||
server to be created and/or the cluster of which it is a member.
|
||||
job_name: (Optional.) Specifies the name of the job of which the server is
|
||||
a member. Defaults to the value in `server_or_cluster_def`, if
|
||||
specified.
|
||||
task_index: (Optional.) Specifies the task index of the server in its
|
||||
job. Defaults to the value in `server_or_cluster_def`, if specified.
|
||||
task_index: (Optional.) Specifies the task index of the server in its job.
|
||||
Defaults to the value in `server_or_cluster_def`, if specified.
|
||||
Otherwise defaults to 0 if the server's job has only one task.
|
||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the
|
||||
value in `server_or_cluster_def`, if specified. Otherwise defaults to
|
||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
|
||||
in `server_or_cluster_def`, if specified. Otherwise defaults to
|
||||
`"grpc"`.
|
||||
config: (Options.) A `tf.ConfigProto` that specifies default
|
||||
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
start: (Optional.) Boolean, indicating whether to start the server
|
||||
after creating it. Defaults to `True`.
|
||||
start: (Optional.) Boolean, indicating whether to start the server after
|
||||
creating it. Defaults to `True`.
|
||||
|
||||
Raises:
|
||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||
creating the TensorFlow server.
|
||||
"""
|
||||
self._server_def = _make_server_def(server_or_cluster_def,
|
||||
job_name, task_index, protocol, config)
|
||||
self._server_def = _make_server_def(server_or_cluster_def, job_name,
|
||||
task_index, protocol, config)
|
||||
self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
|
||||
if start:
|
||||
self.start()
|
||||
@ -195,15 +194,15 @@ class Server(object):
|
||||
|
||||
@property
|
||||
def target(self):
|
||||
"""Returns the target for a `tf.Session` to connect to this server.
|
||||
"""Returns the target for a `tf.compat.v1.Session` to connect to this server.
|
||||
|
||||
To create a
|
||||
`tf.Session` that
|
||||
`tf.compat.v1.Session` that
|
||||
connects to this server, use the following snippet:
|
||||
|
||||
```python
|
||||
server = tf.train.Server(...)
|
||||
with tf.Session(server.target):
|
||||
server = tf.distribute.Server(...)
|
||||
with tf.compat.v1.Session(server.target):
|
||||
# ...
|
||||
```
|
||||
|
||||
@ -217,22 +216,24 @@ class Server(object):
|
||||
"""Creates a new single-process cluster running on the local host.
|
||||
|
||||
This method is a convenience wrapper for creating a
|
||||
`tf.train.Server` with a `tf.train.ServerDef` that specifies a
|
||||
`tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
|
||||
single-process cluster containing a single task in a job called
|
||||
`"local"`.
|
||||
|
||||
Args:
|
||||
config: (Options.) A `tf.ConfigProto` that specifies default
|
||||
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||
configuration options for all sessions that run on this server.
|
||||
start: (Optional.) Boolean, indicating whether to start the server after
|
||||
creating it. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
A local `tf.train.Server`.
|
||||
A local `tf.distribute.Server`.
|
||||
"""
|
||||
# Specifying port 0 means that the OS will choose a free port for the
|
||||
# server.
|
||||
return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
|
||||
return Server({"local": ["localhost:0"]},
|
||||
protocol="grpc",
|
||||
config=config,
|
||||
start=start)
|
||||
|
||||
|
||||
@ -242,7 +243,7 @@ class ClusterSpec(object):
|
||||
|
||||
A `tf.train.ClusterSpec` represents the set of processes that
|
||||
participate in a distributed TensorFlow computation. Every
|
||||
`tf.train.Server` is constructed in a particular cluster.
|
||||
`tf.distribute.Server` is constructed in a particular cluster.
|
||||
|
||||
To create a cluster with two jobs and five tasks, you specify the
|
||||
mapping from job names to lists of network addresses (typically
|
||||
@ -272,10 +273,9 @@ class ClusterSpec(object):
|
||||
"""Creates a `ClusterSpec`.
|
||||
|
||||
Args:
|
||||
cluster: A dictionary mapping one or more job names to (i) a
|
||||
list of network addresses, or (ii) a dictionary mapping integer
|
||||
task indices to network addresses; or a `tf.train.ClusterDef`
|
||||
protocol buffer.
|
||||
cluster: A dictionary mapping one or more job names to (i) a list of
|
||||
network addresses, or (ii) a dictionary mapping integer task indices to
|
||||
network addresses; or a `tf.train.ClusterDef` protocol buffer.
|
||||
|
||||
Raises:
|
||||
TypeError: If `cluster` is not a dictionary mapping strings to lists
|
||||
@ -298,14 +298,16 @@ class ClusterSpec(object):
|
||||
self._cluster_spec = {}
|
||||
for job_def in self._cluster_def.job:
|
||||
self._cluster_spec[job_def.name] = {
|
||||
i: t for i, t in job_def.tasks.items()}
|
||||
i: t for i, t in job_def.tasks.items()
|
||||
}
|
||||
elif isinstance(cluster, ClusterSpec):
|
||||
self._cluster_def = cluster_pb2.ClusterDef()
|
||||
self._cluster_def.MergeFrom(cluster.as_cluster_def())
|
||||
self._cluster_spec = {}
|
||||
for job_def in self._cluster_def.job:
|
||||
self._cluster_spec[job_def.name] = {
|
||||
i: t for i, t in job_def.tasks.items()}
|
||||
i: t for i, t in job_def.tasks.items()
|
||||
}
|
||||
else:
|
||||
raise TypeError("`cluster` must be a dictionary mapping one or more "
|
||||
"job names to lists of network addresses, or a "
|
||||
@ -326,7 +328,8 @@ class ClusterSpec(object):
|
||||
def __str__(self):
|
||||
key_values = self.as_dict()
|
||||
string_items = [
|
||||
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
|
||||
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
|
||||
]
|
||||
return "ClusterSpec({" + ", ".join(string_items) + "})"
|
||||
|
||||
def as_dict(self):
|
||||
@ -427,8 +430,8 @@ class ClusterSpec(object):
|
||||
try:
|
||||
return job[task_index]
|
||||
except KeyError:
|
||||
raise ValueError("No task with index %r in job %r"
|
||||
% (task_index, job_name))
|
||||
raise ValueError("No task with index %r in job %r" %
|
||||
(task_index, job_name))
|
||||
|
||||
def job_tasks(self, job_name):
|
||||
"""Returns a mapping from task ID to address in the given job.
|
||||
@ -482,6 +485,6 @@ class ClusterSpec(object):
|
||||
try:
|
||||
task_address = compat.as_bytes(task_address)
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
"Task address %r must be bytes or unicode" % task_address)
|
||||
raise TypeError("Task address %r must be bytes or unicode" %
|
||||
task_address)
|
||||
job_def.tasks[i] = task_address
|
||||
|
@ -107,7 +107,7 @@ class GrpcServerTest(test.TestCase):
|
||||
self.assertAllEqual(2.0, sess.run(v1))
|
||||
|
||||
def _useRPCConfig(self):
|
||||
"""Return a `tf.ConfigProto` that ensures we use the RPC stack for tests.
|
||||
"""Return a `tf.compat.v1.ConfigProto` that ensures we use the RPC stack for tests.
|
||||
|
||||
This configuration ensures that we continue to exercise the gRPC
|
||||
stack when testing, rather than using the in-process optimization,
|
||||
@ -115,7 +115,7 @@ class GrpcServerTest(test.TestCase):
|
||||
master in the same process.
|
||||
|
||||
Returns:
|
||||
A `tf.ConfigProto`.
|
||||
A `tf.compat.v1.ConfigProto`.
|
||||
"""
|
||||
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
|
||||
use_rpc_for_inprocess_master=True))
|
||||
|
@ -174,8 +174,8 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertFalse(initialized)
|
||||
sess.run(v.initializer)
|
||||
self.assertEquals(1, sess.run(v))
|
||||
saver.save(sess,
|
||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
||||
saver.save(sess, os.path.join(checkpoint_dir,
|
||||
"recover_session_checkpoint"))
|
||||
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
|
||||
self._test_recovered_variable(
|
||||
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
||||
@ -202,9 +202,9 @@ class SessionManagerTest(test.TestCase):
|
||||
def testInitWithNoneLocalInitOpError(self):
|
||||
# Creating a SessionManager with a None local_init_op but
|
||||
# non-None ready_for_local_init_op raises ValueError
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"If you pass a ready_for_local_init_op "
|
||||
"you must also pass a local_init_op "):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "If you pass a ready_for_local_init_op "
|
||||
"you must also pass a local_init_op "):
|
||||
session_manager.SessionManager(
|
||||
ready_for_local_init_op=variables.report_uninitialized_variables(
|
||||
variables.global_variables()),
|
||||
@ -231,8 +231,8 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertFalse(initialized)
|
||||
sess.run(v.initializer)
|
||||
self.assertEquals(1, sess.run(v))
|
||||
saver.save(sess,
|
||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
||||
saver.save(sess, os.path.join(checkpoint_dir,
|
||||
"recover_session_checkpoint"))
|
||||
# Create a new Graph and SessionManager and recover.
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(2, name="v")
|
||||
@ -266,7 +266,7 @@ class SessionManagerTest(test.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
|
||||
# We use ready_for_local_init_op=tf.report_uninitialized_variables(),
|
||||
# We use ready_for_local_init_op=report_uninitialized_variables(),
|
||||
# which causes recover_session to not run local_init_op, and to return
|
||||
# initialized=False
|
||||
|
||||
@ -290,8 +290,8 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertFalse(initialized)
|
||||
sess.run(v.initializer)
|
||||
self.assertEquals(1, sess.run(v))
|
||||
saver.save(sess,
|
||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
||||
saver.save(sess, os.path.join(checkpoint_dir,
|
||||
"recover_session_checkpoint"))
|
||||
# Create a new Graph and SessionManager and recover.
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(2, name="v")
|
||||
@ -780,8 +780,8 @@ class ObsoleteSessionManagerTest(test.TestCase):
|
||||
self.assertFalse(initialized)
|
||||
sess.run(v.initializer)
|
||||
self.assertEquals(1, sess.run(v))
|
||||
saver.save(sess,
|
||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
||||
saver.save(sess, os.path.join(checkpoint_dir,
|
||||
"recover_session_checkpoint"))
|
||||
# Create a new Graph and SessionManager and recover.
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(2, name="v")
|
||||
|
@ -68,7 +68,7 @@ look at following code:
|
||||
|
||||
Above user code leads to following execution:
|
||||
call hooks.begin()
|
||||
sess = tf.Session()
|
||||
sess = tf.compat.v1.Session()
|
||||
call hooks.after_create_session()
|
||||
while not stop is requested:
|
||||
call hooks.before_run()
|
||||
|
@ -56,9 +56,9 @@ class SummaryWriter(_FileWriter):
|
||||
```python
|
||||
...create a graph...
|
||||
# Launch the graph in a session.
|
||||
sess = tf.Session()
|
||||
sess = tf.compat.v1.Session()
|
||||
# Create a summary writer, add the 'graph' to the event file.
|
||||
writer = tf.summary.FileWriter(<some-directory>, sess.graph)
|
||||
writer = tf.compat.v1.summary.FileWriter(<some-directory>, sess.graph)
|
||||
```
|
||||
|
||||
The other arguments to the constructor control the asynchronous writes to
|
||||
|
@ -45,7 +45,7 @@ class Supervisor(object):
|
||||
"""A training helper that checkpoints models and computes summaries.
|
||||
|
||||
This class is deprecated. Please use
|
||||
`tf.train.MonitoredTrainingSession` instead.
|
||||
`tf.compat.v1.train.MonitoredTrainingSession` instead.
|
||||
|
||||
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
|
||||
and a `SessionManager` that takes care of common needs of TensorFlow
|
||||
@ -97,7 +97,7 @@ class Supervisor(object):
|
||||
# or job_def.name, or job_def.tasks. It's entirely up to the end user.
|
||||
# But there can be only one *chief*.
|
||||
is_chief = (server_def.task_index == 0)
|
||||
server = tf.train.Server(server_def)
|
||||
server = tf.distribute.Server(server_def)
|
||||
|
||||
with tf.Graph().as_default():
|
||||
...add operations to the graph...
|
||||
@ -140,7 +140,7 @@ class Supervisor(object):
|
||||
* Specifying `'grpc://hostname:port'` requests a session that uses
|
||||
the RPC interface to a specific host, and also allows the in-process
|
||||
master to access remote tensorflow workers. Often, it is
|
||||
appropriate to pass `server.target` (for some `tf.train.Server`
|
||||
appropriate to pass `server.target` (for some `tf.distribute.Server`
|
||||
named `server).
|
||||
|
||||
#### Advanced use
|
||||
@ -237,17 +237,16 @@ class Supervisor(object):
|
||||
ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
|
||||
`prepare_or_wait_for_session()` to check if the model is ready to use.
|
||||
The model is considered ready if it returns an empty array. Defaults to
|
||||
the tensor returned from `tf.report_uninitialized_variables()` If
|
||||
`None`, the model is not checked for readiness.
|
||||
the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
|
||||
If `None`, the model is not checked for readiness.
|
||||
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
|
||||
supervisors in `prepare_or_wait_for_session()` to check if the model is
|
||||
ready to run the local_init_op.
|
||||
The model is considered ready if it returns an empty array. Defaults to
|
||||
`None`. If `None`, the model is not checked for readiness before running
|
||||
local_init_op.
|
||||
is_chief: If True, create a chief supervisor in charge of initializing
|
||||
and restoring the model. If False, create a supervisor that relies
|
||||
on a chief supervisor for inits and restore.
|
||||
ready to run the local_init_op. The model is considered ready if it
|
||||
returns an empty array. Defaults to `None`. If `None`, the model is not
|
||||
checked for readiness before running local_init_op.
|
||||
is_chief: If True, create a chief supervisor in charge of initializing and
|
||||
restoring the model. If False, create a supervisor that relies on a
|
||||
chief supervisor for inits and restore.
|
||||
init_op: `Operation`. Used by chief supervisors to initialize the model
|
||||
when it can not be recovered. Defaults to an `Operation` that
|
||||
initializes all global variables. If `None`, no initialization is done
|
||||
@ -255,20 +254,19 @@ class Supervisor(object):
|
||||
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
|
||||
This feed dictionary will be used when `init_op` is evaluated.
|
||||
local_init_op: `Operation`. Used by all supervisors to run initializations
|
||||
that should run for every new supervisor instance. By default these
|
||||
are table initializers and initializers for local variables.
|
||||
If `None`, no further per supervisor-instance initialization is
|
||||
done automatically.
|
||||
that should run for every new supervisor instance. By default these are
|
||||
table initializers and initializers for local variables. If `None`, no
|
||||
further per supervisor-instance initialization is done automatically.
|
||||
logdir: A string. Optional path to a directory where to checkpoint the
|
||||
model and log events for the visualizer. Used by chief supervisors.
|
||||
The directory will be created if it does not exist.
|
||||
summary_op: An `Operation` that returns a Summary for the event logs.
|
||||
Used by chief supervisors if a `logdir` was specified. Defaults to the
|
||||
model and log events for the visualizer. Used by chief supervisors. The
|
||||
directory will be created if it does not exist.
|
||||
summary_op: An `Operation` that returns a Summary for the event logs. Used
|
||||
by chief supervisors if a `logdir` was specified. Defaults to the
|
||||
operation returned from summary.merge_all(). If `None`, summaries are
|
||||
not computed automatically.
|
||||
saver: A Saver object. Used by chief supervisors if a `logdir` was
|
||||
specified. Defaults to the saved returned by Saver().
|
||||
If `None`, the model is not saved automatically.
|
||||
specified. Defaults to the saved returned by Saver(). If `None`, the
|
||||
model is not saved automatically.
|
||||
global_step: An integer Tensor of size 1 that counts steps. The value
|
||||
from 'global_step' is used in summaries and checkpoint filenames.
|
||||
Default to the op named 'global_step' in the graph if it exists, is of
|
||||
@ -280,20 +278,20 @@ class Supervisor(object):
|
||||
disable summaries.
|
||||
save_model_secs: Number of seconds between the creation of model
|
||||
checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
|
||||
recovery_wait_secs: Number of seconds between checks that the model
|
||||
is ready. Used by supervisors when waiting for a chief supervisor
|
||||
to initialize or restore the model. Defaults to 30 seconds.
|
||||
recovery_wait_secs: Number of seconds between checks that the model is
|
||||
ready. Used by supervisors when waiting for a chief supervisor to
|
||||
initialize or restore the model. Defaults to 30 seconds.
|
||||
stop_grace_secs: Grace period, in seconds, given to running threads to
|
||||
stop when `stop()` is called. Defaults to 120 seconds.
|
||||
checkpoint_basename: The basename for checkpoint saving.
|
||||
session_manager: `SessionManager`, which manages Session creation and
|
||||
recovery. If it is `None`, a default `SessionManager` will be created
|
||||
with the set of arguments passed in for backwards compatibility.
|
||||
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None`
|
||||
to indicate that no summaries should be written.
|
||||
init_fn: Optional callable used to initialize the model. Called
|
||||
after the optional `init_op` is called. The callable must accept one
|
||||
argument, the session being initialized.
|
||||
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to
|
||||
indicate that no summaries should be written.
|
||||
init_fn: Optional callable used to initialize the model. Called after the
|
||||
optional `init_op` is called. The callable must accept one argument,
|
||||
the session being initialized.
|
||||
local_init_run_options: RunOptions to be passed as the SessionManager
|
||||
local_init_run_options parameter.
|
||||
|
||||
@ -397,12 +395,11 @@ class Supervisor(object):
|
||||
"""Initializes ready_op.
|
||||
|
||||
Args:
|
||||
ready_op: `Tensor` to check if the model is initialized.
|
||||
If it's set to USE_DEFAULT, creates an op that checks all
|
||||
the variables are initialized.
|
||||
ready_op: `Tensor` to check if the model is initialized. If it's set to
|
||||
USE_DEFAULT, creates an op that checks all the variables are
|
||||
initialized.
|
||||
ready_for_local_init_op: `Tensor` to check if the model is ready to run
|
||||
local_init_op.
|
||||
If it's set to USE_DEFAULT, creates an op that checks all
|
||||
local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
|
||||
the global variables are initialized.
|
||||
"""
|
||||
if ready_op is Supervisor.USE_DEFAULT:
|
||||
@ -440,9 +437,9 @@ class Supervisor(object):
|
||||
|
||||
Args:
|
||||
local_init_op: `Operation` run for every new supervisor instance. If set
|
||||
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
|
||||
collection. If the collection is empty, create an op that initializes
|
||||
all local variables and all tables.
|
||||
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
|
||||
collection. If the collection is empty, create an op that initializes
|
||||
all local variables and all tables.
|
||||
"""
|
||||
if local_init_op is Supervisor.USE_DEFAULT:
|
||||
local_init_op = self._get_first_op_from_collection(
|
||||
@ -461,8 +458,8 @@ class Supervisor(object):
|
||||
"""Initializes saver.
|
||||
|
||||
Args:
|
||||
saver: A `Saver` object. If set to USE_DEFAULT, create one that
|
||||
saves all the variables.
|
||||
saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
|
||||
the variables.
|
||||
"""
|
||||
if saver is Supervisor.USE_DEFAULT:
|
||||
saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
|
||||
@ -475,8 +472,8 @@ class Supervisor(object):
|
||||
"""Initializes summary_op.
|
||||
|
||||
Args:
|
||||
summary_op: An Operation that returns a Summary for the event logs.
|
||||
If set to USE_DEFAULT, create an op that merges all the summaries.
|
||||
summary_op: An Operation that returns a Summary for the event logs. If set
|
||||
to USE_DEFAULT, create an op that merges all the summaries.
|
||||
"""
|
||||
if summary_op is Supervisor.USE_DEFAULT:
|
||||
summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
|
||||
@ -490,8 +487,8 @@ class Supervisor(object):
|
||||
"""Initializes global_step.
|
||||
|
||||
Args:
|
||||
global_step: An integer Tensor of size 1 that counts steps. If
|
||||
set to USE_DEFAULT, creates global_step tensor.
|
||||
global_step: An integer Tensor of size 1 that counts steps. If set to
|
||||
USE_DEFAULT, creates global_step tensor.
|
||||
"""
|
||||
if global_step is Supervisor.USE_DEFAULT:
|
||||
global_step = self._get_first_op_from_collection(
|
||||
@ -630,8 +627,9 @@ class Supervisor(object):
|
||||
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
|
||||
assert self._is_chief
|
||||
if self._logdir:
|
||||
training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
|
||||
self._logdir, "graph.pbtxt")
|
||||
training_util.write_graph(
|
||||
self._graph.as_graph_def(add_shapes=True), self._logdir,
|
||||
"graph.pbtxt")
|
||||
if self._summary_writer and not self._graph_added_to_summary:
|
||||
self._summary_writer.add_graph(self._graph)
|
||||
self._summary_writer.add_meta_graph(self._meta_graph_def)
|
||||
@ -675,8 +673,7 @@ class Supervisor(object):
|
||||
# if there is no step value.
|
||||
current_step = training_util.global_step(sess, self._global_step)
|
||||
self._summary_writer.add_session_log(
|
||||
SessionLog(status=SessionLog.START),
|
||||
current_step)
|
||||
SessionLog(status=SessionLog.START), current_step)
|
||||
|
||||
threads = []
|
||||
if self._save_summaries_secs and self._summary_writer:
|
||||
@ -690,7 +687,9 @@ class Supervisor(object):
|
||||
t.start()
|
||||
return threads
|
||||
|
||||
def prepare_or_wait_for_session(self, master="", config=None,
|
||||
def prepare_or_wait_for_session(self,
|
||||
master="",
|
||||
config=None,
|
||||
wait_for_checkpoint=False,
|
||||
max_wait_secs=7200,
|
||||
start_standard_services=True):
|
||||
@ -702,10 +701,10 @@ class Supervisor(object):
|
||||
manager to start the standard services.
|
||||
|
||||
Args:
|
||||
master: name of the TensorFlow master to use. See the `tf.Session`
|
||||
constructor for how this is interpreted.
|
||||
config: Optional ConfigProto proto used to configure the session,
|
||||
which is passed as-is to create the session.
|
||||
master: name of the TensorFlow master to use. See the
|
||||
`tf.compat.v1.Session` constructor for how this is interpreted.
|
||||
config: Optional ConfigProto proto used to configure the session, which is
|
||||
passed as-is to create the session.
|
||||
wait_for_checkpoint: Whether we should wait for the availability of a
|
||||
checkpoint before creating Session. Defaults to False.
|
||||
max_wait_secs: Maximum time to wait for the session to become available.
|
||||
@ -724,18 +723,22 @@ class Supervisor(object):
|
||||
|
||||
if self._is_chief:
|
||||
sess = self._session_manager.prepare_session(
|
||||
master, init_op=self.init_op, saver=self.saver,
|
||||
checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
|
||||
max_wait_secs=max_wait_secs, config=config,
|
||||
init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
|
||||
master,
|
||||
init_op=self.init_op,
|
||||
saver=self.saver,
|
||||
checkpoint_dir=self._logdir,
|
||||
wait_for_checkpoint=wait_for_checkpoint,
|
||||
max_wait_secs=max_wait_secs,
|
||||
config=config,
|
||||
init_feed_dict=self._init_feed_dict,
|
||||
init_fn=self._init_fn)
|
||||
self._write_graph()
|
||||
if start_standard_services:
|
||||
logging.info("Starting standard services.")
|
||||
self.start_standard_services(sess)
|
||||
else:
|
||||
sess = self._session_manager.wait_for_session(master,
|
||||
config=config,
|
||||
max_wait_secs=max_wait_secs)
|
||||
sess = self._session_manager.wait_for_session(
|
||||
master, config=config, max_wait_secs=max_wait_secs)
|
||||
if start_standard_services:
|
||||
logging.info("Starting queue runners.")
|
||||
self.start_queue_runners(sess)
|
||||
@ -772,8 +775,8 @@ class Supervisor(object):
|
||||
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
|
||||
threads = []
|
||||
for qr in queue_runners:
|
||||
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
|
||||
start=True))
|
||||
threads.extend(
|
||||
qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
|
||||
return threads
|
||||
|
||||
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
|
||||
@ -795,8 +798,12 @@ class Supervisor(object):
|
||||
Returns:
|
||||
The started thread.
|
||||
"""
|
||||
looper = coordinator.LooperThread(self._coord, timer_interval_secs,
|
||||
target=target, args=args, kwargs=kwargs)
|
||||
looper = coordinator.LooperThread(
|
||||
self._coord,
|
||||
timer_interval_secs,
|
||||
target=target,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
looper.start()
|
||||
return looper
|
||||
|
||||
@ -812,13 +819,13 @@ class Supervisor(object):
|
||||
threads: Optional list of threads to join with the coordinator. If
|
||||
`None`, defaults to the threads running the standard services, the
|
||||
threads started for `QueueRunners`, and the threads started by the
|
||||
`loop()` method. To wait on additional threads, pass the
|
||||
list in this parameter.
|
||||
`loop()` method. To wait on additional threads, pass the list in this
|
||||
parameter.
|
||||
close_summary_writer: Whether to close the `summary_writer`. Defaults to
|
||||
`True` if the summary writer was created by the supervisor, `False`
|
||||
otherwise.
|
||||
ignore_live_threads: If `True` ignores threads that remain running after
|
||||
a grace period when joining threads via the coordinator, instead of
|
||||
ignore_live_threads: If `True` ignores threads that remain running after a
|
||||
grace period when joining threads via the coordinator, instead of
|
||||
raising a RuntimeError.
|
||||
"""
|
||||
self._coord.request_stop()
|
||||
@ -926,7 +933,9 @@ class Supervisor(object):
|
||||
|
||||
# pylint: disable=g-doc-return-or-yield,broad-except
|
||||
@contextlib.contextmanager
|
||||
def managed_session(self, master="", config=None,
|
||||
def managed_session(self,
|
||||
master="",
|
||||
config=None,
|
||||
start_standard_services=True,
|
||||
close_summary_writer=True):
|
||||
"""Returns a context manager for a managed session.
|
||||
@ -940,7 +949,7 @@ class Supervisor(object):
|
||||
|
||||
```python
|
||||
def train():
|
||||
sv = tf.train.Supervisor(...)
|
||||
sv = tf.compat.v1.train.Supervisor(...)
|
||||
with sv.managed_session(<master>) as sess:
|
||||
for step in xrange(..):
|
||||
if sv.should_stop():
|
||||
@ -973,14 +982,14 @@ class Supervisor(object):
|
||||
the training loop and are considered normal termination.
|
||||
|
||||
Args:
|
||||
master: name of the TensorFlow master to use. See the `tf.Session`
|
||||
constructor for how this is interpreted.
|
||||
config: Optional `ConfigProto` proto used to configure the session.
|
||||
Passed as-is to create the session.
|
||||
start_standard_services: Whether to start the standard services,
|
||||
such as checkpoint, summary and step counter.
|
||||
close_summary_writer: Whether to close the summary writer when
|
||||
closing the session. Defaults to True.
|
||||
master: name of the TensorFlow master to use. See the
|
||||
`tf.compat.v1.Session` constructor for how this is interpreted.
|
||||
config: Optional `ConfigProto` proto used to configure the session. Passed
|
||||
as-is to create the session.
|
||||
start_standard_services: Whether to start the standard services, such as
|
||||
checkpoint, summary and step counter.
|
||||
close_summary_writer: Whether to close the summary writer when closing the
|
||||
session. Defaults to True.
|
||||
|
||||
Returns:
|
||||
A context manager that yields a `Session` restored from the latest
|
||||
@ -989,7 +998,8 @@ class Supervisor(object):
|
||||
"""
|
||||
try:
|
||||
sess = self.prepare_or_wait_for_session(
|
||||
master=master, config=config,
|
||||
master=master,
|
||||
config=config,
|
||||
start_standard_services=start_standard_services)
|
||||
yield sess
|
||||
except Exception as e:
|
||||
@ -1011,6 +1021,7 @@ class Supervisor(object):
|
||||
except Exception:
|
||||
# Silently ignore exceptions raised by close().
|
||||
pass
|
||||
|
||||
# pylint: enable=g-doc-return-or-yield,broad-except
|
||||
|
||||
|
||||
@ -1030,8 +1041,8 @@ class SVSummaryThread(coordinator.LooperThread):
|
||||
|
||||
def run_loop(self):
|
||||
if self._sv.global_step is not None:
|
||||
summary_strs, global_step = self._sess.run([self._sv.summary_op,
|
||||
self._sv.global_step])
|
||||
summary_strs, global_step = self._sess.run(
|
||||
[self._sv.summary_op, self._sv.global_step])
|
||||
else:
|
||||
summary_strs = self._sess.run(self._sv.summary_op)
|
||||
global_step = None
|
||||
@ -1063,8 +1074,7 @@ class SVStepCounterThread(coordinator.LooperThread):
|
||||
|
||||
def start_loop(self):
|
||||
self._last_time = time.time()
|
||||
self._last_step = training_util.global_step(
|
||||
self._sess, self._step_counter)
|
||||
self._last_step = training_util.global_step(self._sess, self._step_counter)
|
||||
|
||||
def run_loop(self):
|
||||
# Count the steps.
|
||||
@ -1080,12 +1090,13 @@ class SVStepCounterThread(coordinator.LooperThread):
|
||||
steps_per_sec = added_steps / elapsed_time
|
||||
else:
|
||||
steps_per_sec = float("inf")
|
||||
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
|
||||
simple_value=steps_per_sec)])
|
||||
summary = Summary(value=[
|
||||
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
|
||||
])
|
||||
if self._sv.summary_writer:
|
||||
self._sv.summary_writer.add_summary(summary, current_step)
|
||||
logging.log_first_n(logging.INFO, "%s: %g", 10,
|
||||
self._summary_tag, steps_per_sec)
|
||||
logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
|
||||
steps_per_sec)
|
||||
|
||||
|
||||
class SVTimerCheckpointThread(coordinator.LooperThread):
|
||||
@ -1104,13 +1115,13 @@ class SVTimerCheckpointThread(coordinator.LooperThread):
|
||||
|
||||
def run_loop(self):
|
||||
logging.info("Saving checkpoint to path %s", self._sv.save_path)
|
||||
self._sv.saver.save(self._sess, self._sv.save_path,
|
||||
global_step=self._sv.global_step)
|
||||
self._sv.saver.save(
|
||||
self._sess, self._sv.save_path, global_step=self._sv.global_step)
|
||||
if self._sv.summary_writer and self._sv.global_step is not None:
|
||||
current_step = training_util.global_step(self._sess, self._sv.global_step)
|
||||
self._sv.summary_writer.add_session_log(
|
||||
SessionLog(status=SessionLog.CHECKPOINT,
|
||||
checkpoint_path=self._sv.save_path),
|
||||
SessionLog(
|
||||
status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
|
||||
current_step)
|
||||
|
||||
|
||||
|
@ -110,7 +110,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
||||
# Note that if you want to have 2 backup replicas, you can change
|
||||
# total_num_replicas=52 and make sure this number matches how many physical
|
||||
# replicas you started in your job.
|
||||
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
|
||||
opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
|
||||
total_num_replicas=50)
|
||||
|
||||
# Some models have startup_delays to help stabilize the model but when using
|
||||
|
@ -38,11 +38,9 @@ from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
|
||||
# Key where the object graph proto is saved in a TensorBundle
|
||||
OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
|
||||
|
||||
|
||||
# A key indicating a variable's value in an object's checkpointed Tensors
|
||||
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
|
||||
# the object has no dependencies, then its value may be restored on object
|
||||
@ -74,8 +72,7 @@ class CheckpointInitialValue(ops.Tensor):
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_position, shape=None):
|
||||
self.wrapped_value = checkpoint_position.value_tensors()[
|
||||
VARIABLE_VALUE_KEY]
|
||||
self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY]
|
||||
if shape:
|
||||
# We need to set the static shape information on the initializer if
|
||||
# possible so we don't get a variable with an unknown shape.
|
||||
@ -97,8 +94,8 @@ class NoRestoreSaveable(saveable_object.SaveableObject):
|
||||
"""Embeds a tensor in a checkpoint with no restore ops."""
|
||||
|
||||
def __init__(self, tensor, name, dtype=None, device=None):
|
||||
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype,
|
||||
device=device)
|
||||
spec = saveable_object.SaveSpec(
|
||||
tensor, "", name, dtype=dtype, device=device)
|
||||
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
@ -123,7 +120,8 @@ class PythonStateSaveable(saveable_object.SaveableObject):
|
||||
"""Create a new `SaveableObject` which freezes current state as a constant.
|
||||
|
||||
Used when executing eagerly to embed the current state as a constant, or
|
||||
when creating a static tf.train.Saver with the frozen current Python state.
|
||||
when creating a static tf.compat.v1.train.Saver with the frozen current
|
||||
Python state.
|
||||
|
||||
Returns:
|
||||
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
|
||||
@ -140,24 +138,26 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
||||
|
||||
Args:
|
||||
name: The checkpoint key to write to.
|
||||
state_callback: A function taking no arguments which returns a
|
||||
string. This function is run every time a checkpoint is written.
|
||||
state_callback: A function taking no arguments which returns a string.
|
||||
This function is run every time a checkpoint is written.
|
||||
restore_callback: A function taking a Python string, used to restore
|
||||
state. Optional; defaults to doing nothing, in which case it is ignored
|
||||
by status assertions such as assert_consumed().
|
||||
"""
|
||||
self._has_trivial_state_callback = (restore_callback is None)
|
||||
|
||||
def _state_callback_wrapper():
|
||||
with ops.init_scope():
|
||||
return state_callback()
|
||||
|
||||
self._state_callback = _state_callback_wrapper
|
||||
self._restore_callback = restore_callback
|
||||
with ops.device("/cpu:0"):
|
||||
self._save_string = constant_op.constant("", dtype=dtypes.string)
|
||||
spec = saveable_object.SaveSpec(
|
||||
self._save_string, "", name, dtype=dtypes.string)
|
||||
super(PythonStringStateSaveable, self).__init__(
|
||||
self._save_string, [spec], name)
|
||||
super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
|
||||
name)
|
||||
|
||||
@property
|
||||
def optional_restore(self):
|
||||
@ -170,8 +170,10 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
||||
|
||||
def freeze(self):
|
||||
"""Create a frozen `SaveableObject` which saves the current state."""
|
||||
|
||||
def _constant_state():
|
||||
return constant_op.constant(self._state_callback(), dtype=dtypes.string)
|
||||
|
||||
return NoRestoreSaveable(
|
||||
tensor=_constant_state,
|
||||
dtype=dtypes.string,
|
||||
@ -217,6 +219,7 @@ class CheckpointPosition(object):
|
||||
|
||||
Args:
|
||||
trackable: The object to record a correspondence for.
|
||||
|
||||
Returns:
|
||||
True if this is a new assignment, False if this object has already been
|
||||
mapped to a checkpointed `Object` proto.
|
||||
@ -263,21 +266,21 @@ class CheckpointPosition(object):
|
||||
# consistent (if the dependency DAG is not a tree then there are
|
||||
# multiple paths to the same object).
|
||||
if current_assignment is not trackable:
|
||||
logging.warning(
|
||||
("Inconsistent references when loading the checkpoint into this "
|
||||
"object graph. Either the Trackable object references in the "
|
||||
"Python program have changed in an incompatible way, or the "
|
||||
"checkpoint was generated in an incompatible program.\n\nTwo "
|
||||
"checkpoint references resolved to different objects (%s and %s).")
|
||||
% (current_assignment, trackable))
|
||||
logging.warning((
|
||||
"Inconsistent references when loading the checkpoint into this "
|
||||
"object graph. Either the Trackable object references in the "
|
||||
"Python program have changed in an incompatible way, or the "
|
||||
"checkpoint was generated in an incompatible program.\n\nTwo "
|
||||
"checkpoint references resolved to different objects (%s and %s)."),
|
||||
current_assignment, trackable)
|
||||
return False # Not a new assignment
|
||||
|
||||
def is_simple_variable(self):
|
||||
"""Determine whether this value is restorable with a Tensor initializer."""
|
||||
attributes = self.object_proto.attributes
|
||||
return (len(attributes) == 1
|
||||
and attributes[0].name == VARIABLE_VALUE_KEY
|
||||
and not self.object_proto.children)
|
||||
return (len(attributes) == 1 and
|
||||
attributes[0].name == VARIABLE_VALUE_KEY and
|
||||
not self.object_proto.children)
|
||||
|
||||
def value_tensors(self):
|
||||
"""Create value `Tensor`s for this object's attributes.
|
||||
@ -335,8 +338,9 @@ class CheckpointPosition(object):
|
||||
# If we've already created and cached a SaveableObject for this
|
||||
# attribute, we can re-use it to avoid re-creating some ops when graph
|
||||
# building.
|
||||
saveable_list = saveables_cache.get(
|
||||
self.trackable, {}).get(serialized_tensor.name, (None,))
|
||||
saveable_list = saveables_cache.get(self.trackable,
|
||||
{}).get(serialized_tensor.name,
|
||||
(None,))
|
||||
if len(saveable_list) == 1:
|
||||
# Almost every attribute will have exactly one SaveableObject.
|
||||
saveable, = saveable_list
|
||||
@ -370,8 +374,8 @@ class CheckpointPosition(object):
|
||||
else:
|
||||
saveable = saveable_factory
|
||||
if saveables_cache is not None:
|
||||
saveables_cache.setdefault(
|
||||
self.trackable, {})[serialized_tensor.name] = [saveable]
|
||||
saveables_cache.setdefault(self.trackable,
|
||||
{})[serialized_tensor.name] = [saveable]
|
||||
if isinstance(saveable, PythonStateSaveable):
|
||||
python_saveables.append(saveable)
|
||||
else:
|
||||
@ -388,11 +392,10 @@ class CheckpointPosition(object):
|
||||
A list of operations when graph building, or an empty list when executing
|
||||
eagerly.
|
||||
"""
|
||||
(restore_ops,
|
||||
tensor_saveables,
|
||||
(restore_ops, tensor_saveables,
|
||||
python_saveables) = self._gather_ops_or_named_saveables()
|
||||
restore_ops.extend(self._checkpoint.restore_saveables(
|
||||
tensor_saveables, python_saveables))
|
||||
restore_ops.extend(
|
||||
self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
|
||||
return restore_ops
|
||||
|
||||
@property
|
||||
@ -416,13 +419,11 @@ class CheckpointPosition(object):
|
||||
|
||||
|
||||
_DeferredSlotVariableRestoration = collections.namedtuple(
|
||||
"_DeferredSlotVariableRestoration",
|
||||
[
|
||||
"_DeferredSlotVariableRestoration", [
|
||||
"original_variable",
|
||||
"slot_variable_id",
|
||||
"slot_name",
|
||||
]
|
||||
)
|
||||
])
|
||||
|
||||
_SlotVariableRestoration = collections.namedtuple(
|
||||
"_SlotVariableRestoration",
|
||||
@ -446,6 +447,7 @@ def no_automatic_dependency_tracking(method):
|
||||
|
||||
Args:
|
||||
method: The method to decorate.
|
||||
|
||||
Returns:
|
||||
A decorated method which sets and un-sets automatic dependency tracking for
|
||||
the object the method is called on (not thread safe).
|
||||
@ -595,16 +597,21 @@ class Trackable(object):
|
||||
|
||||
Args:
|
||||
name: The local name of the dependency.
|
||||
|
||||
Returns:
|
||||
A `Trackable` object, or `None` if no dependency by this name was
|
||||
found.
|
||||
"""
|
||||
return self._self_unconditional_dependency_names.get(name, None)
|
||||
|
||||
def _add_variable_with_custom_getter(
|
||||
self, name, shape=None, dtype=dtypes.float32,
|
||||
initializer=None, getter=None, overwrite=False,
|
||||
**kwargs_for_getter):
|
||||
def _add_variable_with_custom_getter(self,
|
||||
name,
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
initializer=None,
|
||||
getter=None,
|
||||
overwrite=False,
|
||||
**kwargs_for_getter):
|
||||
"""Restore-on-create for a variable be saved with this `Trackable`.
|
||||
|
||||
If the user has requested that this object or another `Trackable` which
|
||||
@ -640,11 +647,9 @@ class Trackable(object):
|
||||
name=name, shape=shape)
|
||||
else:
|
||||
checkpoint_initializer = None
|
||||
if (checkpoint_initializer is not None
|
||||
and not (
|
||||
isinstance(initializer, CheckpointInitialValue)
|
||||
and (initializer.restore_uid
|
||||
> checkpoint_initializer.restore_uid))):
|
||||
if (checkpoint_initializer is not None and
|
||||
not (isinstance(initializer, CheckpointInitialValue) and
|
||||
(initializer.restore_uid > checkpoint_initializer.restore_uid))):
|
||||
# If multiple Trackable objects are "creating" the same variable
|
||||
# via the magic of custom getters, the one with the highest restore UID
|
||||
# (the one called last) has to make the final initializer. If another
|
||||
@ -654,7 +659,10 @@ class Trackable(object):
|
||||
initializer = checkpoint_initializer
|
||||
shape = None
|
||||
new_variable = getter(
|
||||
name=name, shape=shape, dtype=dtype, initializer=initializer,
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
initializer=initializer,
|
||||
**kwargs_for_getter)
|
||||
|
||||
# If we set an initializer and the variable processed it, tracking will not
|
||||
@ -662,8 +670,7 @@ class Trackable(object):
|
||||
# is a non-trivial restoration queued, it will handle that. This also
|
||||
# handles slot variables.
|
||||
if not overwrite or isinstance(new_variable, Trackable):
|
||||
return self._track_trackable(new_variable, name=name,
|
||||
overwrite=overwrite)
|
||||
return self._track_trackable(new_variable, name=name, overwrite=overwrite)
|
||||
else:
|
||||
# TODO(allenl): Some variable types are not yet supported. Remove this
|
||||
# fallback once all get_variable() return types are Trackable.
|
||||
@ -681,6 +688,7 @@ class Trackable(object):
|
||||
name: The object-local name of the dependency holding the variable's
|
||||
value.
|
||||
shape: The shape of the variable being loaded into.
|
||||
|
||||
Returns:
|
||||
An callable for use as a variable's initializer/initial_value, or None if
|
||||
one should not be set (either because there was no variable with this name
|
||||
@ -718,8 +726,8 @@ class Trackable(object):
|
||||
|
||||
Args:
|
||||
trackable: A `Trackable` which this object depends on.
|
||||
name: A local name for `trackable`, used for loading checkpoints into
|
||||
the correct objects.
|
||||
name: A local name for `trackable`, used for loading checkpoints into the
|
||||
correct objects.
|
||||
overwrite: Boolean, whether silently replacing dependencies is OK. Used
|
||||
for __setattr__, where throwing an error on attribute reassignment would
|
||||
be inappropriate.
|
||||
@ -734,13 +742,11 @@ class Trackable(object):
|
||||
"""
|
||||
self._maybe_initialize_trackable()
|
||||
if not isinstance(trackable, Trackable):
|
||||
raise TypeError(
|
||||
("Trackable._track_trackable() passed type %s, not a "
|
||||
"Trackable.") % (type(trackable),))
|
||||
raise TypeError(("Trackable._track_trackable() passed type %s, not a "
|
||||
"Trackable.") % (type(trackable),))
|
||||
new_reference = TrackableReference(name=name, ref=trackable)
|
||||
current_object = self._lookup_dependency(name)
|
||||
if (current_object is not None
|
||||
and current_object is not trackable):
|
||||
if (current_object is not None and current_object is not trackable):
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
("Called Trackable._track_trackable() with name='%s', "
|
||||
@ -755,8 +761,7 @@ class Trackable(object):
|
||||
index] = new_reference
|
||||
elif current_object is None:
|
||||
self._self_unconditional_checkpoint_dependencies.append(new_reference)
|
||||
self._handle_deferred_dependencies(
|
||||
name=name, trackable=trackable)
|
||||
self._handle_deferred_dependencies(name=name, trackable=trackable)
|
||||
self._self_unconditional_dependency_names[name] = trackable
|
||||
return trackable
|
||||
|
||||
@ -780,8 +785,7 @@ class Trackable(object):
|
||||
Args:
|
||||
name: The name of the dependency within this object (`self`), used to
|
||||
match `trackable` with values saved in a checkpoint.
|
||||
trackable: The Trackable object to restore (inheriting from
|
||||
`Trackable`).
|
||||
trackable: The Trackable object to restore (inheriting from `Trackable`).
|
||||
"""
|
||||
self._maybe_initialize_trackable()
|
||||
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
@ -809,15 +813,15 @@ class Trackable(object):
|
||||
restore_ops = []
|
||||
while visit_queue:
|
||||
current_position = visit_queue.popleft()
|
||||
restore_ops.extend(nest.flatten(
|
||||
current_position.trackable # pylint: disable=protected-access
|
||||
._single_restoration_from_checkpoint_position(
|
||||
checkpoint_position=current_position,
|
||||
visit_queue=visit_queue)))
|
||||
restore_ops.extend(
|
||||
nest.flatten(current_position.trackable # pylint: disable=protected-access
|
||||
._single_restoration_from_checkpoint_position(
|
||||
checkpoint_position=current_position,
|
||||
visit_queue=visit_queue)))
|
||||
return restore_ops
|
||||
|
||||
def _single_restoration_from_checkpoint_position(
|
||||
self, checkpoint_position, visit_queue):
|
||||
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
|
||||
visit_queue):
|
||||
"""Restore this object, and either queue its dependencies or defer them."""
|
||||
self._maybe_initialize_trackable()
|
||||
checkpoint = checkpoint_position.checkpoint
|
||||
@ -831,14 +835,13 @@ class Trackable(object):
|
||||
restore_ops = ()
|
||||
for child in checkpoint_position.object_proto.children:
|
||||
child_position = CheckpointPosition(
|
||||
checkpoint=checkpoint,
|
||||
proto_id=child.node_id)
|
||||
checkpoint=checkpoint, proto_id=child.node_id)
|
||||
local_object = self._lookup_dependency(child.local_name)
|
||||
if local_object is None:
|
||||
# We don't yet have a dependency registered with this name. Save it
|
||||
# in case we do.
|
||||
self._deferred_dependencies.setdefault(child.local_name, []).append(
|
||||
child_position)
|
||||
self._deferred_dependencies.setdefault(child.local_name,
|
||||
[]).append(child_position)
|
||||
else:
|
||||
if child_position.bind_object(trackable=local_object):
|
||||
# This object's correspondence is new, so dependencies need to be
|
||||
@ -853,7 +856,8 @@ class Trackable(object):
|
||||
|
||||
Keys in the returned dictionary are local to this object and in a separate
|
||||
namespace from dependencies. Values may either be `SaveableObject` factories
|
||||
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
|
||||
or variables easily converted to `SaveableObject`s (as in
|
||||
`tf.compat.v1.train.Saver`'s
|
||||
`var_list` constructor argument).
|
||||
|
||||
`SaveableObjects` have a name set, which Trackable needs to generate
|
||||
@ -861,7 +865,8 @@ class Trackable(object):
|
||||
should return a dictionary of callables which take `name` arguments and
|
||||
return `SaveableObjects` with that name.
|
||||
|
||||
If this object may also be passed to the global-name-based `tf.train.Saver`,
|
||||
If this object may also be passed to the global-name-based
|
||||
`tf.compat.v1.train.Saver`,
|
||||
the returned callables should have a default value for their name argument
|
||||
(i.e. be callable with no arguments).
|
||||
|
||||
@ -884,6 +889,7 @@ class Trackable(object):
|
||||
except NotImplementedError:
|
||||
return {}
|
||||
weak_self = weakref.ref(self)
|
||||
|
||||
def _state_callback():
|
||||
"""Serializes `self.get_config()` for saving."""
|
||||
dereferenced_self = weak_self()
|
||||
@ -898,9 +904,12 @@ class Trackable(object):
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
|
||||
PythonStringStateSaveable,
|
||||
state_callback=_state_callback)}
|
||||
|
||||
return {
|
||||
OBJECT_CONFIG_JSON_KEY:
|
||||
functools.partial(
|
||||
PythonStringStateSaveable, state_callback=_state_callback)
|
||||
}
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
"""Lists the functions of this trackable to serialize.
|
||||
|
@ -61,8 +61,8 @@ class _CheckpointRestoreCoordinator(object):
|
||||
"""Specify the checkpoint being loaded.
|
||||
|
||||
Args:
|
||||
object_graph_proto: The TrackableObjectGraph protocol buffer
|
||||
associated with this checkpoint.
|
||||
object_graph_proto: The TrackableObjectGraph protocol buffer associated
|
||||
with this checkpoint.
|
||||
save_path: A string, the path to the checkpoint, as returned by
|
||||
`tf.train.latest_checkpoint`.
|
||||
save_path_tensor: A string `Tensor` which contains or will be fed the save
|
||||
@ -142,12 +142,10 @@ class _CheckpointRestoreCoordinator(object):
|
||||
"""
|
||||
restore_ops = []
|
||||
# Eagerly run restorations for Python state.
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(
|
||||
self.save_path_string)
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
|
||||
for saveable in python_saveables:
|
||||
spec_names = [spec.name for spec in saveable.specs]
|
||||
saveable.python_restore(
|
||||
[reader.get_tensor(name) for name in spec_names])
|
||||
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
|
||||
|
||||
# If we have new SaveableObjects, extract and cache restore ops.
|
||||
if tensor_saveables:
|
||||
@ -205,14 +203,13 @@ class _NameBasedRestoreCoordinator(object):
|
||||
# whether it's optional to restore it. If it's optional we don't need
|
||||
# to make assertions fail.
|
||||
if not saveable_factory("").optional_restore:
|
||||
self.unused_attributes.setdefault(trackable, []).append(
|
||||
attribute_name)
|
||||
self.unused_attributes.setdefault(trackable,
|
||||
[]).append(attribute_name)
|
||||
continue
|
||||
else:
|
||||
saveable = saveable_factory
|
||||
names_to_saveables = saveable_object_util.op_list_to_dict(
|
||||
[saveable],
|
||||
convert_variable_to_tensor=False)
|
||||
[saveable], convert_variable_to_tensor=False)
|
||||
for name, op in names_to_saveables.items():
|
||||
for saveable_object in saveable_object_util.saveable_objects_for_op(
|
||||
op=op, name=name):
|
||||
@ -224,8 +221,7 @@ class _NameBasedRestoreCoordinator(object):
|
||||
# run_restore_ops/initialize_or_restore on the status object for name-based
|
||||
# checkpoints.
|
||||
assert context.executing_eagerly()
|
||||
for saveable in self.globally_named_object_attributes(
|
||||
trackable):
|
||||
for saveable in self.globally_named_object_attributes(trackable):
|
||||
restored_tensors = []
|
||||
tensor_missing = False
|
||||
for spec in saveable.specs:
|
||||
@ -248,14 +244,18 @@ class _NameBasedRestoreCoordinator(object):
|
||||
# Ignores values missing from the checkpoint, as with object-based
|
||||
# restore. Status assertions can be used to check exact matches,
|
||||
# although it's unlikely to ever happen for name-based checkpoints.
|
||||
saveable.restore(restored_tensors=restored_tensors,
|
||||
restored_shapes=None)
|
||||
saveable.restore(
|
||||
restored_tensors=restored_tensors, restored_shapes=None)
|
||||
|
||||
|
||||
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
|
||||
# or consolidating the implementation with get_variable.
|
||||
def _default_getter(name, shape, dtype, initializer=None,
|
||||
partition_info=None, **kwargs):
|
||||
def _default_getter(name,
|
||||
shape,
|
||||
dtype,
|
||||
initializer=None,
|
||||
partition_info=None,
|
||||
**kwargs):
|
||||
"""A pared-down version of get_variable which does not reuse variables."""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
shape_object = tensor_shape.as_shape(shape)
|
||||
@ -263,7 +263,9 @@ def _default_getter(name, shape, dtype, initializer=None,
|
||||
if initializer is None:
|
||||
initializer, initializing_from_value = (
|
||||
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
|
||||
name=name, shape=shape_object, dtype=dtype))
|
||||
name=name,
|
||||
shape=shape_object,
|
||||
dtype=dtype))
|
||||
else:
|
||||
initializing_from_value = not callable(initializer)
|
||||
# Same logic as get_variable
|
||||
@ -276,24 +278,33 @@ def _default_getter(name, shape, dtype, initializer=None,
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if isinstance(initializer, type(init_ops.Initializer)):
|
||||
initializer = initializer(dtype=dtype)
|
||||
|
||||
def initial_value():
|
||||
return initializer(
|
||||
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
|
||||
|
||||
return variables.VariableV1(
|
||||
initial_value=initial_value,
|
||||
name=name,
|
||||
dtype=variable_dtype,
|
||||
use_resource=True,
|
||||
**kwargs
|
||||
)
|
||||
**kwargs)
|
||||
|
||||
|
||||
def add_variable(trackable, name, shape=None, dtype=dtypes.float32,
|
||||
initializer=None, trainable=True):
|
||||
def add_variable(trackable,
|
||||
name,
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
initializer=None,
|
||||
trainable=True):
|
||||
"""Add a variable to a Trackable with no scope influence."""
|
||||
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||
name=name, shape=shape, dtype=dtype,
|
||||
initializer=initializer, getter=_default_getter, trainable=trainable)
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
initializer=initializer,
|
||||
getter=_default_getter,
|
||||
trainable=trainable)
|
||||
|
||||
|
||||
def object_metadata(save_path):
|
||||
@ -313,6 +324,7 @@ def object_metadata(save_path):
|
||||
Args:
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`.
|
||||
|
||||
Returns:
|
||||
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
|
||||
Raises:
|
||||
@ -320,16 +332,14 @@ def object_metadata(save_path):
|
||||
"""
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||
try:
|
||||
object_graph_string = reader.get_tensor(
|
||||
base.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||
except errors_impl.NotFoundError:
|
||||
raise ValueError(
|
||||
('The specified checkpoint "%s" does not appear to be object-based (it '
|
||||
'is missing the key "%s"). Likely it was created with a name-based '
|
||||
'saver and does not contain an object dependency graph.') % (
|
||||
save_path, base.OBJECT_GRAPH_PROTO_KEY))
|
||||
object_graph_proto = (
|
||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
"saver and does not contain an object dependency graph.") %
|
||||
(save_path, base.OBJECT_GRAPH_PROTO_KEY))
|
||||
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
return object_graph_proto
|
||||
|
||||
@ -343,8 +353,8 @@ def list_objects(root_trackable):
|
||||
(i.e. if they would be saved with a checkpoint).
|
||||
|
||||
Args:
|
||||
root_trackable: A `Trackable` object whose dependencies should be
|
||||
flattened.
|
||||
root_trackable: A `Trackable` object whose dependencies should be flattened.
|
||||
|
||||
Returns:
|
||||
A flat list of objects.
|
||||
"""
|
||||
@ -362,12 +372,16 @@ def gather_initializers(root_trackable):
|
||||
|
||||
Args:
|
||||
root_trackable: A `Trackable` object to gather initializers for.
|
||||
|
||||
Returns:
|
||||
A list of initialization ops.
|
||||
"""
|
||||
trackable_objects = list_objects(root_trackable)
|
||||
return [c.initializer for c in trackable_objects
|
||||
if hasattr(c, "initializer") and c.initializer is not None]
|
||||
return [
|
||||
c.initializer
|
||||
for c in trackable_objects
|
||||
if hasattr(c, "initializer") and c.initializer is not None
|
||||
]
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
@ -380,7 +394,7 @@ def capture_dependencies(template):
|
||||
object to add dependencies on variables created in a block of code which is
|
||||
not aware of object-based saving (and instead uses variable names
|
||||
heavily). This is how `Template` objects add dependencies on variables and
|
||||
sub-`Template`s. Where possible, use `tf.make_template` directly.
|
||||
sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
|
||||
|
||||
Args:
|
||||
template: The `Template` object to register dependencies with.
|
||||
@ -390,8 +404,11 @@ def capture_dependencies(template):
|
||||
"""
|
||||
name_prefix = template.variable_scope.name
|
||||
|
||||
def _trackable_custom_creator(next_creator, name, initial_value,
|
||||
trackable_parent=None, **kwargs):
|
||||
def _trackable_custom_creator(next_creator,
|
||||
name,
|
||||
initial_value,
|
||||
trackable_parent=None,
|
||||
**kwargs):
|
||||
"""A variable creation hook which adds Trackable dependencies.
|
||||
|
||||
Set for example during a `Template`'s first wrapped function
|
||||
@ -415,21 +432,20 @@ def capture_dependencies(template):
|
||||
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
||||
explicitly so the argument can be re-named and used with
|
||||
`Trackable._add_variable_with_custom_getter`.
|
||||
trackable_parent: If not None, a more deeply nested trackable
|
||||
object and its name prefix which were passed to `capture_dependencies`
|
||||
to add a dependency on (rather than depending on the variable directly).
|
||||
trackable_parent: If not None, a more deeply nested trackable object and
|
||||
its name prefix which were passed to `capture_dependencies` to add a
|
||||
dependency on (rather than depending on the variable directly).
|
||||
**kwargs: Passed through to the next creator.
|
||||
|
||||
Returns:
|
||||
The output of `next_creator`: the fetched/created variable object.
|
||||
"""
|
||||
|
||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||
# we don't want to propagate.
|
||||
return next_creator(
|
||||
initial_value=initializer,
|
||||
name=name,
|
||||
**inner_kwargs)
|
||||
return next_creator(initial_value=initializer, name=name, **inner_kwargs)
|
||||
|
||||
if name is not None and name.startswith(name_prefix):
|
||||
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||
if not trackable_parent:
|
||||
@ -450,8 +466,10 @@ def capture_dependencies(template):
|
||||
name=parent_name_prefix[len(name_prefix) + 1:],
|
||||
overwrite=True)
|
||||
return next_creator(
|
||||
name=name, initial_value=initial_value,
|
||||
trackable_parent=(template, name_prefix), **kwargs)
|
||||
name=name,
|
||||
initial_value=initial_value,
|
||||
trackable_parent=(template, name_prefix),
|
||||
**kwargs)
|
||||
|
||||
with variable_scope.variable_creator_scope(_trackable_custom_creator):
|
||||
yield
|
||||
@ -490,9 +508,8 @@ def streaming_restore(status, session=None):
|
||||
"""When graph building, runs restore ops as soon as they come in.
|
||||
|
||||
Args:
|
||||
status: A _LoadStatus objects from an object-based saver's
|
||||
restore(). Streaming restore from name-based checkpoints is not currently
|
||||
supported.
|
||||
status: A _LoadStatus objects from an object-based saver's restore().
|
||||
Streaming restore from name-based checkpoints is not currently supported.
|
||||
session: A session to run new restore ops in.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
@ -553,13 +570,13 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
if self._checkpoint.slot_restorations:
|
||||
# Sanity check; this collection should be clear if everything has been
|
||||
# restored.
|
||||
raise AssertionError("Unresolved slot restorations: %s" % (
|
||||
self._checkpoint.slot_restorations,))
|
||||
raise AssertionError("Unresolved slot restorations: %s" %
|
||||
(self._checkpoint.slot_restorations,))
|
||||
if self._checkpoint.unused_attributes:
|
||||
raise AssertionError(
|
||||
("Unused attributes in these objects (the attributes exist in the "
|
||||
"checkpoint but not in the objects): %s") % (
|
||||
list(self._checkpoint.unused_attributes.items()),))
|
||||
"checkpoint but not in the objects): %s") %
|
||||
(list(self._checkpoint.unused_attributes.items()),))
|
||||
return self
|
||||
|
||||
def assert_existing_objects_matched(self):
|
||||
@ -581,10 +598,10 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
"""
|
||||
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
|
||||
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
|
||||
if (trackable is not None
|
||||
and trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
|
||||
raise AssertionError(
|
||||
"Object not assigned a value from checkpoint: %s" % (node,))
|
||||
if (trackable is not None and
|
||||
trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
|
||||
raise AssertionError("Object not assigned a value from checkpoint: %s" %
|
||||
(node,))
|
||||
for trackable_object in self._graph_view.list_objects():
|
||||
# Remove data structures that do not contain any variables from
|
||||
# restoration checks.
|
||||
@ -594,14 +611,14 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
continue
|
||||
self._checkpoint.all_python_objects.add(trackable_object)
|
||||
unused_python_objects = (
|
||||
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
|
||||
- object_identity.ObjectIdentitySet(
|
||||
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) -
|
||||
object_identity.ObjectIdentitySet(
|
||||
self._checkpoint.object_by_proto_id.values()))
|
||||
if unused_python_objects:
|
||||
raise AssertionError(
|
||||
("Some Python objects were not bound to checkpointed values, likely "
|
||||
"due to changes in the Python program: %s")
|
||||
% (list(unused_python_objects),))
|
||||
"due to changes in the Python program: %s") %
|
||||
(list(unused_python_objects),))
|
||||
return self
|
||||
|
||||
def assert_nontrivial_match(self):
|
||||
@ -610,8 +627,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
self._checkpoint.all_python_objects.add(trackable_object)
|
||||
if len(self._checkpoint.object_by_proto_id) <= 1:
|
||||
unused_python_objects = (
|
||||
object_identity.ObjectIdentitySet(
|
||||
self._checkpoint.all_python_objects)
|
||||
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
|
||||
- object_identity.ObjectIdentitySet(
|
||||
self._checkpoint.object_by_proto_id.values()))
|
||||
if unused_python_objects:
|
||||
@ -622,8 +638,8 @@ class CheckpointLoadStatus(_LoadStatus):
|
||||
"checkpointed value: %s") % (list(unused_python_objects),))
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Nothing to load. No dependencies have been added to %s yet." % (
|
||||
self._graph_view.root,))
|
||||
"Nothing to load. No dependencies have been added to %s yet." %
|
||||
(self._graph_view.root,))
|
||||
return self
|
||||
|
||||
def run_restore_ops(self, session=None):
|
||||
@ -760,8 +776,8 @@ class NameBasedSaverStatus(_LoadStatus):
|
||||
unused_attributes = dict(self._checkpoint.unused_attributes)
|
||||
if unused_attributes:
|
||||
raise AssertionError(
|
||||
"Some objects had attributes which were not restored: %s"
|
||||
% (unused_attributes,))
|
||||
"Some objects had attributes which were not restored: %s" %
|
||||
(unused_attributes,))
|
||||
for trackable in self._graph_view.list_objects():
|
||||
# pylint: disable=protected-access
|
||||
trackable._maybe_initialize_trackable()
|
||||
@ -799,12 +815,11 @@ class NameBasedSaverStatus(_LoadStatus):
|
||||
continue
|
||||
# pylint: enable=protected-access
|
||||
saveable_objects.extend(
|
||||
self._checkpoint.globally_named_object_attributes(
|
||||
trackable))
|
||||
self._checkpoint.globally_named_object_attributes(trackable))
|
||||
return saveable_objects
|
||||
|
||||
def run_restore_ops(self, session=None):
|
||||
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
|
||||
"""Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
|
||||
if context.executing_eagerly():
|
||||
return # Nothing to do, variables are restored on creation.
|
||||
if session is None:
|
||||
@ -840,7 +855,8 @@ class TrackableSaver(object):
|
||||
"""Saves and restores a `Trackable` object and its dependencies.
|
||||
|
||||
See `Trackable` for details of dependency management. `Saver` wraps
|
||||
`tf.train.Saver` for saving, including extra information about the graph of
|
||||
`tf.compat.v1.train.Saver` for saving, including extra information about the
|
||||
graph of
|
||||
dependencies between Python objects. When restoring, it uses this information
|
||||
about the save-time dependency graph to more robustly match objects with their
|
||||
checkpointed values. When executing eagerly, it supports restoring variables
|
||||
@ -851,7 +867,8 @@ class TrackableSaver(object):
|
||||
checkpoint was written. To avoid breaking existing checkpoints when modifying
|
||||
a class, dependency names (the names of attributes to which `Trackable`
|
||||
objects are assigned) may not change. These names are local to objects, in
|
||||
contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and
|
||||
contrast to the `Variable.name`-based save/restore from
|
||||
`tf.compat.v1.train.Saver`, and
|
||||
so allow additional program transformations.
|
||||
"""
|
||||
|
||||
@ -877,8 +894,7 @@ class TrackableSaver(object):
|
||||
self._restore_op_cache = {}
|
||||
self._graph_view = graph_view
|
||||
|
||||
def _gather_saveables(
|
||||
self, object_graph_tensor=None):
|
||||
def _gather_saveables(self, object_graph_tensor=None):
|
||||
"""Wraps _serialize_object_graph to include the object graph proto."""
|
||||
(named_saveable_objects, graph_proto,
|
||||
feed_additions) = self._graph_view.serialize_object_graph()
|
||||
@ -892,14 +908,12 @@ class TrackableSaver(object):
|
||||
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
|
||||
named_saveable_objects.append(
|
||||
base.NoRestoreSaveable(
|
||||
tensor=object_graph_tensor,
|
||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
return named_saveable_objects, graph_proto, feed_additions
|
||||
|
||||
def _save_cached_when_graph_building(
|
||||
self,
|
||||
file_prefix,
|
||||
object_graph_tensor=None):
|
||||
def _save_cached_when_graph_building(self,
|
||||
file_prefix,
|
||||
object_graph_tensor=None):
|
||||
"""Create or retrieve save ops.
|
||||
|
||||
Args:
|
||||
@ -921,8 +935,7 @@ class TrackableSaver(object):
|
||||
# save() is called so they pick up new Tensors passed to their
|
||||
# constructors. That means the Saver needs to be copied with a new
|
||||
# var_list.
|
||||
or context.executing_eagerly()
|
||||
or ops.inside_function()):
|
||||
or context.executing_eagerly() or ops.inside_function()):
|
||||
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
|
||||
save_op = saver.save(file_prefix)
|
||||
with ops.device("/cpu:0"):
|
||||
@ -954,8 +967,8 @@ class TrackableSaver(object):
|
||||
The full path to the checkpoint.
|
||||
"""
|
||||
feed_dict = {}
|
||||
use_session = (not context.executing_eagerly()
|
||||
and not ops.inside_function())
|
||||
use_session = (not context.executing_eagerly() and
|
||||
not ops.inside_function())
|
||||
if checkpoint_number:
|
||||
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
|
||||
if use_session:
|
||||
@ -976,8 +989,7 @@ class TrackableSaver(object):
|
||||
|
||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
||||
file_prefix=file_prefix_tensor,
|
||||
object_graph_tensor=object_graph_tensor)
|
||||
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
|
||||
if new_feed_additions:
|
||||
feed_dict.update(new_feed_additions)
|
||||
if not use_session:
|
||||
@ -1024,7 +1036,7 @@ class TrackableSaver(object):
|
||||
If the checkpoint has not been consumed completely, then the list of restore
|
||||
ops will grow as more objects are added to the dependency graph.
|
||||
|
||||
Name-based `tf.train.Saver` checkpoints can be loaded using this
|
||||
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
|
||||
method. There is no deferred loading, and names are used to match
|
||||
variables. No restore ops are created/run until `run_restore_ops()` or
|
||||
`initialize_or_restore()` are called on the returned status object, even
|
||||
@ -1035,9 +1047,9 @@ class TrackableSaver(object):
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||
object which may run initializers for objects in the dependency
|
||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
||||
names are used to match variables.
|
||||
object which may run initializers for objects in the dependency graph.
|
||||
If the checkpoint was written by the name-based
|
||||
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||
|
||||
Returns:
|
||||
A load status object, which can be used to make assertions about the
|
||||
@ -1057,8 +1069,7 @@ class TrackableSaver(object):
|
||||
else:
|
||||
dtype_map = reader.get_variable_to_dtype_map()
|
||||
try:
|
||||
object_graph_string = reader.get_tensor(
|
||||
base.OBJECT_GRAPH_PROTO_KEY)
|
||||
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||
except errors_impl.NotFoundError:
|
||||
# The object graph proto does not exist in this checkpoint. Try the
|
||||
# name-based compatibility mode.
|
||||
@ -1069,8 +1080,7 @@ class TrackableSaver(object):
|
||||
# pylint: disable=protected-access
|
||||
existing_trackable._maybe_initialize_trackable()
|
||||
existing_trackable._name_based_restores.add(restore_coordinator)
|
||||
existing_trackable._name_based_attribute_restore(
|
||||
restore_coordinator)
|
||||
existing_trackable._name_based_attribute_restore(restore_coordinator)
|
||||
# pylint: enable=protected-access
|
||||
return NameBasedSaverStatus(
|
||||
restore_coordinator, graph_view=self._graph_view)
|
||||
@ -1085,8 +1095,7 @@ class TrackableSaver(object):
|
||||
with ops.device("/cpu:0"):
|
||||
file_prefix_tensor = constant_op.constant(save_path)
|
||||
file_prefix_feed_dict = None
|
||||
object_graph_proto = (
|
||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||
object_graph_proto.ParseFromString(object_graph_string)
|
||||
checkpoint = _CheckpointRestoreCoordinator(
|
||||
object_graph_proto=object_graph_proto,
|
||||
@ -1094,8 +1103,8 @@ class TrackableSaver(object):
|
||||
save_path_tensor=file_prefix_tensor,
|
||||
restore_op_cache=self._restore_op_cache,
|
||||
graph_view=self._graph_view)
|
||||
base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(
|
||||
self._graph_view.root)
|
||||
base.CheckpointPosition(
|
||||
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
||||
load_status = CheckpointLoadStatus(
|
||||
checkpoint,
|
||||
graph_view=self._graph_view,
|
||||
@ -1104,7 +1113,7 @@ class TrackableSaver(object):
|
||||
|
||||
|
||||
def frozen_saver(root_trackable):
|
||||
"""Creates a static `tf.train.Saver` from a trackable object.
|
||||
"""Creates a static `tf.compat.v1.train.Saver` from a trackable object.
|
||||
|
||||
The returned `Saver` saves object-based checkpoints, but these checkpoints
|
||||
will no longer reflect structural changes to the object graph, only changes to
|
||||
@ -1135,9 +1144,9 @@ def saver_with_op_caching(obj):
|
||||
saveables_cache = None
|
||||
else:
|
||||
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
||||
return TrackableSaver(graph_view_lib.ObjectGraphView(
|
||||
weakref.ref(obj),
|
||||
saveables_cache=saveables_cache))
|
||||
return TrackableSaver(
|
||||
graph_view_lib.ObjectGraphView(
|
||||
weakref.ref(obj), saveables_cache=saveables_cache))
|
||||
|
||||
|
||||
# Mentions graph building / Sessions. The v2 version is below.
|
||||
@ -1146,7 +1155,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
"""Groups trackable objects, saving and restoring them.
|
||||
|
||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||
that contain trackable state, such as `tf.train.Optimizer`
|
||||
that contain trackable state, such as `tf.compat.v1.train.Optimizer`
|
||||
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
||||
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
||||
maintains a `save_counter` for numbering checkpoints.
|
||||
@ -1164,7 +1173,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||
train_op = optimizer.minimize( ... )
|
||||
status.assert_consumed() # Optional sanity checks.
|
||||
with tf.Session() as session:
|
||||
with tf.compat.v1.Session() as session:
|
||||
# Use the Session to restore variables, or initialize them if
|
||||
# tf.train.latest_checkpoint returned None.
|
||||
status.initialize_or_restore(session)
|
||||
@ -1179,7 +1188,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
tf.enable_eager_execution()
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
checkpoint_directory = "/tmp/training_checkpoints"
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
@ -1193,13 +1202,14 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
```
|
||||
|
||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
||||
checkpoints, in contrast to `tf.train.Saver` which writes and reads
|
||||
checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
|
||||
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
|
||||
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
|
||||
etc.) with named edges, and this graph is used to match variables when
|
||||
restoring a checkpoint. It can be more robust to changes in the Python
|
||||
program, and helps to support restore-on-create for variables when executing
|
||||
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code.
|
||||
eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
|
||||
code.
|
||||
|
||||
`Checkpoint` objects have dependencies on the objects passed as keyword
|
||||
arguments to their constructors, and each dependency is given a name that is
|
||||
@ -1244,6 +1254,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
Args:
|
||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||
saved with the checkpoint. Values must be trackable objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If objects in `kwargs` are not trackable.
|
||||
"""
|
||||
@ -1269,8 +1280,12 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
# add_variable creates a dependency named "save_counter"; NoDependency
|
||||
# prevents creating a second dependency named "_save_counter".
|
||||
self._save_counter = data_structures.NoDependency(
|
||||
add_variable(self, name="save_counter", initializer=0,
|
||||
dtype=dtypes.int64, trainable=False))
|
||||
add_variable(
|
||||
self,
|
||||
name="save_counter",
|
||||
initializer=0,
|
||||
dtype=dtypes.int64,
|
||||
trainable=False))
|
||||
|
||||
def write(self, file_prefix, session=None):
|
||||
"""Writes a training checkpoint.
|
||||
@ -1294,9 +1309,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
Returns:
|
||||
The full path to the checkpoint (i.e. `file_prefix`).
|
||||
"""
|
||||
output = self._saver.save(
|
||||
file_prefix=file_prefix,
|
||||
session=session)
|
||||
output = self._saver.save(file_prefix=file_prefix, session=session)
|
||||
if tensor_util.is_tensor(output):
|
||||
if context.executing_eagerly():
|
||||
return compat.as_str(output.numpy())
|
||||
@ -1370,8 +1383,8 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
checkpoint_number = session.run(self._save_assign_op)
|
||||
else:
|
||||
checkpoint_number = assign_op.numpy()
|
||||
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
|
||||
session=session)
|
||||
file_path = self.write(
|
||||
"%s-%d" % (file_prefix, checkpoint_number), session=session)
|
||||
checkpoint_management.update_checkpoint_state_internal(
|
||||
save_dir=os.path.dirname(file_prefix),
|
||||
model_checkpoint_path=file_path,
|
||||
@ -1417,7 +1430,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
If the checkpoint has not been consumed completely, then the list of restore
|
||||
ops will grow as more objects are added to the dependency graph.
|
||||
|
||||
Name-based `tf.train.Saver` checkpoints can be loaded using this
|
||||
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
|
||||
method. Names are used to match variables. No restore ops are created/run
|
||||
until `run_restore_ops()` or `initialize_or_restore()` are called on the
|
||||
returned status object when graph building, but there is restore-on-creation
|
||||
@ -1428,9 +1441,9 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||
object which may run initializers for objects in the dependency
|
||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
||||
names are used to match variables.
|
||||
object which may run initializers for objects in the dependency graph.
|
||||
If the checkpoint was written by the name-based
|
||||
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||
|
||||
Returns:
|
||||
A load status object, which can be used to make assertions about the
|
||||
@ -1453,7 +1466,8 @@ class CheckpointV1(tracking.AutoTrackable):
|
||||
built, and so has not created any variables, will pass this assertion
|
||||
but fail `assert_consumed`. Useful when loading part of a larger
|
||||
checkpoint into a new Python program, e.g. a training checkpoint with
|
||||
a `tf.train.Optimizer` was saved but only the state required for
|
||||
a `tf.compat.v1.train.Optimizer` was saved but only the state required
|
||||
for
|
||||
inference is being loaded. This method returns the status object, and
|
||||
so may be chained with `initialize_or_restore` or `run_restore_ops`.
|
||||
|
||||
@ -1488,7 +1502,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
"""Groups trackable objects, saving and restoring them.
|
||||
|
||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||
that contain trackable state, such as `tf.train.Optimizer`
|
||||
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
|
||||
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
||||
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
||||
maintains a `save_counter` for numbering checkpoints.
|
||||
@ -1511,7 +1525,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
```
|
||||
|
||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
||||
checkpoints, in contrast to TensorFlow 1.x's `tf.train.Saver` which writes and
|
||||
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
|
||||
writes and
|
||||
reads `variable.name` based checkpoints. Object-based checkpointing saves a
|
||||
graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
|
||||
`Variable`s, etc.) with named edges, and this graph is used to match variables
|
||||
@ -1561,6 +1576,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
Args:
|
||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||
saved with the checkpoint. Values must be trackable objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If objects in `kwargs` are not trackable.
|
||||
"""
|
||||
@ -1586,8 +1602,12 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
# add_variable creates a dependency named "save_counter"; NoDependency
|
||||
# prevents creating a second dependency named "_save_counter".
|
||||
self._save_counter = data_structures.NoDependency(
|
||||
add_variable(self, name="save_counter", initializer=0,
|
||||
dtype=dtypes.int64, trainable=False))
|
||||
add_variable(
|
||||
self,
|
||||
name="save_counter",
|
||||
initializer=0,
|
||||
dtype=dtypes.int64,
|
||||
trainable=False))
|
||||
|
||||
def write(self, file_prefix):
|
||||
"""Writes a training checkpoint.
|
||||
@ -1608,8 +1628,7 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
Returns:
|
||||
The full path to the checkpoint (i.e. `file_prefix`).
|
||||
"""
|
||||
output = self._saver.save(
|
||||
file_prefix=file_prefix)
|
||||
output = self._saver.save(file_prefix=file_prefix)
|
||||
if tensor_util.is_tensor(output):
|
||||
if context.executing_eagerly():
|
||||
return compat.as_str(output.numpy())
|
||||
@ -1711,7 +1730,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
were not found in the checkpoint, or if any checkpointed values do not have
|
||||
a matching Python object.
|
||||
|
||||
Name-based `tf.train.Saver` checkpoints from TensorFlow 1.x can be loaded
|
||||
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
|
||||
loaded
|
||||
using this method. Names are used to match variables. Re-encode name-based
|
||||
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
||||
|
||||
@ -1719,9 +1739,9 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
save_path: The path to the checkpoint, as returned by `save` or
|
||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||
object which may run initializers for objects in the dependency
|
||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
||||
names are used to match variables.
|
||||
object which may run initializers for objects in the dependency graph.
|
||||
If the checkpoint was written by the name-based
|
||||
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||
|
||||
Returns:
|
||||
A load status object, which can be used to make assertions about the
|
||||
@ -1744,7 +1764,8 @@ class Checkpoint(tracking.AutoTrackable):
|
||||
built, and so has not created any variables, will pass this assertion
|
||||
but fail `assert_consumed`. Useful when loading part of a larger
|
||||
checkpoint into a new Python program, e.g. a training checkpoint with
|
||||
a `tf.train.Optimizer` was saved but only the state required for
|
||||
a `tf.compat.v1.train.Optimizer` was saved but only the state required
|
||||
for
|
||||
inference is being loaded. This method returns the status object, and
|
||||
so may be chained with other assertions.
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utility functions for training."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -34,7 +33,6 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# collection keys.
|
||||
GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
|
||||
|
||||
|
||||
# TODO(drpng): remove this after legacy uses are resolved.
|
||||
write_graph = graph_io.write_graph
|
||||
|
||||
@ -47,11 +45,12 @@ def global_step(sess, global_step_tensor):
|
||||
# Create a variable to hold the global_step.
|
||||
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
|
||||
# Create a session.
|
||||
sess = tf.Session()
|
||||
sess = tf.compat.v1.Session()
|
||||
# Initialize the variable
|
||||
sess.run(global_step_tensor.initializer)
|
||||
# Get the variable value.
|
||||
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))
|
||||
print('global_step: %s' % tf.compat.v1.train.global_step(sess,
|
||||
global_step_tensor))
|
||||
|
||||
global_step: 10
|
||||
```
|
||||
@ -109,8 +108,8 @@ def create_global_step(graph=None):
|
||||
"""Create global step tensor in graph.
|
||||
|
||||
Args:
|
||||
graph: The graph in which to create the global step tensor. If missing,
|
||||
use default graph.
|
||||
graph: The graph in which to create the global step tensor. If missing, use
|
||||
default graph.
|
||||
|
||||
Returns:
|
||||
Global step tensor.
|
||||
@ -130,8 +129,9 @@ def create_global_step(graph=None):
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
collections=[
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
|
||||
])
|
||||
# Create in proper graph and base name_scope.
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
return variable_scope.get_variable(
|
||||
@ -141,8 +141,7 @@ def create_global_step(graph=None):
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
|
||||
|
||||
|
||||
@tf_export(v1=['train.get_or_create_global_step'])
|
||||
@ -173,9 +172,8 @@ def assert_global_step(global_step_tensor):
|
||||
if not (isinstance(global_step_tensor, variables.Variable) or
|
||||
isinstance(global_step_tensor, ops.Tensor) or
|
||||
resource_variable_ops.is_resource_variable(global_step_tensor)):
|
||||
raise TypeError(
|
||||
'Existing "global_step" must be a Variable or Tensor: %s.' %
|
||||
global_step_tensor)
|
||||
raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
|
||||
global_step_tensor)
|
||||
|
||||
if not global_step_tensor.dtype.base_dtype.is_integer:
|
||||
raise TypeError('Existing "global_step" does not have integer type: %s' %
|
||||
|
@ -49,8 +49,8 @@ class VocabInfo(
|
||||
VocabInfo to warm-start.
|
||||
|
||||
Attributes:
|
||||
new_vocab: [Required] A path to the new vocabulary file (used with the
|
||||
model to be trained).
|
||||
new_vocab: [Required] A path to the new vocabulary file (used with the model
|
||||
to be trained).
|
||||
new_vocab_size: [Required] An integer indicating how many entries of the new
|
||||
vocabulary will used in training.
|
||||
num_oov_buckets: [Required] An integer indicating how many OOV buckets are
|
||||
@ -76,7 +76,7 @@ class VocabInfo(
|
||||
num_oov_buckets=1,
|
||||
old_vocab='pretrained_embeddings_vocab',
|
||||
old_vocab_size=10000,
|
||||
backup_initializer=tf.truncated_normal_initializer(
|
||||
backup_initializer=tf.compat.v1.truncated_normal_initializer(
|
||||
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
|
||||
axis=0)
|
||||
|
||||
@ -86,7 +86,7 @@ class VocabInfo(
|
||||
num_oov_buckets=0, # No OOV for classes.
|
||||
old_vocab='old_class_vocab',
|
||||
old_vocab_size=8,
|
||||
backup_initializer=tf.glorot_uniform_initializer(),
|
||||
backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
|
||||
axis=1)
|
||||
|
||||
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
|
||||
@ -95,7 +95,7 @@ class VocabInfo(
|
||||
num_oov_buckets=0, # No OOV for classes.
|
||||
old_vocab='old_class_vocab',
|
||||
old_vocab_size=8,
|
||||
backup_initializer=tf.zeros_initializer(),
|
||||
backup_initializer=tf.compat.v1.zeros_initializer(),
|
||||
axis=0)
|
||||
|
||||
Currently, only axis=0 and axis=1 are supported.
|
||||
@ -255,8 +255,7 @@ def _warm_start_var_with_vocab(var,
|
||||
partition_info = None
|
||||
if slice_info:
|
||||
partition_info = variable_scope._PartitionInfo(
|
||||
full_shape=slice_info.full_shape,
|
||||
var_offset=slice_info.var_offset)
|
||||
full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
|
||||
|
||||
if axis == 0:
|
||||
new_row_vocab_size = current_vocab_size
|
||||
@ -301,6 +300,8 @@ def _warm_start_var_with_vocab(var,
|
||||
new_init_val = ops.convert_to_tensor(
|
||||
init(shape=v_shape, partition_info=partition_info))
|
||||
v._initializer_op = state_ops.assign(v, new_init_val)
|
||||
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@ -314,7 +315,8 @@ def _get_grouped_variables(vars_to_warm_start):
|
||||
vars_to_warm_start: One of the following:
|
||||
|
||||
- A regular expression (string) that captures which variables to
|
||||
warm-start (see tf.get_collection). This expression will only consider
|
||||
warm-start (see tf.compat.v1.get_collection). This expression will
|
||||
only consider
|
||||
variables in the TRAINABLE_VARIABLES collection.
|
||||
- A list of Variables to warm-start.
|
||||
- A list of strings, each representing a full variable name to warm-start.
|
||||
@ -330,14 +332,13 @@ def _get_grouped_variables(vars_to_warm_start):
|
||||
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
||||
# everything (in TRAINABLE_VARIABLES) here.
|
||||
list_of_vars = ops.get_collection(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||
scope=vars_to_warm_start)
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
|
||||
elif isinstance(vars_to_warm_start, list):
|
||||
if all(isinstance(v, str) for v in vars_to_warm_start):
|
||||
list_of_vars = []
|
||||
for v in vars_to_warm_start:
|
||||
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
scope=v)
|
||||
list_of_vars += ops.get_collection(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
|
||||
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
|
||||
list_of_vars = vars_to_warm_start
|
||||
else:
|
||||
@ -377,16 +378,16 @@ def warm_start(ckpt_to_initialize_from,
|
||||
vars_to_warm_start: [Optional] One of the following:
|
||||
|
||||
- A regular expression (string) that captures which variables to
|
||||
warm-start (see tf.get_collection). This expression will only consider
|
||||
variables in the TRAINABLE_VARIABLES collection -- if you need to
|
||||
warm-start non_TRAINABLE vars (such as optimizer accumulators or batch
|
||||
norm statistics), please use the below option.
|
||||
warm-start (see tf.compat.v1.get_collection). This expression will only
|
||||
consider variables in the TRAINABLE_VARIABLES collection -- if you need
|
||||
to warm-start non_TRAINABLE vars (such as optimizer accumulators or
|
||||
batch norm statistics), please use the below option.
|
||||
- A list of Variables to warm-start. If you do not have access to the
|
||||
`Variable` objects at the call site, please use the below option.
|
||||
- A list of strings, each a regex scope provided to tf.get_collection with
|
||||
GLOBAL_VARIABLES (please see tf.get_collection). For backwards
|
||||
compatibility reasons, this is separate from the single-string argument
|
||||
type.
|
||||
- A list of strings, each a regex scope provided to
|
||||
tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
|
||||
tf.compat.v1.get_collection). For backwards compatibility reasons,
|
||||
this is separate from the single-string argument type.
|
||||
- `None`, in which case only variables specified in
|
||||
`var_name_to_vocab_info` will be warm-started.
|
||||
|
||||
@ -404,6 +405,7 @@ def warm_start(ckpt_to_initialize_from,
|
||||
effect on the set of variables that is warm-started, and only controls
|
||||
name mapping (use `vars_to_warm_start` for controlling what variables to
|
||||
warm-start).
|
||||
|
||||
Raises:
|
||||
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
|
||||
configuration for variable names that are not used. This is to ensure
|
||||
|
Loading…
Reference in New Issue
Block a user