Apply tf1->tf2 name replaces to doc-strings and comments in tensorflow.
No code changes, only doc-strings and comments. PiperOrigin-RevId: 244275767
This commit is contained in:
parent
7fdf27b688
commit
18b680216e
@ -22,8 +22,11 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["train.basic_train_loop"])
|
@tf_export(v1=["train.basic_train_loop"])
|
||||||
def basic_train_loop(supervisor, train_step_fn, args=None,
|
def basic_train_loop(supervisor,
|
||||||
kwargs=None, master=""):
|
train_step_fn,
|
||||||
|
args=None,
|
||||||
|
kwargs=None,
|
||||||
|
master=""):
|
||||||
"""Basic loop to train a model.
|
"""Basic loop to train a model.
|
||||||
|
|
||||||
Calls `train_step_fn` in a loop to train a model. The function is called as:
|
Calls `train_step_fn` in a loop to train a model. The function is called as:
|
||||||
@ -32,17 +35,18 @@ def basic_train_loop(supervisor, train_step_fn, args=None,
|
|||||||
train_step_fn(session, *args, **kwargs)
|
train_step_fn(session, *args, **kwargs)
|
||||||
```
|
```
|
||||||
|
|
||||||
It is passed a `tf.Session` in addition to `args` and `kwargs`. The function
|
It is passed a `tf.compat.v1.Session` in addition to `args` and `kwargs`. The
|
||||||
|
function
|
||||||
typically runs one training step in the session.
|
typically runs one training step in the session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
supervisor: `tf.train.Supervisor` to run the training services.
|
supervisor: `tf.compat.v1.train.Supervisor` to run the training services.
|
||||||
train_step_fn: Callable to execute one training step. Called
|
train_step_fn: Callable to execute one training step. Called repeatedly as
|
||||||
repeatedly as `train_step_fn(session, *args **kwargs)`.
|
`train_step_fn(session, *args **kwargs)`.
|
||||||
args: Optional positional arguments passed to `train_step_fn`.
|
args: Optional positional arguments passed to `train_step_fn`.
|
||||||
kwargs: Optional keyword arguments passed to `train_step_fn`.
|
kwargs: Optional keyword arguments passed to `train_step_fn`.
|
||||||
master: Master to use to create the training session. Defaults to
|
master: Master to use to create the training session. Defaults to `""`
|
||||||
`""` which causes the session to be created in the local process.
|
which causes the session to be created in the local process.
|
||||||
"""
|
"""
|
||||||
if args is None:
|
if args is None:
|
||||||
args = []
|
args = []
|
||||||
|
@ -42,7 +42,6 @@ from tensorflow.python.training.session_run_hook import SessionRunArgs
|
|||||||
from tensorflow.python.training.summary_io import SummaryWriterCache
|
from tensorflow.python.training.summary_io import SummaryWriterCache
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
_HOOKS = "hooks"
|
_HOOKS = "hooks"
|
||||||
_STEPS_PER_RUN_VAR = "steps_per_run"
|
_STEPS_PER_RUN_VAR = "steps_per_run"
|
||||||
|
|
||||||
@ -85,8 +84,7 @@ class _HookTimer(object):
|
|||||||
|
|
||||||
@tf_export(v1=["train.SecondOrStepTimer"])
|
@tf_export(v1=["train.SecondOrStepTimer"])
|
||||||
class SecondOrStepTimer(_HookTimer):
|
class SecondOrStepTimer(_HookTimer):
|
||||||
"""Timer that triggers at most once every N seconds or once every N steps.
|
"""Timer that triggers at most once every N seconds or once every N steps."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, every_secs=None, every_steps=None):
|
def __init__(self, every_secs=None, every_steps=None):
|
||||||
self.reset()
|
self.reset()
|
||||||
@ -171,29 +169,33 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
|
|||||||
seeing the logs, you might want to add the following line after your imports:
|
seeing the logs, you might want to add the following line after your imports:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tf.logging.set_verbosity(tf.logging.INFO)
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that if `at_end` is True, `tensors` should not include any tensor
|
Note that if `at_end` is True, `tensors` should not include any tensor
|
||||||
whose evaluation produces a side effect such as consuming additional inputs.
|
whose evaluation produces a side effect such as consuming additional inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
|
def __init__(self,
|
||||||
at_end=False, formatter=None):
|
tensors,
|
||||||
|
every_n_iter=None,
|
||||||
|
every_n_secs=None,
|
||||||
|
at_end=False,
|
||||||
|
formatter=None):
|
||||||
"""Initializes a `LoggingTensorHook`.
|
"""Initializes a `LoggingTensorHook`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensors: `dict` that maps string-valued tags to tensors/tensor names,
|
tensors: `dict` that maps string-valued tags to tensors/tensor names, or
|
||||||
or `iterable` of tensors/tensor names.
|
`iterable` of tensors/tensor names.
|
||||||
every_n_iter: `int`, print the values of `tensors` once every N local
|
every_n_iter: `int`, print the values of `tensors` once every N local
|
||||||
steps taken on the current worker.
|
steps taken on the current worker.
|
||||||
every_n_secs: `int` or `float`, print the values of `tensors` once every N
|
every_n_secs: `int` or `float`, print the values of `tensors` once every N
|
||||||
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
|
seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
|
||||||
provided.
|
provided.
|
||||||
at_end: `bool` specifying whether to print the values of `tensors` at the
|
at_end: `bool` specifying whether to print the values of `tensors` at the
|
||||||
end of the run.
|
end of the run.
|
||||||
formatter: function, takes dict of `tag`->`Tensor` and returns a string.
|
formatter: function, takes dict of `tag`->`Tensor` and returns a string.
|
||||||
If `None` uses default printing all tensors.
|
If `None` uses default printing all tensors.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `every_n_iter` is non-positive.
|
ValueError: if `every_n_iter` is non-positive.
|
||||||
@ -215,16 +217,18 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
|
|||||||
self._tensors = tensors
|
self._tensors = tensors
|
||||||
self._formatter = formatter
|
self._formatter = formatter
|
||||||
self._timer = (
|
self._timer = (
|
||||||
NeverTriggerTimer() if only_log_at_end else
|
NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
|
||||||
SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter))
|
every_secs=every_n_secs, every_steps=every_n_iter))
|
||||||
self._log_at_end = at_end
|
self._log_at_end = at_end
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
self._timer.reset()
|
self._timer.reset()
|
||||||
self._iter_count = 0
|
self._iter_count = 0
|
||||||
# Convert names to tensors if given
|
# Convert names to tensors if given
|
||||||
self._current_tensors = {tag: _as_graph_element(tensor)
|
self._current_tensors = {
|
||||||
for (tag, tensor) in self._tensors.items()}
|
tag: _as_graph_element(tensor)
|
||||||
|
for (tag, tensor) in self._tensors.items()
|
||||||
|
}
|
||||||
|
|
||||||
def before_run(self, run_context): # pylint: disable=unused-argument
|
def before_run(self, run_context): # pylint: disable=unused-argument
|
||||||
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
|
self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
|
||||||
@ -463,9 +467,10 @@ class CheckpointSaverListener(object):
|
|||||||
|
|
||||||
...
|
...
|
||||||
listener = ExampleCheckpointSaverListener()
|
listener = ExampleCheckpointSaverListener()
|
||||||
saver_hook = tf.train.CheckpointSaverHook(
|
saver_hook = tf.estimator.CheckpointSaverHook(
|
||||||
checkpoint_dir, listeners=[listener])
|
checkpoint_dir, listeners=[listener])
|
||||||
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
|
with
|
||||||
|
tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -516,9 +521,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
|||||||
saver: `Saver` object, used for saving.
|
saver: `Saver` object, used for saving.
|
||||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||||
scaffold: `Scaffold`, use to get saver object.
|
scaffold: `Scaffold`, use to get saver object.
|
||||||
listeners: List of `CheckpointSaverListener` subclass instances.
|
listeners: List of `CheckpointSaverListener` subclass instances. Used for
|
||||||
Used for callbacks that run immediately before or after this hook saves
|
callbacks that run immediately before or after this hook saves the
|
||||||
the checkpoint.
|
checkpoint.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: One of `save_steps` or `save_secs` should be set.
|
ValueError: One of `save_steps` or `save_secs` should be set.
|
||||||
@ -531,8 +536,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
|||||||
self._checkpoint_dir = checkpoint_dir
|
self._checkpoint_dir = checkpoint_dir
|
||||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||||
self._scaffold = scaffold
|
self._scaffold = scaffold
|
||||||
self._timer = SecondOrStepTimer(every_secs=save_secs,
|
self._timer = SecondOrStepTimer(
|
||||||
every_steps=save_steps)
|
every_secs=save_secs, every_steps=save_steps)
|
||||||
self._listeners = listeners or []
|
self._listeners = listeners or []
|
||||||
self._steps_per_run = 1
|
self._steps_per_run = 1
|
||||||
|
|
||||||
@ -555,13 +560,11 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
|||||||
# add variables in begin. Graph is finalized after all begin calls.
|
# add variables in begin. Graph is finalized after all begin calls.
|
||||||
training_util.write_graph(
|
training_util.write_graph(
|
||||||
ops.get_default_graph().as_graph_def(add_shapes=True),
|
ops.get_default_graph().as_graph_def(add_shapes=True),
|
||||||
self._checkpoint_dir,
|
self._checkpoint_dir, "graph.pbtxt")
|
||||||
"graph.pbtxt")
|
|
||||||
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
saver_def = self._get_saver().saver_def if self._get_saver() else None
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||||
graph_def=graph.as_graph_def(add_shapes=True),
|
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
|
||||||
saver_def=saver_def)
|
|
||||||
self._summary_writer.add_graph(graph)
|
self._summary_writer.add_graph(graph)
|
||||||
self._summary_writer.add_meta_graph(meta_graph_def)
|
self._summary_writer.add_meta_graph(meta_graph_def)
|
||||||
# The checkpoint saved here is the state at step "global_step".
|
# The checkpoint saved here is the state at step "global_step".
|
||||||
@ -573,8 +576,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
|||||||
|
|
||||||
def after_run(self, run_context, run_values):
|
def after_run(self, run_context, run_values):
|
||||||
stale_global_step = run_values.results
|
stale_global_step = run_values.results
|
||||||
if self._timer.should_trigger_for_step(
|
if self._timer.should_trigger_for_step(stale_global_step +
|
||||||
stale_global_step + self._steps_per_run):
|
self._steps_per_run):
|
||||||
# get the real value after train op.
|
# get the real value after train op.
|
||||||
global_step = run_context.session.run(self._global_step_tensor)
|
global_step = run_context.session.run(self._global_step_tensor)
|
||||||
if self._timer.should_trigger_for_step(global_step):
|
if self._timer.should_trigger_for_step(global_step):
|
||||||
@ -627,8 +630,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
|
|||||||
elif len(savers) > 1:
|
elif len(savers) > 1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"More than one item in collection {}. "
|
"More than one item in collection {}. "
|
||||||
"Please indicate which one to use by passing it to the constructor.".
|
"Please indicate which one to use by passing it to the constructor."
|
||||||
format(collection_key))
|
.format(collection_key))
|
||||||
|
|
||||||
self._saver = savers[0]
|
self._saver = savers[0]
|
||||||
return savers[0]
|
return savers[0]
|
||||||
@ -647,8 +650,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
|||||||
if (every_n_steps is None) == (every_n_secs is None):
|
if (every_n_steps is None) == (every_n_secs is None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"exactly one of every_n_steps and every_n_secs should be provided.")
|
"exactly one of every_n_steps and every_n_secs should be provided.")
|
||||||
self._timer = SecondOrStepTimer(every_steps=every_n_steps,
|
self._timer = SecondOrStepTimer(
|
||||||
every_secs=every_n_secs)
|
every_steps=every_n_steps, every_secs=every_n_secs)
|
||||||
|
|
||||||
self._summary_writer = summary_writer
|
self._summary_writer = summary_writer
|
||||||
self._output_dir = output_dir
|
self._output_dir = output_dir
|
||||||
@ -673,8 +676,9 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
|||||||
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
|
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
|
||||||
steps_per_sec = elapsed_steps / elapsed_time
|
steps_per_sec = elapsed_steps / elapsed_time
|
||||||
if self._summary_writer is not None:
|
if self._summary_writer is not None:
|
||||||
summary = Summary(value=[Summary.Value(
|
summary = Summary(value=[
|
||||||
tag=self._summary_tag, simple_value=steps_per_sec)])
|
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
|
||||||
|
])
|
||||||
self._summary_writer.add_summary(summary, global_step)
|
self._summary_writer.add_summary(summary, global_step)
|
||||||
logging.info("%s: %g", self._summary_tag, steps_per_sec)
|
logging.info("%s: %g", self._summary_tag, steps_per_sec)
|
||||||
|
|
||||||
@ -682,8 +686,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
|
|||||||
_ = run_context
|
_ = run_context
|
||||||
|
|
||||||
stale_global_step = run_values.results
|
stale_global_step = run_values.results
|
||||||
if self._timer.should_trigger_for_step(
|
if self._timer.should_trigger_for_step(stale_global_step +
|
||||||
stale_global_step + self._steps_per_run):
|
self._steps_per_run):
|
||||||
# get the real value after train op.
|
# get the real value after train op.
|
||||||
global_step = run_context.session.run(self._global_step_tensor)
|
global_step = run_context.session.run(self._global_step_tensor)
|
||||||
if self._timer.should_trigger_for_step(global_step):
|
if self._timer.should_trigger_for_step(global_step):
|
||||||
@ -767,18 +771,18 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_steps: `int`, save summaries every N steps. Exactly one of
|
save_steps: `int`, save summaries every N steps. Exactly one of
|
||||||
`save_secs` and `save_steps` should be set.
|
`save_secs` and `save_steps` should be set.
|
||||||
save_secs: `int`, save summaries every N seconds.
|
save_secs: `int`, save summaries every N seconds.
|
||||||
output_dir: `string`, the directory to save the summaries to. Only used
|
output_dir: `string`, the directory to save the summaries to. Only used if
|
||||||
if no `summary_writer` is supplied.
|
no `summary_writer` is supplied.
|
||||||
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
|
summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
|
||||||
one will be created accordingly.
|
one will be created accordingly.
|
||||||
scaffold: `Scaffold` to get summary_op if it's not provided.
|
scaffold: `Scaffold` to get summary_op if it's not provided.
|
||||||
summary_op: `Tensor` of type `string` containing the serialized `Summary`
|
summary_op: `Tensor` of type `string` containing the serialized `Summary`
|
||||||
protocol buffer or a list of `Tensor`. They are most likely an output
|
protocol buffer or a list of `Tensor`. They are most likely an output by
|
||||||
by TF summary methods like `tf.summary.scalar` or
|
TF summary methods like `tf.compat.v1.summary.scalar` or
|
||||||
`tf.summary.merge_all`. It can be passed in as one tensor; if more
|
`tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
|
||||||
than one, they must be passed in as a list.
|
more than one, they must be passed in as a list.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Exactly one of scaffold or summary_op should be set.
|
ValueError: Exactly one of scaffold or summary_op should be set.
|
||||||
@ -791,8 +795,8 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
|
|||||||
self._summary_writer = summary_writer
|
self._summary_writer = summary_writer
|
||||||
self._output_dir = output_dir
|
self._output_dir = output_dir
|
||||||
self._scaffold = scaffold
|
self._scaffold = scaffold
|
||||||
self._timer = SecondOrStepTimer(every_secs=save_secs,
|
self._timer = SecondOrStepTimer(
|
||||||
every_steps=save_steps)
|
every_secs=save_secs, every_steps=save_steps)
|
||||||
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
@ -903,8 +907,9 @@ class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
|
|||||||
self._worker_is_started = True
|
self._worker_is_started = True
|
||||||
return None
|
return None
|
||||||
if current_step - last_logged_step > 1000:
|
if current_step - last_logged_step > 1000:
|
||||||
logging.info("Waiting for global step %d before starting training. "
|
logging.info(
|
||||||
"Current step is %d.", self._wait_until_step, current_step)
|
"Waiting for global step %d before starting training. "
|
||||||
|
"Current step is %d.", self._wait_until_step, current_step)
|
||||||
last_logged_step = current_step
|
last_logged_step = current_step
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
@ -917,8 +922,8 @@ class FinalOpsHook(session_run_hook.SessionRunHook):
|
|||||||
"""Initializes `FinalOpHook` with ops to run at the end of the session.
|
"""Initializes `FinalOpHook` with ops to run at the end of the session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of
|
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
||||||
names to `Tensors`.
|
to `Tensors`.
|
||||||
final_ops_feed_dict: A feed dictionary to use when running
|
final_ops_feed_dict: A feed dictionary to use when running
|
||||||
`final_ops_dict`.
|
`final_ops_dict`.
|
||||||
"""
|
"""
|
||||||
@ -997,14 +1002,14 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
save_steps: `int`, save profile traces every N steps. Exactly one of
|
save_steps: `int`, save profile traces every N steps. Exactly one of
|
||||||
`save_secs` and `save_steps` should be set.
|
`save_secs` and `save_steps` should be set.
|
||||||
save_secs: `int` or `float`, save profile traces every N seconds.
|
save_secs: `int` or `float`, save profile traces every N seconds.
|
||||||
output_dir: `string`, the directory to save the profile traces to.
|
output_dir: `string`, the directory to save the profile traces to.
|
||||||
Defaults to the current directory.
|
Defaults to the current directory.
|
||||||
show_dataflow: `bool`, if True, add flow events to the trace connecting
|
show_dataflow: `bool`, if True, add flow events to the trace connecting
|
||||||
producers and consumers of tensors.
|
producers and consumers of tensors.
|
||||||
show_memory: `bool`, if True, add object snapshot events to the trace
|
show_memory: `bool`, if True, add object snapshot events to the trace
|
||||||
showing the sizes and lifetimes of tensors.
|
showing the sizes and lifetimes of tensors.
|
||||||
"""
|
"""
|
||||||
self._output_file = os.path.join(output_dir, "timeline-{}.json")
|
self._output_file = os.path.join(output_dir, "timeline-{}.json")
|
||||||
self._file_writer = SummaryWriterCache.get(output_dir)
|
self._file_writer = SummaryWriterCache.get(output_dir)
|
||||||
@ -1024,8 +1029,9 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
|||||||
self._next_step is not None and
|
self._next_step is not None and
|
||||||
self._timer.should_trigger_for_step(self._next_step))
|
self._timer.should_trigger_for_step(self._next_step))
|
||||||
requests = {"global_step": self._global_step_tensor}
|
requests = {"global_step": self._global_step_tensor}
|
||||||
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
opts = (
|
||||||
if self._request_summary else None)
|
config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
|
||||||
|
if self._request_summary else None)
|
||||||
|
|
||||||
return SessionRunArgs(requests, options=opts)
|
return SessionRunArgs(requests, options=opts)
|
||||||
|
|
||||||
@ -1039,8 +1045,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
|
|||||||
if self._request_summary:
|
if self._request_summary:
|
||||||
global_step = run_context.session.run(self._global_step_tensor)
|
global_step = run_context.session.run(self._global_step_tensor)
|
||||||
self._timer.update_last_triggered_step(global_step)
|
self._timer.update_last_triggered_step(global_step)
|
||||||
self._save(global_step,
|
self._save(global_step, self._output_file.format(global_step),
|
||||||
self._output_file.format(global_step),
|
|
||||||
run_values.run_metadata.step_stats)
|
run_values.run_metadata.step_stats)
|
||||||
self._file_writer.add_run_metadata(run_values.run_metadata,
|
self._file_writer.add_run_metadata(run_values.run_metadata,
|
||||||
"step_%d" % global_step)
|
"step_%d" % global_step)
|
||||||
|
@ -106,7 +106,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
|||||||
"""Replaces `tf.Variable` initializers so they load from a checkpoint file.
|
"""Replaces `tf.Variable` initializers so they load from a checkpoint file.
|
||||||
|
|
||||||
Values are not loaded immediately, but when the initializer is run
|
Values are not loaded immediately, but when the initializer is run
|
||||||
(typically by running a `tf.global_variables_initializer` op).
|
(typically by running a `tf.compat.v1.global_variables_initializer` op).
|
||||||
|
|
||||||
Note: This overrides default initialization ops of specified variables and
|
Note: This overrides default initialization ops of specified variables and
|
||||||
redefines dtype.
|
redefines dtype.
|
||||||
@ -139,15 +139,15 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
|||||||
# -- name='old_scope_2/var3', shape=[100, 100]
|
# -- name='old_scope_2/var3', shape=[100, 100]
|
||||||
|
|
||||||
# Create new model's variables
|
# Create new model's variables
|
||||||
with tf.variable_scope('new_scope_1'):
|
with tf.compat.v1.variable_scope('new_scope_1'):
|
||||||
var1 = tf.get_variable('var1', shape=[20, 2],
|
var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
|
||||||
initializer=tf.zeros_initializer())
|
initializer=tf.compat.v1.zeros_initializer())
|
||||||
with tf.variable_scope('new_scope_2'):
|
with tf.compat.v1.variable_scope('new_scope_2'):
|
||||||
var2 = tf.get_variable('var2', shape=[50, 4],
|
var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
|
||||||
initializer=tf.zeros_initializer())
|
initializer=tf.compat.v1.zeros_initializer())
|
||||||
# Partition into 5 variables along the first axis.
|
# Partition into 5 variables along the first axis.
|
||||||
var3 = tf.get_variable(name='var3', shape=[100, 100],
|
var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
|
||||||
initializer=tf.zeros_initializer(),
|
initializer=tf.compat.v1.zeros_initializer(),
|
||||||
partitioner=lambda shape, dtype: [5, 1])
|
partitioner=lambda shape, dtype: [5, 1])
|
||||||
|
|
||||||
# Initialize all variables in `new_scope_1` from `old_scope_1`.
|
# Initialize all variables in `new_scope_1` from `old_scope_1`.
|
||||||
|
@ -131,9 +131,13 @@ class _ReplicaDeviceChooser(object):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["train.replica_device_setter"])
|
@tf_export(v1=["train.replica_device_setter"])
|
||||||
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
def replica_device_setter(ps_tasks=0,
|
||||||
worker_device="/job:worker", merge_devices=True,
|
ps_device="/job:ps",
|
||||||
cluster=None, ps_ops=None, ps_strategy=None):
|
worker_device="/job:worker",
|
||||||
|
merge_devices=True,
|
||||||
|
cluster=None,
|
||||||
|
ps_ops=None,
|
||||||
|
ps_strategy=None):
|
||||||
"""Return a `device function` to use when building a Graph for replicas.
|
"""Return a `device function` to use when building a Graph for replicas.
|
||||||
|
|
||||||
Device Functions are used in `with tf.device(device_function):` statement to
|
Device Functions are used in `with tf.device(device_function):` statement to
|
||||||
@ -158,7 +162,8 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
|||||||
cluster_spec = {
|
cluster_spec = {
|
||||||
"ps": ["ps0:2222", "ps1:2222"],
|
"ps": ["ps0:2222", "ps1:2222"],
|
||||||
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
|
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
|
||||||
with tf.device(tf.train.replica_device_setter(cluster=cluster_spec)):
|
with
|
||||||
|
tf.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
|
||||||
# Build your graph
|
# Build your graph
|
||||||
v1 = tf.Variable(...) # assigned to /job:ps/task:0
|
v1 = tf.Variable(...) # assigned to /job:ps/task:0
|
||||||
v2 = tf.Variable(...) # assigned to /job:ps/task:1
|
v2 = tf.Variable(...) # assigned to /job:ps/task:1
|
||||||
@ -218,6 +223,6 @@ def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
|
|||||||
ps_strategy = _RoundRobinStrategy(ps_tasks)
|
ps_strategy = _RoundRobinStrategy(ps_tasks)
|
||||||
if not six.callable(ps_strategy):
|
if not six.callable(ps_strategy):
|
||||||
raise TypeError("ps_strategy must be callable")
|
raise TypeError("ps_strategy must be callable")
|
||||||
chooser = _ReplicaDeviceChooser(
|
chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
|
||||||
ps_tasks, ps_device, worker_device, merge_devices, ps_ops, ps_strategy)
|
merge_devices, ps_ops, ps_strategy)
|
||||||
return chooser.device_function
|
return chooser.device_function
|
||||||
|
@ -65,8 +65,8 @@ def _get_latest_eval_step_value(update_ops):
|
|||||||
"""Gets the eval step `Tensor` value after running `update_ops`.
|
"""Gets the eval step `Tensor` value after running `update_ops`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`,
|
update_ops: A list of `Tensors` or a dictionary of names to `Tensors`, which
|
||||||
which are run before reading the eval step value.
|
are run before reading the eval step value.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor` representing the value for the evaluation step.
|
A `Tensor` representing the value for the evaluation step.
|
||||||
@ -102,21 +102,20 @@ class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
|||||||
|
|
||||||
def after_create_session(self, session, coord):
|
def after_create_session(self, session, coord):
|
||||||
# Update number of steps to run in the first run call
|
# Update number of steps to run in the first run call
|
||||||
if self._num_evals is None:
|
if self._num_evals is None:
|
||||||
steps = self._steps_per_run_initial_value
|
steps = self._steps_per_run_initial_value
|
||||||
else:
|
else:
|
||||||
steps = min(self._steps_per_run_initial_value, self._num_evals)
|
steps = min(self._steps_per_run_initial_value, self._num_evals)
|
||||||
self._steps_per_run_variable.load(steps, session=session)
|
self._steps_per_run_variable.load(steps, session=session)
|
||||||
|
|
||||||
def before_run(self, run_context):
|
def before_run(self, run_context):
|
||||||
return session_run_hook.SessionRunArgs({
|
return session_run_hook.SessionRunArgs(
|
||||||
'evals_completed': self._evals_completed
|
{'evals_completed': self._evals_completed})
|
||||||
})
|
|
||||||
|
|
||||||
def after_run(self, run_context, run_values):
|
def after_run(self, run_context, run_values):
|
||||||
evals_completed = run_values.results['evals_completed']
|
evals_completed = run_values.results['evals_completed']
|
||||||
# Update number of steps to run in the next iteration
|
# Update number of steps to run in the next iteration
|
||||||
if self._num_evals is None:
|
if self._num_evals is None:
|
||||||
steps = self._steps_per_run_initial_value
|
steps = self._steps_per_run_initial_value
|
||||||
else:
|
else:
|
||||||
steps = min(self._num_evals - evals_completed,
|
steps = min(self._num_evals - evals_completed,
|
||||||
@ -147,16 +146,15 @@ class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
|
|||||||
self._evals_completed = None
|
self._evals_completed = None
|
||||||
self._log_progress = log_progress
|
self._log_progress = log_progress
|
||||||
# Reduce logging frequency if there are 20 or more evaluations.
|
# Reduce logging frequency if there are 20 or more evaluations.
|
||||||
self._log_frequency = (1 if (num_evals is None or num_evals < 20)
|
self._log_frequency = (1 if (num_evals is None or num_evals < 20) else
|
||||||
else math.floor(num_evals / 10.))
|
math.floor(num_evals / 10.))
|
||||||
|
|
||||||
def _set_evals_completed_tensor(self, updated_eval_step):
|
def _set_evals_completed_tensor(self, updated_eval_step):
|
||||||
self._evals_completed = updated_eval_step
|
self._evals_completed = updated_eval_step
|
||||||
|
|
||||||
def before_run(self, run_context):
|
def before_run(self, run_context):
|
||||||
return session_run_hook.SessionRunArgs({
|
return session_run_hook.SessionRunArgs(
|
||||||
'evals_completed': self._evals_completed
|
{'evals_completed': self._evals_completed})
|
||||||
})
|
|
||||||
|
|
||||||
def after_run(self, run_context, run_values):
|
def after_run(self, run_context, run_values):
|
||||||
evals_completed = run_values.results['evals_completed']
|
evals_completed = run_values.results['evals_completed']
|
||||||
@ -205,20 +203,20 @@ def _evaluate_once(checkpoint_path,
|
|||||||
Args:
|
Args:
|
||||||
checkpoint_path: The path to a checkpoint to use for evaluation.
|
checkpoint_path: The path to a checkpoint to use for evaluation.
|
||||||
master: The BNS address of the TensorFlow master.
|
master: The BNS address of the TensorFlow master.
|
||||||
scaffold: An tf.train.Scaffold instance for initializing variables and
|
scaffold: An tf.compat.v1.train.Scaffold instance for initializing variables
|
||||||
restoring variables. Note that `scaffold.init_fn` is used by the function
|
and restoring variables. Note that `scaffold.init_fn` is used by the
|
||||||
to restore the checkpoint. If you supply a custom init_fn, then it must
|
function to restore the checkpoint. If you supply a custom init_fn, then
|
||||||
also take care of restoring the model from its checkpoint.
|
it must also take care of restoring the model from its checkpoint.
|
||||||
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
eval_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to
|
||||||
to `Tensors`, which is run until the session is requested to stop,
|
`Tensors`, which is run until the session is requested to stop, commonly
|
||||||
commonly done by a `tf.contrib.training.StopAfterNEvalsHook`.
|
done by a `tf.contrib.training.StopAfterNEvalsHook`.
|
||||||
feed_dict: The feed dictionary to use when executing the `eval_ops`.
|
feed_dict: The feed dictionary to use when executing the `eval_ops`.
|
||||||
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
|
||||||
to `Tensors`.
|
to `Tensors`.
|
||||||
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
|
final_ops_feed_dict: A feed dictionary to use when evaluating `final_ops`.
|
||||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside
|
||||||
evaluation loop.
|
the evaluation loop.
|
||||||
config: An instance of `tf.ConfigProto` that will be used to
|
config: An instance of `tf.compat.v1.ConfigProto` that will be used to
|
||||||
configure the `Session`. If left as `None`, the default will be used.
|
configure the `Session`. If left as `None`, the default will be used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -263,8 +261,8 @@ def _evaluate_once(checkpoint_path,
|
|||||||
master=master,
|
master=master,
|
||||||
config=config)
|
config=config)
|
||||||
|
|
||||||
final_ops_hook = basic_session_run_hooks.FinalOpsHook(
|
final_ops_hook = basic_session_run_hooks.FinalOpsHook(final_ops,
|
||||||
final_ops, final_ops_feed_dict)
|
final_ops_feed_dict)
|
||||||
hooks.append(final_ops_hook)
|
hooks.append(final_ops_hook)
|
||||||
|
|
||||||
with monitored_session.MonitoredSession(
|
with monitored_session.MonitoredSession(
|
||||||
|
@ -1090,7 +1090,7 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
|
|||||||
|
|
||||||
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
||||||
dictionaries of tensors. Each element in the list is treated similarly
|
dictionaries of tensors. Each element in the list is treated similarly
|
||||||
to the `tensors` argument of `tf.train.batch()`.
|
to the `tensors` argument of `tf.compat.v1.train.batch()`.
|
||||||
|
|
||||||
WARNING: This function is nondeterministic, since it starts a separate thread
|
WARNING: This function is nondeterministic, since it starts a separate thread
|
||||||
for each tensor.
|
for each tensor.
|
||||||
@ -1284,7 +1284,7 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
# Creates batches of 32 images and 32 labels.
|
# Creates batches of 32 images and 32 labels.
|
||||||
image_batch, label_batch = tf.train.shuffle_batch(
|
image_batch, label_batch = tf.compat.v1.train.shuffle_batch(
|
||||||
[single_image, single_label],
|
[single_image, single_label],
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_threads=4,
|
num_threads=4,
|
||||||
@ -1425,7 +1425,7 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
|
|||||||
|
|
||||||
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
The `tensors_list` argument is a list of tuples of tensors, or a list of
|
||||||
dictionaries of tensors. Each element in the list is treated similarly
|
dictionaries of tensors. Each element in the list is treated similarly
|
||||||
to the `tensors` argument of `tf.train.shuffle_batch()`.
|
to the `tensors` argument of `tf.compat.v1.train.shuffle_batch()`.
|
||||||
|
|
||||||
This version enqueues a different list of tensors in different threads.
|
This version enqueues a different list of tensors in different threads.
|
||||||
It adds the following to the current `Graph`:
|
It adds the following to the current `Graph`:
|
||||||
|
@ -56,24 +56,25 @@ def exponential_decay(learning_rate,
|
|||||||
...
|
...
|
||||||
global_step = tf.Variable(0, trainable=False)
|
global_step = tf.Variable(0, trainable=False)
|
||||||
starter_learning_rate = 0.1
|
starter_learning_rate = 0.1
|
||||||
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
|
learning_rate = tf.compat.v1.train.exponential_decay(starter_learning_rate,
|
||||||
|
global_step,
|
||||||
100000, 0.96, staircase=True)
|
100000, 0.96, staircase=True)
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
# Passing global_step to minimize() will increment it at each step.
|
||||||
learning_step = (
|
learning_step = (
|
||||||
tf.train.GradientDescentOptimizer(learning_rate)
|
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||||
.minimize(...my loss..., global_step=global_step)
|
.minimize(...my loss..., global_step=global_step)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Python number. The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation. Must not be negative.
|
step to use for the decay computation. Must not be negative.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
|
||||||
Must be positive. See the decay computation above.
|
be positive. See the decay computation above.
|
||||||
decay_rate: A scalar `float32` or `float64` `Tensor` or a
|
decay_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Python number. The decay rate.
|
The decay rate.
|
||||||
staircase: Boolean. If `True` decay the learning rate at discrete intervals
|
staircase: Boolean. If `True` decay the learning rate at discrete intervals
|
||||||
name: String. Optional name of the operation. Defaults to
|
name: String. Optional name of the operation. Defaults to
|
||||||
'ExponentialDecay'.
|
'ExponentialDecay'.
|
||||||
@ -91,11 +92,8 @@ def exponential_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate,
|
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||||
decay_steps,
|
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
|
||||||
decay_rate,
|
|
||||||
staircase=staircase,
|
|
||||||
name=name)
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr(global_step)
|
decayed_lr = decayed_lr(global_step)
|
||||||
else:
|
else:
|
||||||
@ -114,7 +112,8 @@ def piecewise_constant(x, boundaries, values, name=None):
|
|||||||
global_step = tf.Variable(0, trainable=False)
|
global_step = tf.Variable(0, trainable=False)
|
||||||
boundaries = [100000, 110000]
|
boundaries = [100000, 110000]
|
||||||
values = [1.0, 0.5, 0.1]
|
values = [1.0, 0.5, 0.1]
|
||||||
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
|
learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries,
|
||||||
|
values)
|
||||||
|
|
||||||
# Later, whenever we perform an optimization step, we increment global_step.
|
# Later, whenever we perform an optimization step, we increment global_step.
|
||||||
```
|
```
|
||||||
@ -202,27 +201,28 @@ def polynomial_decay(learning_rate,
|
|||||||
starter_learning_rate = 0.1
|
starter_learning_rate = 0.1
|
||||||
end_learning_rate = 0.01
|
end_learning_rate = 0.01
|
||||||
decay_steps = 10000
|
decay_steps = 10000
|
||||||
learning_rate = tf.train.polynomial_decay(starter_learning_rate, global_step,
|
learning_rate = tf.compat.v1.train.polynomial_decay(starter_learning_rate,
|
||||||
|
global_step,
|
||||||
decay_steps, end_learning_rate,
|
decay_steps, end_learning_rate,
|
||||||
power=0.5)
|
power=0.5)
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
# Passing global_step to minimize() will increment it at each step.
|
||||||
learning_step = (
|
learning_step = (
|
||||||
tf.train.GradientDescentOptimizer(learning_rate)
|
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||||
.minimize(...my loss..., global_step=global_step)
|
.minimize(...my loss..., global_step=global_step)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Python number. The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation. Must not be negative.
|
step to use for the decay computation. Must not be negative.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Must
|
||||||
Must be positive. See the decay computation above.
|
be positive. See the decay computation above.
|
||||||
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a Python
|
||||||
Python number. The minimal end learning rate.
|
number. The minimal end learning rate.
|
||||||
power: A scalar `float32` or `float64` `Tensor` or a
|
power: A scalar `float32` or `float64` `Tensor` or a Python number. The
|
||||||
Python number. The power of the polynomial. Defaults to linear, 1.0.
|
power of the polynomial. Defaults to linear, 1.0.
|
||||||
cycle: A boolean, whether or not it should cycle beyond decay_steps.
|
cycle: A boolean, whether or not it should cycle beyond decay_steps.
|
||||||
name: String. Optional name of the operation. Defaults to
|
name: String. Optional name of the operation. Defaults to
|
||||||
'PolynomialDecay'.
|
'PolynomialDecay'.
|
||||||
@ -292,21 +292,22 @@ def natural_exp_decay(learning_rate,
|
|||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
decay_steps = 5
|
decay_steps = 5
|
||||||
k = 0.5
|
k = 0.5
|
||||||
learning_rate = tf.train.natural_exp_decay(learning_rate, global_step,
|
learning_rate = tf.compat.v1.train.natural_exp_decay(learning_rate,
|
||||||
|
global_step,
|
||||||
decay_steps, k)
|
decay_steps, k)
|
||||||
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
# Passing global_step to minimize() will increment it at each step.
|
||||||
learning_step = (
|
learning_step = (
|
||||||
tf.train.GradientDescentOptimizer(learning_rate)
|
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||||
.minimize(...my loss..., global_step=global_step)
|
.minimize(...my loss..., global_step=global_step)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Python number. The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A Python number.
|
global_step: A Python number. Global step to use for the decay computation.
|
||||||
Global step to use for the decay computation. Must not be negative.
|
Must not be negative.
|
||||||
decay_steps: How often to apply decay.
|
decay_steps: How often to apply decay.
|
||||||
decay_rate: A Python number. The decay rate.
|
decay_rate: A Python number. The decay rate.
|
||||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
||||||
@ -329,7 +330,10 @@ def natural_exp_decay(learning_rate,
|
|||||||
"""
|
"""
|
||||||
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
|
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
|
||||||
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||||
learning_rate, decay_steps, natural_exp_rate, staircase=staircase,
|
learning_rate,
|
||||||
|
decay_steps,
|
||||||
|
natural_exp_rate,
|
||||||
|
staircase=staircase,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
@ -376,21 +380,22 @@ def inverse_time_decay(learning_rate,
|
|||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
decay_steps = 1.0
|
decay_steps = 1.0
|
||||||
decay_rate = 0.5
|
decay_rate = 0.5
|
||||||
learning_rate = tf.train.inverse_time_decay(learning_rate, global_step,
|
learning_rate = tf.compat.v1.train.inverse_time_decay(learning_rate,
|
||||||
|
global_step,
|
||||||
decay_steps, decay_rate)
|
decay_steps, decay_rate)
|
||||||
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
# Passing global_step to minimize() will increment it at each step.
|
||||||
learning_step = (
|
learning_step = (
|
||||||
tf.train.GradientDescentOptimizer(learning_rate)
|
tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
|
||||||
.minimize(...my loss..., global_step=global_step)
|
.minimize(...my loss..., global_step=global_step)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
learning_rate: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Python number. The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A Python number.
|
global_step: A Python number. Global step to use for the decay computation.
|
||||||
Global step to use for the decay computation. Must not be negative.
|
Must not be negative.
|
||||||
decay_steps: How often to apply decay.
|
decay_steps: How often to apply decay.
|
||||||
decay_rate: A Python number. The decay rate.
|
decay_rate: A Python number. The decay rate.
|
||||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
||||||
@ -412,11 +417,7 @@ def inverse_time_decay(learning_rate,
|
|||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_schedule.InverseTimeDecay(
|
decayed_lr = learning_rate_schedule.InverseTimeDecay(
|
||||||
learning_rate,
|
learning_rate, decay_steps, decay_rate, staircase=staircase, name=name)
|
||||||
decay_steps,
|
|
||||||
decay_rate,
|
|
||||||
staircase=staircase,
|
|
||||||
name=name)
|
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr(global_step)
|
decayed_lr = decayed_lr(global_step)
|
||||||
@ -455,13 +456,14 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
|
|||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation.
|
step to use for the decay computation.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||||
Number of steps to decay over.
|
of steps to decay over.
|
||||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
|
||||||
Minimum learning rate value as a fraction of learning_rate.
|
learning rate value as a fraction of learning_rate.
|
||||||
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
learning rate.
|
learning rate.
|
||||||
@ -519,17 +521,18 @@ def cosine_decay_restarts(learning_rate,
|
|||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation.
|
step to use for the decay computation.
|
||||||
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
||||||
Number of steps to decay over.
|
Number of steps to decay over.
|
||||||
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. Used to
|
||||||
Used to derive the number of iterations in the i-th period
|
derive the number of iterations in the i-th period
|
||||||
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
||||||
Used to derive the initial learning rate of the i-th period:
|
Used to derive the initial learning rate of the i-th period:
|
||||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
alpha: A scalar `float32` or `float64` Tensor or a Python number. Minimum
|
||||||
Minimum learning rate value as a fraction of the learning_rate.
|
learning rate value as a fraction of the learning_rate.
|
||||||
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
learning rate.
|
learning rate.
|
||||||
@ -602,16 +605,17 @@ def linear_cosine_decay(learning_rate,
|
|||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation.
|
step to use for the decay computation.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||||
Number of steps to decay over.
|
of steps to decay over.
|
||||||
num_periods: Number of periods in the cosine part of the decay.
|
num_periods: Number of periods in the cosine part of the decay. See
|
||||||
See computation above.
|
computation above.
|
||||||
alpha: See computation above.
|
alpha: See computation above.
|
||||||
beta: See computation above.
|
beta: See computation above.
|
||||||
name: String. Optional name of the operation. Defaults to
|
name: String. Optional name of the operation. Defaults to
|
||||||
'LinearCosineDecay'.
|
'LinearCosineDecay'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
learning rate.
|
learning rate.
|
||||||
@ -690,18 +694,19 @@ def noisy_linear_cosine_decay(learning_rate,
|
|||||||
Args:
|
Args:
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
||||||
The initial learning rate.
|
The initial learning rate.
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
global_step: A scalar `int32` or `int64` `Tensor` or a Python number. Global
|
||||||
Global step to use for the decay computation.
|
step to use for the decay computation.
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. Number
|
||||||
Number of steps to decay over.
|
of steps to decay over.
|
||||||
initial_variance: initial variance for the noise. See computation above.
|
initial_variance: initial variance for the noise. See computation above.
|
||||||
variance_decay: decay for the noise's variance. See computation above.
|
variance_decay: decay for the noise's variance. See computation above.
|
||||||
num_periods: Number of periods in the cosine part of the decay.
|
num_periods: Number of periods in the cosine part of the decay. See
|
||||||
See computation above.
|
computation above.
|
||||||
alpha: See computation above.
|
alpha: See computation above.
|
||||||
beta: See computation above.
|
beta: See computation above.
|
||||||
name: String. Optional name of the operation. Defaults to
|
name: String. Optional name of the operation. Defaults to
|
||||||
'NoisyLinearCosineDecay'.
|
'NoisyLinearCosineDecay'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
A scalar `Tensor` of the same type as `learning_rate`. The decayed
|
||||||
learning rate.
|
learning rate.
|
||||||
|
@ -77,7 +77,8 @@ class Scaffold(object):
|
|||||||
The following pieces are directly accessible as attributes of the `Scaffold`
|
The following pieces are directly accessible as attributes of the `Scaffold`
|
||||||
object:
|
object:
|
||||||
|
|
||||||
* `saver`: A `tf.train.Saver` object taking care of saving the variables.
|
* `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
|
||||||
|
variables.
|
||||||
Picked from and stored into the `SAVERS` collection in the graph by default.
|
Picked from and stored into the `SAVERS` collection in the graph by default.
|
||||||
* `init_op`: An op to run to initialize the variables. Picked from and
|
* `init_op`: An op to run to initialize the variables. Picked from and
|
||||||
stored into the `INIT_OP` collection in the graph by default.
|
stored into the `INIT_OP` collection in the graph by default.
|
||||||
@ -133,9 +134,9 @@ class Scaffold(object):
|
|||||||
local_init_op: Optional op to initialize local variables.
|
local_init_op: Optional op to initialize local variables.
|
||||||
summary_op: Optional op to gather all summaries. Must return a scalar
|
summary_op: Optional op to gather all summaries. Must return a scalar
|
||||||
string tensor containing a serialized `Summary` proto.
|
string tensor containing a serialized `Summary` proto.
|
||||||
saver: Optional `tf.train.Saver` object to use to save and restore
|
saver: Optional `tf.compat.v1.train.Saver` object to use to save and
|
||||||
variables. May also be a `tf.train.Checkpoint` object, in which case
|
restore variables. May also be a `tf.train.Checkpoint` object, in which
|
||||||
object-based checkpoints are saved. This will also load some
|
case object-based checkpoints are saved. This will also load some
|
||||||
object-based checkpoints saved from elsewhere, but that loading may be
|
object-based checkpoints saved from elsewhere, but that loading may be
|
||||||
fragile since it uses fixed keys rather than performing a full
|
fragile since it uses fixed keys rather than performing a full
|
||||||
graph-based match. For example if a variable has two paths from the
|
graph-based match. For example if a variable has two paths from the
|
||||||
@ -199,8 +200,9 @@ class Scaffold(object):
|
|||||||
resources.report_uninitialized_resources()
|
resources.report_uninitialized_resources()
|
||||||
], 0)
|
], 0)
|
||||||
|
|
||||||
self._ready_op = Scaffold.get_or_default(
|
self._ready_op = Scaffold.get_or_default('ready_op',
|
||||||
'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
|
ops.GraphKeys.READY_OP,
|
||||||
|
default_ready_op)
|
||||||
if self._ready_for_local_init_op is None:
|
if self._ready_for_local_init_op is None:
|
||||||
|
|
||||||
def default_ready_for_local_init_op():
|
def default_ready_for_local_init_op():
|
||||||
@ -219,8 +221,9 @@ class Scaffold(object):
|
|||||||
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
|
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
|
||||||
Scaffold.default_local_init_op)
|
Scaffold.default_local_init_op)
|
||||||
if self._summary_op is None:
|
if self._summary_op is None:
|
||||||
self._summary_op = Scaffold.get_or_default(
|
self._summary_op = Scaffold.get_or_default('summary_op',
|
||||||
'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all)
|
ops.GraphKeys.SUMMARY_OP,
|
||||||
|
summary.merge_all)
|
||||||
# pylint: disable=g-long-lambda
|
# pylint: disable=g-long-lambda
|
||||||
if self._saver is None:
|
if self._saver is None:
|
||||||
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
|
self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access
|
||||||
@ -292,7 +295,8 @@ class Scaffold(object):
|
|||||||
|
|
||||||
This op is used during session initialization when a Scaffold is
|
This op is used during session initialization when a Scaffold is
|
||||||
initialized without specifying the local_init_op arg. It includes
|
initialized without specifying the local_init_op arg. It includes
|
||||||
`tf.local_variables_initializer`, `tf.tables_initializer`, and also
|
`tf.compat.v1.local_variables_initializer`,
|
||||||
|
`tf.compat.v1.tables_initializer`, and also
|
||||||
initializes local session resources.
|
initializes local session resources.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -435,7 +439,8 @@ def MonitoredTrainingSession(
|
|||||||
For a chief, this utility sets proper session initializer/restorer. It also
|
For a chief, this utility sets proper session initializer/restorer. It also
|
||||||
creates hooks related to checkpoint and summary saving. For workers, this
|
creates hooks related to checkpoint and summary saving. For workers, this
|
||||||
utility sets proper session creator which waits for the chief to
|
utility sets proper session creator which waits for the chief to
|
||||||
initialize/restore. Please check `tf.train.MonitoredSession` for more
|
initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
|
||||||
|
more
|
||||||
information.
|
information.
|
||||||
|
|
||||||
|
|
||||||
@ -464,8 +469,9 @@ def MonitoredTrainingSession(
|
|||||||
to disk using a default summary saver. If both `save_summaries_steps` and
|
to disk using a default summary saver. If both `save_summaries_steps` and
|
||||||
`save_summaries_secs` are set to `None`, then the default summary saver
|
`save_summaries_secs` are set to `None`, then the default summary saver
|
||||||
isn't used. Default not enabled.
|
isn't used. Default not enabled.
|
||||||
config: an instance of `tf.ConfigProto` proto used to configure the session.
|
config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
|
||||||
It's the `config` argument of constructor of `tf.Session`.
|
the session. It's the `config` argument of constructor of
|
||||||
|
`tf.compat.v1.Session`.
|
||||||
stop_grace_period_secs: Number of seconds given to threads to stop after
|
stop_grace_period_secs: Number of seconds given to threads to stop after
|
||||||
`close()` has been called.
|
`close()` has been called.
|
||||||
log_step_count_steps: The frequency, in number of global steps, that the
|
log_step_count_steps: The frequency, in number of global steps, that the
|
||||||
@ -591,7 +597,7 @@ class SessionCreator(object):
|
|||||||
|
|
||||||
@tf_export(v1=['train.ChiefSessionCreator'])
|
@tf_export(v1=['train.ChiefSessionCreator'])
|
||||||
class ChiefSessionCreator(SessionCreator):
|
class ChiefSessionCreator(SessionCreator):
|
||||||
"""Creates a tf.Session for a chief."""
|
"""Creates a tf.compat.v1.Session for a chief."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
scaffold=None,
|
scaffold=None,
|
||||||
@ -643,7 +649,7 @@ class ChiefSessionCreator(SessionCreator):
|
|||||||
|
|
||||||
@tf_export(v1=['train.WorkerSessionCreator'])
|
@tf_export(v1=['train.WorkerSessionCreator'])
|
||||||
class WorkerSessionCreator(SessionCreator):
|
class WorkerSessionCreator(SessionCreator):
|
||||||
"""Creates a tf.Session for a worker."""
|
"""Creates a tf.compat.v1.Session for a worker."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
scaffold=None,
|
scaffold=None,
|
||||||
@ -757,8 +763,9 @@ class _MonitoredSession(object):
|
|||||||
`step_fn` will be returned from `run_step_fn`, unless a stop is
|
`step_fn` will be returned from `run_step_fn`, unless a stop is
|
||||||
requested. In that case, the next `should_stop` call will return True.
|
requested. In that case, the next `should_stop` call will return True.
|
||||||
Example usage: ```python
|
Example usage: ```python
|
||||||
with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v =
|
with tf.Graph().as_default(): c =
|
||||||
tf.add(c, 4.0) w = tf.add(c, 0.5)
|
tf.compat.v1.placeholder(dtypes.float32) v = tf.add(c, 4.0) w =
|
||||||
|
tf.add(c, 0.5)
|
||||||
def step_fn(step_context):
|
def step_fn(step_context):
|
||||||
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
|
a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
|
||||||
if a <= 4.5: step_context.request_stop()
|
if a <= 4.5: step_context.request_stop()
|
||||||
@ -808,7 +815,7 @@ class _MonitoredSession(object):
|
|||||||
"""Initializes the `step_context` argument for a `step_fn` invocation.
|
"""Initializes the `step_context` argument for a `step_fn` invocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: An instance of `tf.Session`.
|
session: An instance of `tf.compat.v1.Session`.
|
||||||
run_with_hooks_fn: A function for running fetches and hooks.
|
run_with_hooks_fn: A function for running fetches and hooks.
|
||||||
"""
|
"""
|
||||||
self._session = session
|
self._session = session
|
||||||
@ -901,13 +908,13 @@ class _MonitoredSession(object):
|
|||||||
return self._coordinated_creator.tf_sess is None
|
return self._coordinated_creator.tf_sess is None
|
||||||
|
|
||||||
def _tf_sess(self):
|
def _tf_sess(self):
|
||||||
"""Return underlying tf.Session object.
|
"""Return underlying tf.compat.v1.Session object.
|
||||||
|
|
||||||
Warning: accessing the returned object in user code is likely to cause races
|
Warning: accessing the returned object in user code is likely to cause races
|
||||||
or "flaky tests".
|
or "flaky tests".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tf.Session object.
|
A tf.compat.v1.Session object.
|
||||||
"""
|
"""
|
||||||
return self._coordinated_creator.tf_sess
|
return self._coordinated_creator.tf_sess
|
||||||
|
|
||||||
@ -955,7 +962,7 @@ class MonitoredSession(_MonitoredSession):
|
|||||||
* suppresses `OutOfRange` error which indicates that all inputs have been
|
* suppresses `OutOfRange` error which indicates that all inputs have been
|
||||||
processed if the monitored_session is used as a context
|
processed if the monitored_session is used as a context
|
||||||
|
|
||||||
How to set `tf.Session` arguments:
|
How to set `tf.compat.v1.Session` arguments:
|
||||||
|
|
||||||
* In most cases you can set session arguments as follows:
|
* In most cases you can set session arguments as follows:
|
||||||
|
|
||||||
@ -973,7 +980,8 @@ class MonitoredSession(_MonitoredSession):
|
|||||||
|
|
||||||
See `MonitoredTrainingSession` for an example usage based on chief or worker.
|
See `MonitoredTrainingSession` for an example usage based on chief or worker.
|
||||||
|
|
||||||
Note: This is not a `tf.Session`. For example, it cannot do following:
|
Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
|
||||||
|
following:
|
||||||
|
|
||||||
* it cannot be set as default session.
|
* it cannot be set as default session.
|
||||||
* it cannot be sent to saver.save.
|
* it cannot be sent to saver.save.
|
||||||
@ -1004,14 +1012,15 @@ class SingularMonitoredSession(_MonitoredSession):
|
|||||||
"""Session-like object that handles initialization, restoring, and hooks.
|
"""Session-like object that handles initialization, restoring, and hooks.
|
||||||
|
|
||||||
Please note that this utility is not recommended for distributed settings.
|
Please note that this utility is not recommended for distributed settings.
|
||||||
For distributed settings, please use `tf.train.MonitoredSession`. The
|
For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
|
||||||
|
The
|
||||||
differences between `MonitoredSession` and `SingularMonitoredSession` are:
|
differences between `MonitoredSession` and `SingularMonitoredSession` are:
|
||||||
|
|
||||||
* `MonitoredSession` handles `AbortedError` and `UnavailableError` for
|
* `MonitoredSession` handles `AbortedError` and `UnavailableError` for
|
||||||
distributed settings, but `SingularMonitoredSession` does not.
|
distributed settings, but `SingularMonitoredSession` does not.
|
||||||
* `MonitoredSession` can be created in `chief` or `worker` modes.
|
* `MonitoredSession` can be created in `chief` or `worker` modes.
|
||||||
`SingularMonitoredSession` is always created as `chief`.
|
`SingularMonitoredSession` is always created as `chief`.
|
||||||
* You can access the raw `tf.Session` object used by
|
* You can access the raw `tf.compat.v1.Session` object used by
|
||||||
`SingularMonitoredSession`, whereas in MonitoredSession the raw session is
|
`SingularMonitoredSession`, whereas in MonitoredSession the raw session is
|
||||||
private. This can be used:
|
private. This can be used:
|
||||||
- To `run` without hooks.
|
- To `run` without hooks.
|
||||||
@ -1093,7 +1102,7 @@ class SingularMonitoredSession(_MonitoredSession):
|
|||||||
|
|
||||||
|
|
||||||
class _WrappedSession(object):
|
class _WrappedSession(object):
|
||||||
"""Wrapper around a `tf.Session`.
|
"""Wrapper around a `tf.compat.v1.Session`.
|
||||||
|
|
||||||
This wrapper is used as a base class for various session wrappers
|
This wrapper is used as a base class for various session wrappers
|
||||||
that provide additional functionality such as monitoring, coordination,
|
that provide additional functionality such as monitoring, coordination,
|
||||||
@ -1108,7 +1117,8 @@ class _WrappedSession(object):
|
|||||||
"""Creates a `_WrappedSession`.
|
"""Creates a `_WrappedSession`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: A `tf.Session` or `_WrappedSession` object. The wrapped session.
|
sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped
|
||||||
|
session.
|
||||||
"""
|
"""
|
||||||
self._sess = sess
|
self._sess = sess
|
||||||
self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
|
self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
|
||||||
@ -1293,7 +1303,7 @@ class _CoordinatedSession(_WrappedSession):
|
|||||||
"""Create a new `_CoordinatedSession`.
|
"""Create a new `_CoordinatedSession`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: A `tf.Session` object. The wrapped session.
|
sess: A `tf.compat.v1.Session` object. The wrapped session.
|
||||||
coord: A `tf.train.Coordinator` object.
|
coord: A `tf.train.Coordinator` object.
|
||||||
stop_grace_period_secs: Number of seconds given to threads to stop after
|
stop_grace_period_secs: Number of seconds given to threads to stop after
|
||||||
`close()` has been called.
|
`close()` has been called.
|
||||||
@ -1364,7 +1374,7 @@ class _HookedSession(_WrappedSession):
|
|||||||
"""Initializes a _HookedSession object.
|
"""Initializes a _HookedSession object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: A `tf.Session` or a `_WrappedSession` object.
|
sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
|
||||||
hooks: An iterable of `SessionRunHook' objects.
|
hooks: An iterable of `SessionRunHook' objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -56,9 +56,9 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||||||
E.g.:
|
E.g.:
|
||||||
|
|
||||||
```
|
```
|
||||||
with tf.variable_scope('scope1'):
|
with tf.compat.v1.variable_scope('scope1'):
|
||||||
with tf.variable_scope('scope2'):
|
with tf.compat.v1.variable_scope('scope2'):
|
||||||
var = tf.get_variable('foo')
|
var = tf.compat.v1.get_variable('foo')
|
||||||
update_1 = tf.assign_moving_average(var, 0.0, 1.0)
|
update_1 = tf.assign_moving_average(var, 0.0, 1.0)
|
||||||
update_2 = tf.assign_moving_average(var, 0.0, 0.9)
|
update_2 = tf.assign_moving_average(var, 0.0, 0.9)
|
||||||
|
|
||||||
@ -73,12 +73,13 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||||||
decay: A float Tensor or float value. The moving average decay.
|
decay: A float Tensor or float value. The moving average decay.
|
||||||
zero_debias: A python bool. If true, assume the variable is 0-initialized
|
zero_debias: A python bool. If true, assume the variable is 0-initialized
|
||||||
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
|
and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
|
||||||
`_zero_debias` for more details.
|
`_zero_debias` for more details.
|
||||||
name: Optional name of the returned operation.
|
name: Optional name of the returned operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor which if evaluated will compute and return the new moving average.
|
A tensor which if evaluated will compute and return the new moving average.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update_fn(v, value, decay=decay):
|
def update_fn(v, value, decay=decay):
|
||||||
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
decay = ops.convert_to_tensor(1.0 - decay, name="decay")
|
||||||
if decay.dtype != v.dtype.base_dtype:
|
if decay.dtype != v.dtype.base_dtype:
|
||||||
@ -96,8 +97,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
|
|||||||
# In a replica context, we update variable using the mean of value across
|
# In a replica context, we update variable using the mean of value across
|
||||||
# replicas.
|
# replicas.
|
||||||
def merge_fn(strategy, v, value):
|
def merge_fn(strategy, v, value):
|
||||||
value = strategy.extended.reduce_to(
|
value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
|
||||||
ds_reduce_util.ReduceOp.MEAN, value, v)
|
v)
|
||||||
return strategy.extended.update(v, update_fn, args=(value,))
|
return strategy.extended.update(v, update_fn, args=(value,))
|
||||||
|
|
||||||
return replica_context.merge_call(merge_fn, args=(variable, value))
|
return replica_context.merge_call(merge_fn, args=(variable, value))
|
||||||
@ -124,15 +125,15 @@ def weighted_moving_average(value,
|
|||||||
Args:
|
Args:
|
||||||
value: A numeric `Tensor`.
|
value: A numeric `Tensor`.
|
||||||
decay: A float `Tensor` or float value. The moving average decay.
|
decay: A float `Tensor` or float value. The moving average decay.
|
||||||
weight: `Tensor` that keeps the current value of a weight.
|
weight: `Tensor` that keeps the current value of a weight. Shape should be
|
||||||
Shape should be able to multiply `value`.
|
able to multiply `value`.
|
||||||
truediv: Boolean, if `True`, dividing by `moving_average(weight)` is
|
truediv: Boolean, if `True`, dividing by `moving_average(weight)` is
|
||||||
floating point division. If `False`, use division implied by dtypes.
|
floating point division. If `False`, use division implied by dtypes.
|
||||||
collections: List of graph collections keys to add the internal variables
|
collections: List of graph collections keys to add the internal variables
|
||||||
`value * weight` and `weight` to.
|
`value * weight` and `weight` to. Defaults to
|
||||||
Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
`[GraphKeys.GLOBAL_VARIABLES]`.
|
||||||
name: Optional name of the returned operation.
|
name: Optional name of the returned operation. Defaults to
|
||||||
Defaults to "WeightedMovingAvg".
|
"WeightedMovingAvg".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Operation that updates and returns the weighted moving average.
|
An Operation that updates and returns the weighted moving average.
|
||||||
@ -203,27 +204,34 @@ def _zero_debias(unbiased_var, value, decay):
|
|||||||
tensor will also update the shadow variables appropriately.
|
tensor will also update the shadow variables appropriately.
|
||||||
"""
|
"""
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
unbiased_var.name[:-len(":0")], values=[unbiased_var,
|
unbiased_var.name[:-len(":0")], values=[unbiased_var, value,
|
||||||
value, decay]) as scope:
|
decay]) as scope:
|
||||||
with ops.colocate_with(unbiased_var):
|
with ops.colocate_with(unbiased_var):
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
biased_initializer = init_ops.zeros_initializer(
|
biased_initializer = init_ops.zeros_initializer(
|
||||||
dtype=unbiased_var.dtype)(unbiased_var.get_shape())
|
dtype=unbiased_var.dtype)(
|
||||||
|
unbiased_var.get_shape())
|
||||||
local_step_initializer = init_ops.zeros_initializer()
|
local_step_initializer = init_ops.zeros_initializer()
|
||||||
|
|
||||||
def _maybe_get_unique(name):
|
def _maybe_get_unique(name):
|
||||||
"""Get name for a unique variable, if not `reuse=True`."""
|
"""Get name for a unique variable, if not `reuse=True`."""
|
||||||
if variable_scope.get_variable_scope().reuse:
|
if variable_scope.get_variable_scope().reuse:
|
||||||
return name
|
return name
|
||||||
vs_vars = [x.op.name for x in
|
vs_vars = [
|
||||||
variable_scope.get_variable_scope().global_variables()]
|
x.op.name
|
||||||
|
for x in variable_scope.get_variable_scope().global_variables()
|
||||||
|
]
|
||||||
full_name = variable_scope.get_variable_scope().name + "/" + name
|
full_name = variable_scope.get_variable_scope().name + "/" + name
|
||||||
if full_name not in vs_vars: return name
|
if full_name not in vs_vars:
|
||||||
|
return name
|
||||||
idx = 1
|
idx = 1
|
||||||
while full_name + ("_%d" % idx) in vs_vars:
|
while full_name + ("_%d" % idx) in vs_vars:
|
||||||
idx += 1
|
idx += 1
|
||||||
return name + ("_%d" % idx)
|
return name + ("_%d" % idx)
|
||||||
|
|
||||||
biased_var = variable_scope.get_variable(
|
biased_var = variable_scope.get_variable(
|
||||||
_maybe_get_unique("biased"), initializer=biased_initializer,
|
_maybe_get_unique("biased"),
|
||||||
|
initializer=biased_initializer,
|
||||||
trainable=False)
|
trainable=False)
|
||||||
local_step = variable_scope.get_variable(
|
local_step = variable_scope.get_variable(
|
||||||
_maybe_get_unique("local_step"),
|
_maybe_get_unique("local_step"),
|
||||||
@ -233,18 +241,17 @@ def _zero_debias(unbiased_var, value, decay):
|
|||||||
trainable=False)
|
trainable=False)
|
||||||
|
|
||||||
# Get an update ops for both shadow variables.
|
# Get an update ops for both shadow variables.
|
||||||
update_biased = state_ops.assign_sub(biased_var,
|
update_biased = state_ops.assign_sub(
|
||||||
(biased_var - value) * decay,
|
biased_var, (biased_var - value) * decay, name=scope.name)
|
||||||
name=scope.name)
|
|
||||||
update_local_step = local_step.assign_add(1)
|
update_local_step = local_step.assign_add(1)
|
||||||
|
|
||||||
# Compute the value of the delta to update the unbiased EMA. Make sure to
|
# Compute the value of the delta to update the unbiased EMA. Make sure to
|
||||||
# use the new values of the biased variable and the local step.
|
# use the new values of the biased variable and the local step.
|
||||||
with ops.control_dependencies([update_biased, update_local_step]):
|
with ops.control_dependencies([update_biased, update_local_step]):
|
||||||
# This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
|
# This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
|
||||||
unbiased_ema_delta = (unbiased_var - biased_var.read_value() /
|
unbiased_ema_delta = (
|
||||||
(1 - math_ops.pow(
|
unbiased_var - biased_var.read_value() /
|
||||||
1.0 - decay, local_step.read_value())))
|
(1 - math_ops.pow(1.0 - decay, local_step.read_value())))
|
||||||
|
|
||||||
return unbiased_ema_delta
|
return unbiased_ema_delta
|
||||||
|
|
||||||
@ -315,7 +322,7 @@ class ExponentialMovingAverage(object):
|
|||||||
for a given variable.
|
for a given variable.
|
||||||
* Build a model normally but load the checkpoint files to evaluate by using
|
* Build a model normally but load the checkpoint files to evaluate by using
|
||||||
the shadow variable names. For this use the `average_name()` method. See
|
the shadow variable names. For this use the `average_name()` method. See
|
||||||
the `tf.train.Saver` for more
|
the `tf.compat.v1.train.Saver` for more
|
||||||
information on restoring saved variables.
|
information on restoring saved variables.
|
||||||
|
|
||||||
Example of restoring the shadow variable values:
|
Example of restoring the shadow variable values:
|
||||||
@ -324,13 +331,17 @@ class ExponentialMovingAverage(object):
|
|||||||
# Create a Saver that loads variables from their saved shadow values.
|
# Create a Saver that loads variables from their saved shadow values.
|
||||||
shadow_var0_name = ema.average_name(var0)
|
shadow_var0_name = ema.average_name(var0)
|
||||||
shadow_var1_name = ema.average_name(var1)
|
shadow_var1_name = ema.average_name(var1)
|
||||||
saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
|
saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name:
|
||||||
|
var1})
|
||||||
saver.restore(...checkpoint filename...)
|
saver.restore(...checkpoint filename...)
|
||||||
# var0 and var1 now hold the moving average values
|
# var0 and var1 now hold the moving average values
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, decay, num_updates=None, zero_debias=False,
|
def __init__(self,
|
||||||
|
decay,
|
||||||
|
num_updates=None,
|
||||||
|
zero_debias=False,
|
||||||
name="ExponentialMovingAverage"):
|
name="ExponentialMovingAverage"):
|
||||||
"""Creates a new ExponentialMovingAverage object.
|
"""Creates a new ExponentialMovingAverage object.
|
||||||
|
|
||||||
@ -376,7 +387,7 @@ class ExponentialMovingAverage(object):
|
|||||||
|
|
||||||
shadow variables are created with `trainable=False` and added to the
|
shadow variables are created with `trainable=False` and added to the
|
||||||
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
|
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
|
||||||
`tf.global_variables()`.
|
`tf.compat.v1.global_variables()`.
|
||||||
|
|
||||||
Returns an op that updates all shadow variables from the current value of
|
Returns an op that updates all shadow variables from the current value of
|
||||||
their associated variables.
|
their associated variables.
|
||||||
@ -386,8 +397,8 @@ class ExponentialMovingAverage(object):
|
|||||||
be called in a loop.
|
be called in a loop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
var_list: A list of Variable or Tensor objects. The variables
|
var_list: A list of Variable or Tensor objects. The variables and Tensors
|
||||||
and Tensors must be of types bfloat16, float16, float32, or float64.
|
must be of types bfloat16, float16, float32, or float64.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Operation that updates the moving averages.
|
An Operation that updates the moving averages.
|
||||||
@ -417,10 +428,11 @@ class ExponentialMovingAverage(object):
|
|||||||
# tensors, we rely on the existing device allocation mechanism.
|
# tensors, we rely on the existing device allocation mechanism.
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
if isinstance(var, variables.Variable):
|
if isinstance(var, variables.Variable):
|
||||||
avg = slot_creator.create_slot(var,
|
avg = slot_creator.create_slot(
|
||||||
var.initialized_value(),
|
var,
|
||||||
self.name,
|
var.initialized_value(),
|
||||||
colocate_with_primary=True)
|
self.name,
|
||||||
|
colocate_with_primary=True)
|
||||||
# NOTE(mrry): We only add `tf.Variable` objects to the
|
# NOTE(mrry): We only add `tf.Variable` objects to the
|
||||||
# `MOVING_AVERAGE_VARIABLES` collection.
|
# `MOVING_AVERAGE_VARIABLES` collection.
|
||||||
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
|
||||||
@ -428,9 +440,9 @@ class ExponentialMovingAverage(object):
|
|||||||
avg = slot_creator.create_zeros_slot(
|
avg = slot_creator.create_zeros_slot(
|
||||||
var,
|
var,
|
||||||
self.name,
|
self.name,
|
||||||
colocate_with_primary=(var.op.type in ["Variable",
|
colocate_with_primary=(var.op.type in [
|
||||||
"VariableV2",
|
"Variable", "VariableV2", "VarHandleOp"
|
||||||
"VarHandleOp"]))
|
]))
|
||||||
if self._zero_debias:
|
if self._zero_debias:
|
||||||
zero_debias_true.add(avg)
|
zero_debias_true.add(avg)
|
||||||
self._averages[var] = avg
|
self._averages[var] = avg
|
||||||
@ -438,16 +450,16 @@ class ExponentialMovingAverage(object):
|
|||||||
with ops.name_scope(self.name) as scope:
|
with ops.name_scope(self.name) as scope:
|
||||||
decay = ops.convert_to_tensor(self._decay, name="decay")
|
decay = ops.convert_to_tensor(self._decay, name="decay")
|
||||||
if self._num_updates is not None:
|
if self._num_updates is not None:
|
||||||
num_updates = math_ops.cast(self._num_updates,
|
num_updates = math_ops.cast(
|
||||||
dtypes.float32,
|
self._num_updates, dtypes.float32, name="num_updates")
|
||||||
name="num_updates")
|
|
||||||
decay = math_ops.minimum(decay,
|
decay = math_ops.minimum(decay,
|
||||||
(1.0 + num_updates) / (10.0 + num_updates))
|
(1.0 + num_updates) / (10.0 + num_updates))
|
||||||
updates = []
|
updates = []
|
||||||
for var in var_list:
|
for var in var_list:
|
||||||
zero_debias = self._averages[var] in zero_debias_true
|
zero_debias = self._averages[var] in zero_debias_true
|
||||||
updates.append(assign_moving_average(
|
updates.append(
|
||||||
self._averages[var], var, decay, zero_debias=zero_debias))
|
assign_moving_average(
|
||||||
|
self._averages[var], var, decay, zero_debias=zero_debias))
|
||||||
return control_flow_ops.group(*updates, name=scope)
|
return control_flow_ops.group(*updates, name=scope)
|
||||||
|
|
||||||
def average(self, var):
|
def average(self, var):
|
||||||
@ -472,7 +484,7 @@ class ExponentialMovingAverage(object):
|
|||||||
To restore variables, you have to know the name of the shadow variables.
|
To restore variables, you have to know the name of the shadow variables.
|
||||||
That name and the original variable can then be passed to a `Saver()` object
|
That name and the original variable can then be passed to a `Saver()` object
|
||||||
to restore the variable from the moving average value with:
|
to restore the variable from the moving average value with:
|
||||||
`saver = tf.train.Saver({ema.average_name(var): var})`
|
`saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
|
||||||
|
|
||||||
`average_name()` can be called whether or not `apply()` has been called.
|
`average_name()` can be called whether or not `apply()` has been called.
|
||||||
|
|
||||||
@ -499,7 +511,7 @@ class ExponentialMovingAverage(object):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
variables_to_restore = ema.variables_to_restore()
|
variables_to_restore = ema.variables_to_restore()
|
||||||
saver = tf.train.Saver(variables_to_restore)
|
saver = tf.compat.v1.train.Saver(variables_to_restore)
|
||||||
```
|
```
|
||||||
|
|
||||||
Below is an example of such mapping:
|
Below is an example of such mapping:
|
||||||
|
@ -434,14 +434,14 @@ def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sess` is None and there isn't any default session.
|
ValueError: if `sess` is None and there isn't any default session.
|
||||||
TypeError: if `sess` is not a `tf.Session` object.
|
TypeError: if `sess` is not a `tf.compat.v1.Session` object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of threads.
|
A list of threads.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If called with eager execution enabled.
|
RuntimeError: If called with eager execution enabled.
|
||||||
ValueError: If called without a default `tf.Session` registered.
|
ValueError: If called without a default `tf.compat.v1.Session` registered.
|
||||||
|
|
||||||
@compatibility(eager)
|
@compatibility(eager)
|
||||||
Not compatible with eager execution. To ingest data under eager execution,
|
Not compatible with eager execution. To ingest data under eager execution,
|
||||||
|
@ -56,7 +56,6 @@ from tensorflow.python.training.tracking import base as trackable
|
|||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# TODO(allenl): Remove these aliases once all users are migrated off.
|
# TODO(allenl): Remove these aliases once all users are migrated off.
|
||||||
get_checkpoint_state = checkpoint_management.get_checkpoint_state
|
get_checkpoint_state = checkpoint_management.get_checkpoint_state
|
||||||
update_checkpoint_state = checkpoint_management.update_checkpoint_state
|
update_checkpoint_state = checkpoint_management.update_checkpoint_state
|
||||||
@ -174,13 +173,11 @@ class BaseSaverBuilder(object):
|
|||||||
tensors = []
|
tensors = []
|
||||||
for spec in saveable.specs:
|
for spec in saveable.specs:
|
||||||
tensors.append(
|
tensors.append(
|
||||||
io_ops.restore_v2(
|
io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
|
||||||
filename_tensor,
|
[spec.dtype])[0])
|
||||||
[spec.name],
|
|
||||||
[spec.slice_spec],
|
|
||||||
[spec.dtype])[0])
|
|
||||||
|
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
# pylint: enable=unused-argument
|
# pylint: enable=unused-argument
|
||||||
|
|
||||||
def sharded_filename(self, filename_tensor, shard, num_shards):
|
def sharded_filename(self, filename_tensor, shard, num_shards):
|
||||||
@ -217,8 +214,8 @@ class BaseSaverBuilder(object):
|
|||||||
from each device.
|
from each device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A
|
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
|
||||||
FILENAME*, but as a prefix of a V2 checkpoint;
|
but as a prefix of a V2 checkpoint;
|
||||||
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
|
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
|
||||||
returned by _GroupByDevices().
|
returned by _GroupByDevices().
|
||||||
|
|
||||||
@ -319,8 +316,8 @@ class BaseSaverBuilder(object):
|
|||||||
saveables: A list of SaveableObject objects.
|
saveables: A list of SaveableObject objects.
|
||||||
restore_sequentially: True if we want to restore variables sequentially
|
restore_sequentially: True if we want to restore variables sequentially
|
||||||
within a shard.
|
within a shard.
|
||||||
reshape: True if we want to reshape loaded tensors to the shape of
|
reshape: True if we want to reshape loaded tensors to the shape of the
|
||||||
the corresponding variable.
|
corresponding variable.
|
||||||
preferred_shard: Shard to open first when loading a sharded file.
|
preferred_shard: Shard to open first when loading a sharded file.
|
||||||
name: Name for the returned op.
|
name: Name for the returned op.
|
||||||
|
|
||||||
@ -361,12 +358,12 @@ class BaseSaverBuilder(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename_tensor: Tensor for the path of the file to load.
|
filename_tensor: Tensor for the path of the file to load.
|
||||||
per_device: A list of (device, SaveableObject) pairs, as
|
per_device: A list of (device, SaveableObject) pairs, as returned by
|
||||||
returned by _GroupByDevices().
|
_GroupByDevices().
|
||||||
restore_sequentially: True if we want to restore variables sequentially
|
restore_sequentially: True if we want to restore variables sequentially
|
||||||
within a shard.
|
within a shard.
|
||||||
reshape: True if we want to reshape loaded tensors to the shape of
|
reshape: True if we want to reshape loaded tensors to the shape of the
|
||||||
the corresponding variable.
|
corresponding variable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An Operation that restores the variables.
|
An Operation that restores the variables.
|
||||||
@ -424,14 +421,13 @@ class BaseSaverBuilder(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
names_to_saveables: A dictionary mapping name to a Variable or
|
names_to_saveables: A dictionary mapping name to a Variable or
|
||||||
SaveableObject. Each name will be associated with the
|
SaveableObject. Each name will be associated with the corresponding
|
||||||
corresponding variable in the checkpoint.
|
variable in the checkpoint.
|
||||||
reshape: If True, allow restoring parameters from a checkpoint
|
reshape: If True, allow restoring parameters from a checkpoint that where
|
||||||
that where the parameters have a different shape. This is
|
the parameters have a different shape. This is only needed when you try
|
||||||
only needed when you try to restore from a Dist-Belief checkpoint,
|
to restore from a Dist-Belief checkpoint, and only some times.
|
||||||
and only some times.
|
sharded: If True, shard the checkpoints, one per device that has Variable
|
||||||
sharded: If True, shard the checkpoints, one per device that has
|
nodes.
|
||||||
Variable nodes.
|
|
||||||
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
|
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
|
||||||
are created, old ones are deleted. If None or 0, no checkpoints are
|
are created, old ones are deleted. If None or 0, no checkpoints are
|
||||||
deleted from the filesystem but only the last one is kept in the
|
deleted from the filesystem but only the last one is kept in the
|
||||||
@ -597,8 +593,8 @@ def _get_saver_or_default():
|
|||||||
if len(savers) > 1:
|
if len(savers) > 1:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"More than one item in collection {}. "
|
"More than one item in collection {}. "
|
||||||
"Please indicate which one to use by passing it to the constructor.".
|
"Please indicate which one to use by passing it to the constructor."
|
||||||
format(collection_key))
|
.format(collection_key))
|
||||||
return savers[0]
|
return savers[0]
|
||||||
saver = Saver(sharded=True, allow_empty=True)
|
saver = Saver(sharded=True, allow_empty=True)
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
@ -662,9 +658,9 @@ class Saver(object):
|
|||||||
```python
|
```python
|
||||||
...
|
...
|
||||||
# Create a saver.
|
# Create a saver.
|
||||||
saver = tf.train.Saver(...variables...)
|
saver = tf.compat.v1.train.Saver(...variables...)
|
||||||
# Launch the graph and train, saving the model every 1,000 steps.
|
# Launch the graph and train, saving the model every 1,000 steps.
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
for step in xrange(1000000):
|
for step in xrange(1000000):
|
||||||
sess.run(..training_op..)
|
sess.run(..training_op..)
|
||||||
if step % 1000 == 0:
|
if step % 1000 == 0:
|
||||||
@ -717,13 +713,13 @@ class Saver(object):
|
|||||||
v2 = tf.Variable(..., name='v2')
|
v2 = tf.Variable(..., name='v2')
|
||||||
|
|
||||||
# Pass the variables as a dict:
|
# Pass the variables as a dict:
|
||||||
saver = tf.train.Saver({'v1': v1, 'v2': v2})
|
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
|
||||||
|
|
||||||
# Or pass them as a list.
|
# Or pass them as a list.
|
||||||
saver = tf.train.Saver([v1, v2])
|
saver = tf.compat.v1.train.Saver([v1, v2])
|
||||||
# Passing a list is equivalent to passing a dict with the variable op names
|
# Passing a list is equivalent to passing a dict with the variable op names
|
||||||
# as keys:
|
# as keys:
|
||||||
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
|
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
|
||||||
```
|
```
|
||||||
|
|
||||||
The optional `reshape` argument, if `True`, allows restoring a variable from
|
The optional `reshape` argument, if `True`, allows restoring a variable from
|
||||||
@ -738,35 +734,33 @@ class Saver(object):
|
|||||||
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
|
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
|
||||||
names to `SaveableObject`s. If `None`, defaults to the list of all
|
names to `SaveableObject`s. If `None`, defaults to the list of all
|
||||||
saveable objects.
|
saveable objects.
|
||||||
reshape: If `True`, allows restoring parameters from a checkpoint
|
reshape: If `True`, allows restoring parameters from a checkpoint where
|
||||||
where the variables have a different shape.
|
the variables have a different shape.
|
||||||
sharded: If `True`, shard the checkpoints, one per device.
|
sharded: If `True`, shard the checkpoints, one per device.
|
||||||
max_to_keep: Maximum number of recent checkpoints to keep.
|
max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
|
||||||
Defaults to 5.
|
keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
|
||||||
keep_checkpoint_every_n_hours: How often to keep checkpoints.
|
10,000 hours.
|
||||||
Defaults to 10,000 hours.
|
|
||||||
name: String. Optional name to use as a prefix when adding operations.
|
name: String. Optional name to use as a prefix when adding operations.
|
||||||
restore_sequentially: A `Bool`, which if true, causes restore of different
|
restore_sequentially: A `Bool`, which if true, causes restore of different
|
||||||
variables to happen sequentially within each device. This can lower
|
variables to happen sequentially within each device. This can lower
|
||||||
memory usage when restoring very large models.
|
memory usage when restoring very large models.
|
||||||
saver_def: Optional `SaverDef` proto to use instead of running the
|
saver_def: Optional `SaverDef` proto to use instead of running the
|
||||||
builder. This is only useful for specialty code that wants to recreate
|
builder. This is only useful for specialty code that wants to recreate a
|
||||||
a `Saver` object for a previously built `Graph` that had a `Saver`.
|
`Saver` object for a previously built `Graph` that had a `Saver`. The
|
||||||
The `saver_def` proto should be the one returned by the
|
`saver_def` proto should be the one returned by the `as_saver_def()`
|
||||||
`as_saver_def()` call of the `Saver` that was created for that `Graph`.
|
call of the `Saver` that was created for that `Graph`.
|
||||||
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
|
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
|
||||||
Defaults to `BulkSaverBuilder()`.
|
Defaults to `BulkSaverBuilder()`.
|
||||||
defer_build: If `True`, defer adding the save and restore ops to the
|
defer_build: If `True`, defer adding the save and restore ops to the
|
||||||
`build()` call. In that case `build()` should be called before
|
`build()` call. In that case `build()` should be called before
|
||||||
finalizing the graph or using the saver.
|
finalizing the graph or using the saver.
|
||||||
allow_empty: If `False` (default) raise an error if there are no
|
allow_empty: If `False` (default) raise an error if there are no variables
|
||||||
variables in the graph. Otherwise, construct the saver anyway and make
|
in the graph. Otherwise, construct the saver anyway and make it a no-op.
|
||||||
it a no-op.
|
|
||||||
write_version: controls what format to use when saving checkpoints. It
|
write_version: controls what format to use when saving checkpoints. It
|
||||||
also affects certain filepath matching logic. The V2 format is the
|
also affects certain filepath matching logic. The V2 format is the
|
||||||
recommended choice: it is much more optimized than V1 in terms of
|
recommended choice: it is much more optimized than V1 in terms of memory
|
||||||
memory required and latency incurred during restore. Regardless of
|
required and latency incurred during restore. Regardless of this
|
||||||
this flag, the Saver is able to restore from both V2 and V1 checkpoints.
|
flag, the Saver is able to restore from both V2 and V1 checkpoints.
|
||||||
pad_step_number: if True, pads the global step number in the checkpoint
|
pad_step_number: if True, pads the global step number in the checkpoint
|
||||||
filepaths to some fixed width (8 by default). This is turned off by
|
filepaths to some fixed width (8 by default). This is turned off by
|
||||||
default.
|
default.
|
||||||
@ -877,7 +871,8 @@ class Saver(object):
|
|||||||
name=self._name,
|
name=self._name,
|
||||||
restore_sequentially=self._restore_sequentially,
|
restore_sequentially=self._restore_sequentially,
|
||||||
filename=checkpoint_path,
|
filename=checkpoint_path,
|
||||||
build_save=build_save, build_restore=build_restore)
|
build_save=build_save,
|
||||||
|
build_restore=build_restore)
|
||||||
elif self.saver_def and self._name:
|
elif self.saver_def and self._name:
|
||||||
# Since self._name is used as a name_scope by builder(), we are
|
# Since self._name is used as a name_scope by builder(), we are
|
||||||
# overloading the use of this field to represent the "import_scope" as
|
# overloading the use of this field to represent the "import_scope" as
|
||||||
@ -997,8 +992,8 @@ class Saver(object):
|
|||||||
saver_def.filename_tensor_name, export_scope)
|
saver_def.filename_tensor_name, export_scope)
|
||||||
saver_def.save_tensor_name = ops.strip_name_scope(
|
saver_def.save_tensor_name = ops.strip_name_scope(
|
||||||
saver_def.save_tensor_name, export_scope)
|
saver_def.save_tensor_name, export_scope)
|
||||||
saver_def.restore_op_name = ops.strip_name_scope(
|
saver_def.restore_op_name = ops.strip_name_scope(saver_def.restore_op_name,
|
||||||
saver_def.restore_op_name, export_scope)
|
export_scope)
|
||||||
return saver_def
|
return saver_def
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1092,14 +1087,13 @@ class Saver(object):
|
|||||||
Args:
|
Args:
|
||||||
sess: A Session to use to save the variables.
|
sess: A Session to use to save the variables.
|
||||||
save_path: String. Prefix of filenames created for the checkpoint.
|
save_path: String. Prefix of filenames created for the checkpoint.
|
||||||
global_step: If provided the global step number is appended to
|
global_step: If provided the global step number is appended to `save_path`
|
||||||
`save_path` to create the checkpoint filenames. The optional argument
|
to create the checkpoint filenames. The optional argument can be a
|
||||||
can be a `Tensor`, a `Tensor` name or an integer.
|
`Tensor`, a `Tensor` name or an integer.
|
||||||
latest_filename: Optional name for the protocol buffer file that will
|
latest_filename: Optional name for the protocol buffer file that will
|
||||||
contains the list of most recent checkpoints. That file,
|
contains the list of most recent checkpoints. That file, kept in the
|
||||||
kept in the same directory as the checkpoint files, is automatically
|
same directory as the checkpoint files, is automatically managed by the
|
||||||
managed by the saver to keep track of recent checkpoints. Defaults to
|
saver to keep track of recent checkpoints. Defaults to 'checkpoint'.
|
||||||
'checkpoint'.
|
|
||||||
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
|
||||||
write_meta_graph: `Boolean` indicating whether or not to write the meta
|
write_meta_graph: `Boolean` indicating whether or not to write the meta
|
||||||
graph file.
|
graph file.
|
||||||
@ -1107,7 +1101,8 @@ class Saver(object):
|
|||||||
`CheckpointStateProto`.
|
`CheckpointStateProto`.
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
removed from the NodeDefs. For a detailed guide, see
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
[Stripping Default-Valued
|
||||||
|
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||||
which in the same directory of save_path and with `_debug` added before
|
which in the same directory of save_path and with `_debug` added before
|
||||||
the file extension. This is only enabled when `write_meta_graph` is
|
the file extension. This is only enabled when `write_meta_graph` is
|
||||||
@ -1151,8 +1146,7 @@ class Saver(object):
|
|||||||
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
|
checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
|
||||||
else:
|
else:
|
||||||
checkpoint_file = save_path
|
checkpoint_file = save_path
|
||||||
if os.path.basename(
|
if os.path.basename(save_path) == latest_filename and not self._sharded:
|
||||||
save_path) == latest_filename and not self._sharded:
|
|
||||||
# Guard against collision between data file and checkpoint state file.
|
# Guard against collision between data file and checkpoint state file.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
|
"'latest_filename' collides with 'save_path': '%s' and '%s'" %
|
||||||
@ -1197,7 +1191,8 @@ class Saver(object):
|
|||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
with sess.graph.as_default():
|
with sess.graph.as_default():
|
||||||
self.export_meta_graph(
|
self.export_meta_graph(
|
||||||
meta_graph_filename, strip_default_attrs=strip_default_attrs,
|
meta_graph_filename,
|
||||||
|
strip_default_attrs=strip_default_attrs,
|
||||||
save_debug_info=save_debug_info)
|
save_debug_info=save_debug_info)
|
||||||
|
|
||||||
if self._is_empty:
|
if self._is_empty:
|
||||||
@ -1225,11 +1220,12 @@ class Saver(object):
|
|||||||
clear_devices: Whether or not to clear the device field for an `Operation`
|
clear_devices: Whether or not to clear the device field for an `Operation`
|
||||||
or `Tensor` during export.
|
or `Tensor` during export.
|
||||||
clear_extraneous_savers: Remove any Saver-related information from the
|
clear_extraneous_savers: Remove any Saver-related information from the
|
||||||
graph (both Save/Restore ops and SaverDefs) that are not associated
|
graph (both Save/Restore ops and SaverDefs) that are not associated with
|
||||||
with this Saver.
|
this Saver.
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
removed from the NodeDefs. For a detailed guide, see
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
[Stripping Default-Valued
|
||||||
|
Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||||
which in the same directory of filename and with `_debug` added before
|
which in the same directory of filename and with `_debug` added before
|
||||||
the file extension.
|
the file extension.
|
||||||
@ -1274,8 +1270,8 @@ class Saver(object):
|
|||||||
raise ValueError("Can't load save_path when it is None.")
|
raise ValueError("Can't load save_path when it is None.")
|
||||||
|
|
||||||
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
|
if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
|
||||||
raise ValueError("The passed save_path is not a valid checkpoint: "
|
raise ValueError("The passed save_path is not a valid checkpoint: " +
|
||||||
+ compat.as_text(save_path))
|
compat.as_text(save_path))
|
||||||
|
|
||||||
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
logging.info("Restoring parameters from %s", compat.as_text(save_path))
|
||||||
try:
|
try:
|
||||||
@ -1330,13 +1326,15 @@ class Saver(object):
|
|||||||
key: One of the GraphKeys or user-defined string.
|
key: One of the GraphKeys or user-defined string.
|
||||||
export_scope: Optional `string`. Name scope to remove.
|
export_scope: Optional `string`. Name scope to remove.
|
||||||
"""
|
"""
|
||||||
meta_graph.add_collection_def(meta_graph_def, key,
|
meta_graph.add_collection_def(
|
||||||
export_scope=export_scope)
|
meta_graph_def, key, export_scope=export_scope)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["train.import_meta_graph"])
|
@tf_export(v1=["train.import_meta_graph"])
|
||||||
def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
def import_meta_graph(meta_graph_or_file,
|
||||||
import_scope=None, **kwargs):
|
clear_devices=False,
|
||||||
|
import_scope=None,
|
||||||
|
**kwargs):
|
||||||
"""Recreates a Graph saved in a `MetaGraphDef` proto.
|
"""Recreates a Graph saved in a `MetaGraphDef` proto.
|
||||||
|
|
||||||
This function takes a `MetaGraphDef` protocol buffer as input. If
|
This function takes a `MetaGraphDef` protocol buffer as input. If
|
||||||
@ -1358,10 +1356,10 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
```Python
|
```Python
|
||||||
...
|
...
|
||||||
# Create a saver.
|
# Create a saver.
|
||||||
saver = tf.train.Saver(...variables...)
|
saver = tf.compat.v1.train.Saver(...variables...)
|
||||||
# Remember the training_op we want to run by adding it to a collection.
|
# Remember the training_op we want to run by adding it to a collection.
|
||||||
tf.add_to_collection('train_op', train_op)
|
tf.compat.v1.add_to_collection('train_op', train_op)
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
for step in xrange(1000000):
|
for step in xrange(1000000):
|
||||||
sess.run(train_op)
|
sess.run(train_op)
|
||||||
if step % 1000 == 0:
|
if step % 1000 == 0:
|
||||||
@ -1374,12 +1372,13 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
the model from scratch.
|
the model from scratch.
|
||||||
|
|
||||||
```Python
|
```Python
|
||||||
with tf.Session() as sess:
|
with tf.compat.v1.Session() as sess:
|
||||||
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
|
new_saver =
|
||||||
|
tf.compat.v1.train.import_meta_graph('my-save-dir/my-model-10000.meta')
|
||||||
new_saver.restore(sess, 'my-save-dir/my-model-10000')
|
new_saver.restore(sess, 'my-save-dir/my-model-10000')
|
||||||
# tf.get_collection() returns a list. In this example we only want the
|
# tf.compat.v1.get_collection() returns a list. In this example we only want
|
||||||
# first one.
|
# the first one.
|
||||||
train_op = tf.get_collection('train_op')[0]
|
train_op = tf.compat.v1.get_collection('train_op')[0]
|
||||||
for step in xrange(1000000):
|
for step in xrange(1000000):
|
||||||
sess.run(train_op)
|
sess.run(train_op)
|
||||||
```
|
```
|
||||||
@ -1393,14 +1392,14 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
|
|
||||||
```Python
|
```Python
|
||||||
# Saving contents and operations.
|
# Saving contents and operations.
|
||||||
v1 = tf.placeholder(tf.float32, name="v1")
|
v1 = tf.compat.v1.placeholder(tf.float32, name="v1")
|
||||||
v2 = tf.placeholder(tf.float32, name="v2")
|
v2 = tf.compat.v1.placeholder(tf.float32, name="v2")
|
||||||
v3 = tf.mul(v1, v2)
|
v3 = tf.mul(v1, v2)
|
||||||
vx = tf.Variable(10.0, name="vx")
|
vx = tf.Variable(10.0, name="vx")
|
||||||
v4 = tf.add(v3, vx, name="v4")
|
v4 = tf.add(v3, vx, name="v4")
|
||||||
saver = tf.train.Saver([vx])
|
saver = tf.compat.v1.train.Saver([vx])
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.compat.v1.initialize_all_variables())
|
||||||
sess.run(vx.assign(tf.add(vx, vx)))
|
sess.run(vx.assign(tf.add(vx, vx)))
|
||||||
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
|
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
|
||||||
print(result)
|
print(result)
|
||||||
@ -1411,8 +1410,8 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
|
|
||||||
```Python
|
```Python
|
||||||
# Restoring variables and running operations.
|
# Restoring variables and running operations.
|
||||||
saver = tf.train.import_meta_graph("./model_ex1.meta")
|
saver = tf.compat.v1.train.import_meta_graph("./model_ex1.meta")
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
saver.restore(sess, "./model_ex1")
|
saver.restore(sess, "./model_ex1")
|
||||||
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
|
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
|
||||||
print(result)
|
print(result)
|
||||||
@ -1441,13 +1440,16 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
execution is enabled.
|
execution is enabled.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
""" # pylint: disable=g-doc-exception
|
""" # pylint: disable=g-doc-exception
|
||||||
return _import_meta_graph_with_return_elements(
|
return _import_meta_graph_with_return_elements(meta_graph_or_file,
|
||||||
meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
|
clear_devices, import_scope,
|
||||||
|
**kwargs)[0]
|
||||||
|
|
||||||
|
|
||||||
def _import_meta_graph_with_return_elements(
|
def _import_meta_graph_with_return_elements(meta_graph_or_file,
|
||||||
meta_graph_or_file, clear_devices=False, import_scope=None,
|
clear_devices=False,
|
||||||
return_elements=None, **kwargs):
|
import_scope=None,
|
||||||
|
return_elements=None,
|
||||||
|
**kwargs):
|
||||||
"""Import MetaGraph, and return both a saver and returned elements."""
|
"""Import MetaGraph, and return both a saver and returned elements."""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
raise RuntimeError("Exporting/importing meta graphs is not supported when "
|
raise RuntimeError("Exporting/importing meta graphs is not supported when "
|
||||||
@ -1466,13 +1468,13 @@ def _import_meta_graph_with_return_elements(
|
|||||||
return_elements=return_elements,
|
return_elements=return_elements,
|
||||||
**kwargs))
|
**kwargs))
|
||||||
|
|
||||||
saver = _create_saver_from_imported_meta_graph(
|
saver = _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
|
||||||
meta_graph_def, import_scope, imported_vars)
|
imported_vars)
|
||||||
return saver, imported_return_elements
|
return saver, imported_return_elements
|
||||||
|
|
||||||
|
|
||||||
def _create_saver_from_imported_meta_graph(
|
def _create_saver_from_imported_meta_graph(meta_graph_def, import_scope,
|
||||||
meta_graph_def, import_scope, imported_vars):
|
imported_vars):
|
||||||
"""Return a saver for restoring variable values to an imported MetaGraph."""
|
"""Return a saver for restoring variable values to an imported MetaGraph."""
|
||||||
if meta_graph_def.HasField("saver_def"):
|
if meta_graph_def.HasField("saver_def"):
|
||||||
# Infer the scope that is prepended by `import_scoped_meta_graph`.
|
# Infer the scope that is prepended by `import_scoped_meta_graph`.
|
||||||
@ -1510,7 +1512,9 @@ def export_meta_graph(filename=None,
|
|||||||
save_debug_info=False,
|
save_debug_info=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
|
"""Returns `MetaGraphDef` proto.
|
||||||
|
|
||||||
|
Optionally writes it to filename.
|
||||||
|
|
||||||
This function exports the graph, saver, and collection objects into
|
This function exports the graph, saver, and collection objects into
|
||||||
`MetaGraphDef` protocol buffer with the intention of it being imported
|
`MetaGraphDef` protocol buffer with the intention of it being imported
|
||||||
@ -1518,29 +1522,29 @@ def export_meta_graph(filename=None,
|
|||||||
a subgraph.
|
a subgraph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filename: Optional filename including the path for writing the
|
filename: Optional filename including the path for writing the generated
|
||||||
generated `MetaGraphDef` protocol buffer.
|
`MetaGraphDef` protocol buffer.
|
||||||
meta_info_def: `MetaInfoDef` protocol buffer.
|
meta_info_def: `MetaInfoDef` protocol buffer.
|
||||||
graph_def: `GraphDef` protocol buffer.
|
graph_def: `GraphDef` protocol buffer.
|
||||||
saver_def: `SaverDef` protocol buffer.
|
saver_def: `SaverDef` protocol buffer.
|
||||||
collection_list: List of string keys to collect.
|
collection_list: List of string keys to collect.
|
||||||
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
|
as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
|
||||||
graph: The `Graph` to export. If `None`, use the default graph.
|
graph: The `Graph` to export. If `None`, use the default graph.
|
||||||
export_scope: Optional `string`. Name scope under which to extract
|
export_scope: Optional `string`. Name scope under which to extract the
|
||||||
the subgraph. The scope name will be striped from the node definitions
|
subgraph. The scope name will be striped from the node definitions for
|
||||||
for easy import later into new name scopes. If `None`, the whole graph
|
easy import later into new name scopes. If `None`, the whole graph is
|
||||||
is exported. graph_def and export_scope cannot both be specified.
|
exported. graph_def and export_scope cannot both be specified.
|
||||||
clear_devices: Whether or not to clear the device field for an `Operation`
|
clear_devices: Whether or not to clear the device field for an `Operation`
|
||||||
or `Tensor` during export.
|
or `Tensor` during export.
|
||||||
clear_extraneous_savers: Remove any Saver-related information from the
|
clear_extraneous_savers: Remove any Saver-related information from the graph
|
||||||
graph (both Save/Restore ops and SaverDefs) that are not associated
|
(both Save/Restore ops and SaverDefs) that are not associated with the
|
||||||
with the provided SaverDef.
|
provided SaverDef.
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
removed from the NodeDefs. For a detailed guide, see
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||||
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
|
||||||
which in the same directory of filename and with `_debug` added before
|
which in the same directory of filename and with `_debug` added before the
|
||||||
the file extend.
|
file extend.
|
||||||
**kwargs: Optional keyed arguments.
|
**kwargs: Optional keyed arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1603,10 +1607,8 @@ def object_graph_key_mapping(checkpoint_path):
|
|||||||
Dictionary mapping tensor names to checkpoint keys.
|
Dictionary mapping tensor names to checkpoint keys.
|
||||||
"""
|
"""
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
|
||||||
object_graph_string = reader.get_tensor(
|
object_graph_string = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY)
|
||||||
trackable.OBJECT_GRAPH_PROTO_KEY)
|
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||||
object_graph_proto = (
|
|
||||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
|
||||||
object_graph_proto.ParseFromString(object_graph_string)
|
object_graph_proto.ParseFromString(object_graph_string)
|
||||||
names_to_keys = {}
|
names_to_keys = {}
|
||||||
for node in object_graph_proto.nodes:
|
for node in object_graph_proto.nodes:
|
||||||
@ -1615,9 +1617,11 @@ def object_graph_key_mapping(checkpoint_path):
|
|||||||
return names_to_keys
|
return names_to_keys
|
||||||
|
|
||||||
|
|
||||||
def saver_from_object_based_checkpoint(
|
def saver_from_object_based_checkpoint(checkpoint_path,
|
||||||
checkpoint_path, var_list=None, builder=None, names_to_keys=None,
|
var_list=None,
|
||||||
cached_saver=None):
|
builder=None,
|
||||||
|
names_to_keys=None,
|
||||||
|
cached_saver=None):
|
||||||
"""Return a `Saver` which reads from an object-based checkpoint.
|
"""Return a `Saver` which reads from an object-based checkpoint.
|
||||||
|
|
||||||
This function validates that all variables in the variables list are remapped
|
This function validates that all variables in the variables list are remapped
|
||||||
@ -1659,8 +1663,8 @@ def saver_from_object_based_checkpoint(
|
|||||||
try:
|
try:
|
||||||
names_to_keys = object_graph_key_mapping(checkpoint_path)
|
names_to_keys = object_graph_key_mapping(checkpoint_path)
|
||||||
except errors.NotFoundError:
|
except errors.NotFoundError:
|
||||||
raise ValueError("Checkpoint in %s not an object-based checkpoint."
|
raise ValueError("Checkpoint in %s not an object-based checkpoint." %
|
||||||
% checkpoint_path)
|
checkpoint_path)
|
||||||
if var_list is None:
|
if var_list is None:
|
||||||
var_list = variables._all_saveable_objects() # pylint: disable=protected-access
|
var_list = variables._all_saveable_objects() # pylint: disable=protected-access
|
||||||
if builder is None:
|
if builder is None:
|
||||||
@ -1677,7 +1681,8 @@ def saver_from_object_based_checkpoint(
|
|||||||
extra_names = previous_names - current_names
|
extra_names = previous_names - current_names
|
||||||
intersecting_names = previous_names.intersection(current_names)
|
intersecting_names = previous_names.intersection(current_names)
|
||||||
raise errors.NotFoundError(
|
raise errors.NotFoundError(
|
||||||
None, None,
|
None,
|
||||||
|
None,
|
||||||
message=(
|
message=(
|
||||||
"\n\nExisting variables not in the checkpoint: %s\n\n"
|
"\n\nExisting variables not in the checkpoint: %s\n\n"
|
||||||
"Variables names when this checkpoint was written which don't "
|
"Variables names when this checkpoint was written which don't "
|
||||||
@ -1695,9 +1700,9 @@ def saver_from_object_based_checkpoint(
|
|||||||
"existed, and if variable names have changed you may need to "
|
"existed, and if variable names have changed you may need to "
|
||||||
"make this a dictionary with the old names as keys. If you're "
|
"make this a dictionary with the old names as keys. If you're "
|
||||||
"using an Estimator, you'll need to return a tf.train.Saver "
|
"using an Estimator, you'll need to return a tf.train.Saver "
|
||||||
"inside a tf.train.Scaffold from your model_fn.")
|
"inside a tf.train.Scaffold from your model_fn.") %
|
||||||
% (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)),
|
(", ".join(sorted(missing_names)), ", ".join(
|
||||||
len(intersecting_names)))
|
sorted(extra_names)), len(intersecting_names)))
|
||||||
for saveable in saveables:
|
for saveable in saveables:
|
||||||
for spec in saveable.specs:
|
for spec in saveable.specs:
|
||||||
spec.name = names_to_keys[spec.name]
|
spec.name = names_to_keys[spec.name]
|
||||||
|
@ -32,21 +32,19 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
|||||||
"""Creates a `tf.train.ServerDef` protocol buffer.
|
"""Creates a `tf.train.ServerDef` protocol buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_or_cluster_def: A `tf.train.ServerDef` or
|
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
|
||||||
`tf.train.ClusterDef` protocol buffer, or a
|
protocol buffer, or a `tf.train.ClusterSpec` object, describing the server
|
||||||
`tf.train.ClusterSpec` object, describing the server to be
|
to be defined and/or the cluster of which it is a member.
|
||||||
defined and/or the cluster of which it is a member.
|
job_name: (Optional.) Specifies the name of the job of which the server is a
|
||||||
job_name: (Optional.) Specifies the name of the job of which the server
|
member. Defaults to the value in `server_or_cluster_def`, if specified.
|
||||||
is a member. Defaults to the value in `server_or_cluster_def`, if
|
|
||||||
specified.
|
|
||||||
task_index: (Optional.) Specifies the task index of the server in its job.
|
task_index: (Optional.) Specifies the task index of the server in its job.
|
||||||
Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
|
Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
|
||||||
defaults to 0 if the server's job has only one task.
|
defaults to 0 if the server's job has only one task.
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
|
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in
|
||||||
in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
|
`server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
|
||||||
config: (Options.) A `tf.ConfigProto` that specifies default configuration
|
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||||
options for all sessions that run on this server.
|
configuration options for all sessions that run on this server.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tf.train.ServerDef`.
|
A `tf.train.ServerDef`.
|
||||||
@ -88,7 +86,9 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
|||||||
|
|
||||||
server_def = tensorflow_server_pb2.ServerDef(
|
server_def = tensorflow_server_pb2.ServerDef(
|
||||||
cluster=cluster_spec.as_cluster_def(),
|
cluster=cluster_spec.as_cluster_def(),
|
||||||
job_name=job_name, task_index=task_index, protocol=protocol)
|
job_name=job_name,
|
||||||
|
task_index=task_index,
|
||||||
|
protocol=protocol)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
server_def.default_session_config.MergeFrom(config)
|
server_def.default_session_config.MergeFrom(config)
|
||||||
return server_def
|
return server_def
|
||||||
@ -99,8 +99,8 @@ def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
|
|||||||
class Server(object):
|
class Server(object):
|
||||||
"""An in-process TensorFlow server, for use in distributed training.
|
"""An in-process TensorFlow server, for use in distributed training.
|
||||||
|
|
||||||
A `tf.train.Server` instance encapsulates a set of devices and a
|
A `tf.distribute.Server` instance encapsulates a set of devices and a
|
||||||
`tf.Session` target that
|
`tf.compat.v1.Session` target that
|
||||||
can participate in distributed training. A server belongs to a
|
can participate in distributed training. A server belongs to a
|
||||||
cluster (specified by a `tf.train.ClusterSpec`), and
|
cluster (specified by a `tf.train.ClusterSpec`), and
|
||||||
corresponds to a particular task in a named job. The server can
|
corresponds to a particular task in a named job. The server can
|
||||||
@ -120,31 +120,30 @@ class Server(object):
|
|||||||
override any information provided in `server_or_cluster_def`.
|
override any information provided in `server_or_cluster_def`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_or_cluster_def: A `tf.train.ServerDef` or
|
server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef`
|
||||||
`tf.train.ClusterDef` protocol buffer, or a
|
protocol buffer, or a `tf.train.ClusterSpec` object, describing the
|
||||||
`tf.train.ClusterSpec` object, describing the server to be
|
server to be created and/or the cluster of which it is a member.
|
||||||
created and/or the cluster of which it is a member.
|
job_name: (Optional.) Specifies the name of the job of which the server is
|
||||||
job_name: (Optional.) Specifies the name of the job of which the server
|
a member. Defaults to the value in `server_or_cluster_def`, if
|
||||||
is a member. Defaults to the value in `server_or_cluster_def`, if
|
|
||||||
specified.
|
specified.
|
||||||
task_index: (Optional.) Specifies the task index of the server in its
|
task_index: (Optional.) Specifies the task index of the server in its job.
|
||||||
job. Defaults to the value in `server_or_cluster_def`, if specified.
|
Defaults to the value in `server_or_cluster_def`, if specified.
|
||||||
Otherwise defaults to 0 if the server's job has only one task.
|
Otherwise defaults to 0 if the server's job has only one task.
|
||||||
protocol: (Optional.) Specifies the protocol to be used by the server.
|
protocol: (Optional.) Specifies the protocol to be used by the server.
|
||||||
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the
|
Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value
|
||||||
value in `server_or_cluster_def`, if specified. Otherwise defaults to
|
in `server_or_cluster_def`, if specified. Otherwise defaults to
|
||||||
`"grpc"`.
|
`"grpc"`.
|
||||||
config: (Options.) A `tf.ConfigProto` that specifies default
|
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||||
configuration options for all sessions that run on this server.
|
configuration options for all sessions that run on this server.
|
||||||
start: (Optional.) Boolean, indicating whether to start the server
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
after creating it. Defaults to `True`.
|
creating it. Defaults to `True`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||||
creating the TensorFlow server.
|
creating the TensorFlow server.
|
||||||
"""
|
"""
|
||||||
self._server_def = _make_server_def(server_or_cluster_def,
|
self._server_def = _make_server_def(server_or_cluster_def, job_name,
|
||||||
job_name, task_index, protocol, config)
|
task_index, protocol, config)
|
||||||
self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
|
self._server = c_api.TF_NewServer(self._server_def.SerializeToString())
|
||||||
if start:
|
if start:
|
||||||
self.start()
|
self.start()
|
||||||
@ -195,15 +194,15 @@ class Server(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def target(self):
|
def target(self):
|
||||||
"""Returns the target for a `tf.Session` to connect to this server.
|
"""Returns the target for a `tf.compat.v1.Session` to connect to this server.
|
||||||
|
|
||||||
To create a
|
To create a
|
||||||
`tf.Session` that
|
`tf.compat.v1.Session` that
|
||||||
connects to this server, use the following snippet:
|
connects to this server, use the following snippet:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
server = tf.train.Server(...)
|
server = tf.distribute.Server(...)
|
||||||
with tf.Session(server.target):
|
with tf.compat.v1.Session(server.target):
|
||||||
# ...
|
# ...
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -217,22 +216,24 @@ class Server(object):
|
|||||||
"""Creates a new single-process cluster running on the local host.
|
"""Creates a new single-process cluster running on the local host.
|
||||||
|
|
||||||
This method is a convenience wrapper for creating a
|
This method is a convenience wrapper for creating a
|
||||||
`tf.train.Server` with a `tf.train.ServerDef` that specifies a
|
`tf.distribute.Server` with a `tf.train.ServerDef` that specifies a
|
||||||
single-process cluster containing a single task in a job called
|
single-process cluster containing a single task in a job called
|
||||||
`"local"`.
|
`"local"`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: (Options.) A `tf.ConfigProto` that specifies default
|
config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default
|
||||||
configuration options for all sessions that run on this server.
|
configuration options for all sessions that run on this server.
|
||||||
start: (Optional.) Boolean, indicating whether to start the server after
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||||||
creating it. Defaults to `True`.
|
creating it. Defaults to `True`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A local `tf.train.Server`.
|
A local `tf.distribute.Server`.
|
||||||
"""
|
"""
|
||||||
# Specifying port 0 means that the OS will choose a free port for the
|
# Specifying port 0 means that the OS will choose a free port for the
|
||||||
# server.
|
# server.
|
||||||
return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
|
return Server({"local": ["localhost:0"]},
|
||||||
|
protocol="grpc",
|
||||||
|
config=config,
|
||||||
start=start)
|
start=start)
|
||||||
|
|
||||||
|
|
||||||
@ -242,7 +243,7 @@ class ClusterSpec(object):
|
|||||||
|
|
||||||
A `tf.train.ClusterSpec` represents the set of processes that
|
A `tf.train.ClusterSpec` represents the set of processes that
|
||||||
participate in a distributed TensorFlow computation. Every
|
participate in a distributed TensorFlow computation. Every
|
||||||
`tf.train.Server` is constructed in a particular cluster.
|
`tf.distribute.Server` is constructed in a particular cluster.
|
||||||
|
|
||||||
To create a cluster with two jobs and five tasks, you specify the
|
To create a cluster with two jobs and five tasks, you specify the
|
||||||
mapping from job names to lists of network addresses (typically
|
mapping from job names to lists of network addresses (typically
|
||||||
@ -272,10 +273,9 @@ class ClusterSpec(object):
|
|||||||
"""Creates a `ClusterSpec`.
|
"""Creates a `ClusterSpec`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cluster: A dictionary mapping one or more job names to (i) a
|
cluster: A dictionary mapping one or more job names to (i) a list of
|
||||||
list of network addresses, or (ii) a dictionary mapping integer
|
network addresses, or (ii) a dictionary mapping integer task indices to
|
||||||
task indices to network addresses; or a `tf.train.ClusterDef`
|
network addresses; or a `tf.train.ClusterDef` protocol buffer.
|
||||||
protocol buffer.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `cluster` is not a dictionary mapping strings to lists
|
TypeError: If `cluster` is not a dictionary mapping strings to lists
|
||||||
@ -298,14 +298,16 @@ class ClusterSpec(object):
|
|||||||
self._cluster_spec = {}
|
self._cluster_spec = {}
|
||||||
for job_def in self._cluster_def.job:
|
for job_def in self._cluster_def.job:
|
||||||
self._cluster_spec[job_def.name] = {
|
self._cluster_spec[job_def.name] = {
|
||||||
i: t for i, t in job_def.tasks.items()}
|
i: t for i, t in job_def.tasks.items()
|
||||||
|
}
|
||||||
elif isinstance(cluster, ClusterSpec):
|
elif isinstance(cluster, ClusterSpec):
|
||||||
self._cluster_def = cluster_pb2.ClusterDef()
|
self._cluster_def = cluster_pb2.ClusterDef()
|
||||||
self._cluster_def.MergeFrom(cluster.as_cluster_def())
|
self._cluster_def.MergeFrom(cluster.as_cluster_def())
|
||||||
self._cluster_spec = {}
|
self._cluster_spec = {}
|
||||||
for job_def in self._cluster_def.job:
|
for job_def in self._cluster_def.job:
|
||||||
self._cluster_spec[job_def.name] = {
|
self._cluster_spec[job_def.name] = {
|
||||||
i: t for i, t in job_def.tasks.items()}
|
i: t for i, t in job_def.tasks.items()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise TypeError("`cluster` must be a dictionary mapping one or more "
|
raise TypeError("`cluster` must be a dictionary mapping one or more "
|
||||||
"job names to lists of network addresses, or a "
|
"job names to lists of network addresses, or a "
|
||||||
@ -326,7 +328,8 @@ class ClusterSpec(object):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
key_values = self.as_dict()
|
key_values = self.as_dict()
|
||||||
string_items = [
|
string_items = [
|
||||||
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
|
repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)
|
||||||
|
]
|
||||||
return "ClusterSpec({" + ", ".join(string_items) + "})"
|
return "ClusterSpec({" + ", ".join(string_items) + "})"
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
@ -427,8 +430,8 @@ class ClusterSpec(object):
|
|||||||
try:
|
try:
|
||||||
return job[task_index]
|
return job[task_index]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError("No task with index %r in job %r"
|
raise ValueError("No task with index %r in job %r" %
|
||||||
% (task_index, job_name))
|
(task_index, job_name))
|
||||||
|
|
||||||
def job_tasks(self, job_name):
|
def job_tasks(self, job_name):
|
||||||
"""Returns a mapping from task ID to address in the given job.
|
"""Returns a mapping from task ID to address in the given job.
|
||||||
@ -482,6 +485,6 @@ class ClusterSpec(object):
|
|||||||
try:
|
try:
|
||||||
task_address = compat.as_bytes(task_address)
|
task_address = compat.as_bytes(task_address)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise TypeError(
|
raise TypeError("Task address %r must be bytes or unicode" %
|
||||||
"Task address %r must be bytes or unicode" % task_address)
|
task_address)
|
||||||
job_def.tasks[i] = task_address
|
job_def.tasks[i] = task_address
|
||||||
|
@ -107,7 +107,7 @@ class GrpcServerTest(test.TestCase):
|
|||||||
self.assertAllEqual(2.0, sess.run(v1))
|
self.assertAllEqual(2.0, sess.run(v1))
|
||||||
|
|
||||||
def _useRPCConfig(self):
|
def _useRPCConfig(self):
|
||||||
"""Return a `tf.ConfigProto` that ensures we use the RPC stack for tests.
|
"""Return a `tf.compat.v1.ConfigProto` that ensures we use the RPC stack for tests.
|
||||||
|
|
||||||
This configuration ensures that we continue to exercise the gRPC
|
This configuration ensures that we continue to exercise the gRPC
|
||||||
stack when testing, rather than using the in-process optimization,
|
stack when testing, rather than using the in-process optimization,
|
||||||
@ -115,7 +115,7 @@ class GrpcServerTest(test.TestCase):
|
|||||||
master in the same process.
|
master in the same process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `tf.ConfigProto`.
|
A `tf.compat.v1.ConfigProto`.
|
||||||
"""
|
"""
|
||||||
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
|
return config_pb2.ConfigProto(rpc_options=config_pb2.RPCOptions(
|
||||||
use_rpc_for_inprocess_master=True))
|
use_rpc_for_inprocess_master=True))
|
||||||
|
@ -174,8 +174,8 @@ class SessionManagerTest(test.TestCase):
|
|||||||
self.assertFalse(initialized)
|
self.assertFalse(initialized)
|
||||||
sess.run(v.initializer)
|
sess.run(v.initializer)
|
||||||
self.assertEquals(1, sess.run(v))
|
self.assertEquals(1, sess.run(v))
|
||||||
saver.save(sess,
|
saver.save(sess, os.path.join(checkpoint_dir,
|
||||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
"recover_session_checkpoint"))
|
||||||
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
|
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
|
||||||
self._test_recovered_variable(
|
self._test_recovered_variable(
|
||||||
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
||||||
@ -202,9 +202,9 @@ class SessionManagerTest(test.TestCase):
|
|||||||
def testInitWithNoneLocalInitOpError(self):
|
def testInitWithNoneLocalInitOpError(self):
|
||||||
# Creating a SessionManager with a None local_init_op but
|
# Creating a SessionManager with a None local_init_op but
|
||||||
# non-None ready_for_local_init_op raises ValueError
|
# non-None ready_for_local_init_op raises ValueError
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(
|
||||||
"If you pass a ready_for_local_init_op "
|
ValueError, "If you pass a ready_for_local_init_op "
|
||||||
"you must also pass a local_init_op "):
|
"you must also pass a local_init_op "):
|
||||||
session_manager.SessionManager(
|
session_manager.SessionManager(
|
||||||
ready_for_local_init_op=variables.report_uninitialized_variables(
|
ready_for_local_init_op=variables.report_uninitialized_variables(
|
||||||
variables.global_variables()),
|
variables.global_variables()),
|
||||||
@ -231,8 +231,8 @@ class SessionManagerTest(test.TestCase):
|
|||||||
self.assertFalse(initialized)
|
self.assertFalse(initialized)
|
||||||
sess.run(v.initializer)
|
sess.run(v.initializer)
|
||||||
self.assertEquals(1, sess.run(v))
|
self.assertEquals(1, sess.run(v))
|
||||||
saver.save(sess,
|
saver.save(sess, os.path.join(checkpoint_dir,
|
||||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
"recover_session_checkpoint"))
|
||||||
# Create a new Graph and SessionManager and recover.
|
# Create a new Graph and SessionManager and recover.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.VariableV1(2, name="v")
|
v = variables.VariableV1(2, name="v")
|
||||||
@ -266,7 +266,7 @@ class SessionManagerTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
|
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
|
||||||
# We use ready_for_local_init_op=tf.report_uninitialized_variables(),
|
# We use ready_for_local_init_op=report_uninitialized_variables(),
|
||||||
# which causes recover_session to not run local_init_op, and to return
|
# which causes recover_session to not run local_init_op, and to return
|
||||||
# initialized=False
|
# initialized=False
|
||||||
|
|
||||||
@ -290,8 +290,8 @@ class SessionManagerTest(test.TestCase):
|
|||||||
self.assertFalse(initialized)
|
self.assertFalse(initialized)
|
||||||
sess.run(v.initializer)
|
sess.run(v.initializer)
|
||||||
self.assertEquals(1, sess.run(v))
|
self.assertEquals(1, sess.run(v))
|
||||||
saver.save(sess,
|
saver.save(sess, os.path.join(checkpoint_dir,
|
||||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
"recover_session_checkpoint"))
|
||||||
# Create a new Graph and SessionManager and recover.
|
# Create a new Graph and SessionManager and recover.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.VariableV1(2, name="v")
|
v = variables.VariableV1(2, name="v")
|
||||||
@ -780,8 +780,8 @@ class ObsoleteSessionManagerTest(test.TestCase):
|
|||||||
self.assertFalse(initialized)
|
self.assertFalse(initialized)
|
||||||
sess.run(v.initializer)
|
sess.run(v.initializer)
|
||||||
self.assertEquals(1, sess.run(v))
|
self.assertEquals(1, sess.run(v))
|
||||||
saver.save(sess,
|
saver.save(sess, os.path.join(checkpoint_dir,
|
||||||
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
|
"recover_session_checkpoint"))
|
||||||
# Create a new Graph and SessionManager and recover.
|
# Create a new Graph and SessionManager and recover.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
v = variables.VariableV1(2, name="v")
|
v = variables.VariableV1(2, name="v")
|
||||||
|
@ -68,7 +68,7 @@ look at following code:
|
|||||||
|
|
||||||
Above user code leads to following execution:
|
Above user code leads to following execution:
|
||||||
call hooks.begin()
|
call hooks.begin()
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
call hooks.after_create_session()
|
call hooks.after_create_session()
|
||||||
while not stop is requested:
|
while not stop is requested:
|
||||||
call hooks.before_run()
|
call hooks.before_run()
|
||||||
|
@ -56,9 +56,9 @@ class SummaryWriter(_FileWriter):
|
|||||||
```python
|
```python
|
||||||
...create a graph...
|
...create a graph...
|
||||||
# Launch the graph in a session.
|
# Launch the graph in a session.
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
# Create a summary writer, add the 'graph' to the event file.
|
# Create a summary writer, add the 'graph' to the event file.
|
||||||
writer = tf.summary.FileWriter(<some-directory>, sess.graph)
|
writer = tf.compat.v1.summary.FileWriter(<some-directory>, sess.graph)
|
||||||
```
|
```
|
||||||
|
|
||||||
The other arguments to the constructor control the asynchronous writes to
|
The other arguments to the constructor control the asynchronous writes to
|
||||||
|
@ -45,7 +45,7 @@ class Supervisor(object):
|
|||||||
"""A training helper that checkpoints models and computes summaries.
|
"""A training helper that checkpoints models and computes summaries.
|
||||||
|
|
||||||
This class is deprecated. Please use
|
This class is deprecated. Please use
|
||||||
`tf.train.MonitoredTrainingSession` instead.
|
`tf.compat.v1.train.MonitoredTrainingSession` instead.
|
||||||
|
|
||||||
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
|
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
|
||||||
and a `SessionManager` that takes care of common needs of TensorFlow
|
and a `SessionManager` that takes care of common needs of TensorFlow
|
||||||
@ -97,7 +97,7 @@ class Supervisor(object):
|
|||||||
# or job_def.name, or job_def.tasks. It's entirely up to the end user.
|
# or job_def.name, or job_def.tasks. It's entirely up to the end user.
|
||||||
# But there can be only one *chief*.
|
# But there can be only one *chief*.
|
||||||
is_chief = (server_def.task_index == 0)
|
is_chief = (server_def.task_index == 0)
|
||||||
server = tf.train.Server(server_def)
|
server = tf.distribute.Server(server_def)
|
||||||
|
|
||||||
with tf.Graph().as_default():
|
with tf.Graph().as_default():
|
||||||
...add operations to the graph...
|
...add operations to the graph...
|
||||||
@ -140,7 +140,7 @@ class Supervisor(object):
|
|||||||
* Specifying `'grpc://hostname:port'` requests a session that uses
|
* Specifying `'grpc://hostname:port'` requests a session that uses
|
||||||
the RPC interface to a specific host, and also allows the in-process
|
the RPC interface to a specific host, and also allows the in-process
|
||||||
master to access remote tensorflow workers. Often, it is
|
master to access remote tensorflow workers. Often, it is
|
||||||
appropriate to pass `server.target` (for some `tf.train.Server`
|
appropriate to pass `server.target` (for some `tf.distribute.Server`
|
||||||
named `server).
|
named `server).
|
||||||
|
|
||||||
#### Advanced use
|
#### Advanced use
|
||||||
@ -237,17 +237,16 @@ class Supervisor(object):
|
|||||||
ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
|
ready_op: 1-D string `Tensor`. This tensor is evaluated by supervisors in
|
||||||
`prepare_or_wait_for_session()` to check if the model is ready to use.
|
`prepare_or_wait_for_session()` to check if the model is ready to use.
|
||||||
The model is considered ready if it returns an empty array. Defaults to
|
The model is considered ready if it returns an empty array. Defaults to
|
||||||
the tensor returned from `tf.report_uninitialized_variables()` If
|
the tensor returned from `tf.compat.v1.report_uninitialized_variables()`
|
||||||
`None`, the model is not checked for readiness.
|
If `None`, the model is not checked for readiness.
|
||||||
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
|
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
|
||||||
supervisors in `prepare_or_wait_for_session()` to check if the model is
|
supervisors in `prepare_or_wait_for_session()` to check if the model is
|
||||||
ready to run the local_init_op.
|
ready to run the local_init_op. The model is considered ready if it
|
||||||
The model is considered ready if it returns an empty array. Defaults to
|
returns an empty array. Defaults to `None`. If `None`, the model is not
|
||||||
`None`. If `None`, the model is not checked for readiness before running
|
checked for readiness before running local_init_op.
|
||||||
local_init_op.
|
is_chief: If True, create a chief supervisor in charge of initializing and
|
||||||
is_chief: If True, create a chief supervisor in charge of initializing
|
restoring the model. If False, create a supervisor that relies on a
|
||||||
and restoring the model. If False, create a supervisor that relies
|
chief supervisor for inits and restore.
|
||||||
on a chief supervisor for inits and restore.
|
|
||||||
init_op: `Operation`. Used by chief supervisors to initialize the model
|
init_op: `Operation`. Used by chief supervisors to initialize the model
|
||||||
when it can not be recovered. Defaults to an `Operation` that
|
when it can not be recovered. Defaults to an `Operation` that
|
||||||
initializes all global variables. If `None`, no initialization is done
|
initializes all global variables. If `None`, no initialization is done
|
||||||
@ -255,20 +254,19 @@ class Supervisor(object):
|
|||||||
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
|
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
|
||||||
This feed dictionary will be used when `init_op` is evaluated.
|
This feed dictionary will be used when `init_op` is evaluated.
|
||||||
local_init_op: `Operation`. Used by all supervisors to run initializations
|
local_init_op: `Operation`. Used by all supervisors to run initializations
|
||||||
that should run for every new supervisor instance. By default these
|
that should run for every new supervisor instance. By default these are
|
||||||
are table initializers and initializers for local variables.
|
table initializers and initializers for local variables. If `None`, no
|
||||||
If `None`, no further per supervisor-instance initialization is
|
further per supervisor-instance initialization is done automatically.
|
||||||
done automatically.
|
|
||||||
logdir: A string. Optional path to a directory where to checkpoint the
|
logdir: A string. Optional path to a directory where to checkpoint the
|
||||||
model and log events for the visualizer. Used by chief supervisors.
|
model and log events for the visualizer. Used by chief supervisors. The
|
||||||
The directory will be created if it does not exist.
|
directory will be created if it does not exist.
|
||||||
summary_op: An `Operation` that returns a Summary for the event logs.
|
summary_op: An `Operation` that returns a Summary for the event logs. Used
|
||||||
Used by chief supervisors if a `logdir` was specified. Defaults to the
|
by chief supervisors if a `logdir` was specified. Defaults to the
|
||||||
operation returned from summary.merge_all(). If `None`, summaries are
|
operation returned from summary.merge_all(). If `None`, summaries are
|
||||||
not computed automatically.
|
not computed automatically.
|
||||||
saver: A Saver object. Used by chief supervisors if a `logdir` was
|
saver: A Saver object. Used by chief supervisors if a `logdir` was
|
||||||
specified. Defaults to the saved returned by Saver().
|
specified. Defaults to the saved returned by Saver(). If `None`, the
|
||||||
If `None`, the model is not saved automatically.
|
model is not saved automatically.
|
||||||
global_step: An integer Tensor of size 1 that counts steps. The value
|
global_step: An integer Tensor of size 1 that counts steps. The value
|
||||||
from 'global_step' is used in summaries and checkpoint filenames.
|
from 'global_step' is used in summaries and checkpoint filenames.
|
||||||
Default to the op named 'global_step' in the graph if it exists, is of
|
Default to the op named 'global_step' in the graph if it exists, is of
|
||||||
@ -280,20 +278,20 @@ class Supervisor(object):
|
|||||||
disable summaries.
|
disable summaries.
|
||||||
save_model_secs: Number of seconds between the creation of model
|
save_model_secs: Number of seconds between the creation of model
|
||||||
checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
|
checkpoints. Defaults to 600 seconds. Pass 0 to disable checkpoints.
|
||||||
recovery_wait_secs: Number of seconds between checks that the model
|
recovery_wait_secs: Number of seconds between checks that the model is
|
||||||
is ready. Used by supervisors when waiting for a chief supervisor
|
ready. Used by supervisors when waiting for a chief supervisor to
|
||||||
to initialize or restore the model. Defaults to 30 seconds.
|
initialize or restore the model. Defaults to 30 seconds.
|
||||||
stop_grace_secs: Grace period, in seconds, given to running threads to
|
stop_grace_secs: Grace period, in seconds, given to running threads to
|
||||||
stop when `stop()` is called. Defaults to 120 seconds.
|
stop when `stop()` is called. Defaults to 120 seconds.
|
||||||
checkpoint_basename: The basename for checkpoint saving.
|
checkpoint_basename: The basename for checkpoint saving.
|
||||||
session_manager: `SessionManager`, which manages Session creation and
|
session_manager: `SessionManager`, which manages Session creation and
|
||||||
recovery. If it is `None`, a default `SessionManager` will be created
|
recovery. If it is `None`, a default `SessionManager` will be created
|
||||||
with the set of arguments passed in for backwards compatibility.
|
with the set of arguments passed in for backwards compatibility.
|
||||||
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None`
|
summary_writer: `SummaryWriter` to use or `USE_DEFAULT`. Can be `None` to
|
||||||
to indicate that no summaries should be written.
|
indicate that no summaries should be written.
|
||||||
init_fn: Optional callable used to initialize the model. Called
|
init_fn: Optional callable used to initialize the model. Called after the
|
||||||
after the optional `init_op` is called. The callable must accept one
|
optional `init_op` is called. The callable must accept one argument,
|
||||||
argument, the session being initialized.
|
the session being initialized.
|
||||||
local_init_run_options: RunOptions to be passed as the SessionManager
|
local_init_run_options: RunOptions to be passed as the SessionManager
|
||||||
local_init_run_options parameter.
|
local_init_run_options parameter.
|
||||||
|
|
||||||
@ -397,12 +395,11 @@ class Supervisor(object):
|
|||||||
"""Initializes ready_op.
|
"""Initializes ready_op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ready_op: `Tensor` to check if the model is initialized.
|
ready_op: `Tensor` to check if the model is initialized. If it's set to
|
||||||
If it's set to USE_DEFAULT, creates an op that checks all
|
USE_DEFAULT, creates an op that checks all the variables are
|
||||||
the variables are initialized.
|
initialized.
|
||||||
ready_for_local_init_op: `Tensor` to check if the model is ready to run
|
ready_for_local_init_op: `Tensor` to check if the model is ready to run
|
||||||
local_init_op.
|
local_init_op. If it's set to USE_DEFAULT, creates an op that checks all
|
||||||
If it's set to USE_DEFAULT, creates an op that checks all
|
|
||||||
the global variables are initialized.
|
the global variables are initialized.
|
||||||
"""
|
"""
|
||||||
if ready_op is Supervisor.USE_DEFAULT:
|
if ready_op is Supervisor.USE_DEFAULT:
|
||||||
@ -440,9 +437,9 @@ class Supervisor(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_init_op: `Operation` run for every new supervisor instance. If set
|
local_init_op: `Operation` run for every new supervisor instance. If set
|
||||||
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
|
to USE_DEFAULT, use the first op from the GraphKeys.LOCAL_INIT_OP
|
||||||
collection. If the collection is empty, create an op that initializes
|
collection. If the collection is empty, create an op that initializes
|
||||||
all local variables and all tables.
|
all local variables and all tables.
|
||||||
"""
|
"""
|
||||||
if local_init_op is Supervisor.USE_DEFAULT:
|
if local_init_op is Supervisor.USE_DEFAULT:
|
||||||
local_init_op = self._get_first_op_from_collection(
|
local_init_op = self._get_first_op_from_collection(
|
||||||
@ -461,8 +458,8 @@ class Supervisor(object):
|
|||||||
"""Initializes saver.
|
"""Initializes saver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
saver: A `Saver` object. If set to USE_DEFAULT, create one that
|
saver: A `Saver` object. If set to USE_DEFAULT, create one that saves all
|
||||||
saves all the variables.
|
the variables.
|
||||||
"""
|
"""
|
||||||
if saver is Supervisor.USE_DEFAULT:
|
if saver is Supervisor.USE_DEFAULT:
|
||||||
saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
|
saver = self._get_first_op_from_collection(ops.GraphKeys.SAVERS)
|
||||||
@ -475,8 +472,8 @@ class Supervisor(object):
|
|||||||
"""Initializes summary_op.
|
"""Initializes summary_op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
summary_op: An Operation that returns a Summary for the event logs.
|
summary_op: An Operation that returns a Summary for the event logs. If set
|
||||||
If set to USE_DEFAULT, create an op that merges all the summaries.
|
to USE_DEFAULT, create an op that merges all the summaries.
|
||||||
"""
|
"""
|
||||||
if summary_op is Supervisor.USE_DEFAULT:
|
if summary_op is Supervisor.USE_DEFAULT:
|
||||||
summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
|
summary_op = self._get_first_op_from_collection(ops.GraphKeys.SUMMARY_OP)
|
||||||
@ -490,8 +487,8 @@ class Supervisor(object):
|
|||||||
"""Initializes global_step.
|
"""Initializes global_step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
global_step: An integer Tensor of size 1 that counts steps. If
|
global_step: An integer Tensor of size 1 that counts steps. If set to
|
||||||
set to USE_DEFAULT, creates global_step tensor.
|
USE_DEFAULT, creates global_step tensor.
|
||||||
"""
|
"""
|
||||||
if global_step is Supervisor.USE_DEFAULT:
|
if global_step is Supervisor.USE_DEFAULT:
|
||||||
global_step = self._get_first_op_from_collection(
|
global_step = self._get_first_op_from_collection(
|
||||||
@ -630,8 +627,9 @@ class Supervisor(object):
|
|||||||
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
|
"""Writes graph_def to `logdir` and adds it to summary if applicable."""
|
||||||
assert self._is_chief
|
assert self._is_chief
|
||||||
if self._logdir:
|
if self._logdir:
|
||||||
training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
|
training_util.write_graph(
|
||||||
self._logdir, "graph.pbtxt")
|
self._graph.as_graph_def(add_shapes=True), self._logdir,
|
||||||
|
"graph.pbtxt")
|
||||||
if self._summary_writer and not self._graph_added_to_summary:
|
if self._summary_writer and not self._graph_added_to_summary:
|
||||||
self._summary_writer.add_graph(self._graph)
|
self._summary_writer.add_graph(self._graph)
|
||||||
self._summary_writer.add_meta_graph(self._meta_graph_def)
|
self._summary_writer.add_meta_graph(self._meta_graph_def)
|
||||||
@ -675,8 +673,7 @@ class Supervisor(object):
|
|||||||
# if there is no step value.
|
# if there is no step value.
|
||||||
current_step = training_util.global_step(sess, self._global_step)
|
current_step = training_util.global_step(sess, self._global_step)
|
||||||
self._summary_writer.add_session_log(
|
self._summary_writer.add_session_log(
|
||||||
SessionLog(status=SessionLog.START),
|
SessionLog(status=SessionLog.START), current_step)
|
||||||
current_step)
|
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
if self._save_summaries_secs and self._summary_writer:
|
if self._save_summaries_secs and self._summary_writer:
|
||||||
@ -690,7 +687,9 @@ class Supervisor(object):
|
|||||||
t.start()
|
t.start()
|
||||||
return threads
|
return threads
|
||||||
|
|
||||||
def prepare_or_wait_for_session(self, master="", config=None,
|
def prepare_or_wait_for_session(self,
|
||||||
|
master="",
|
||||||
|
config=None,
|
||||||
wait_for_checkpoint=False,
|
wait_for_checkpoint=False,
|
||||||
max_wait_secs=7200,
|
max_wait_secs=7200,
|
||||||
start_standard_services=True):
|
start_standard_services=True):
|
||||||
@ -702,10 +701,10 @@ class Supervisor(object):
|
|||||||
manager to start the standard services.
|
manager to start the standard services.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
master: name of the TensorFlow master to use. See the `tf.Session`
|
master: name of the TensorFlow master to use. See the
|
||||||
constructor for how this is interpreted.
|
`tf.compat.v1.Session` constructor for how this is interpreted.
|
||||||
config: Optional ConfigProto proto used to configure the session,
|
config: Optional ConfigProto proto used to configure the session, which is
|
||||||
which is passed as-is to create the session.
|
passed as-is to create the session.
|
||||||
wait_for_checkpoint: Whether we should wait for the availability of a
|
wait_for_checkpoint: Whether we should wait for the availability of a
|
||||||
checkpoint before creating Session. Defaults to False.
|
checkpoint before creating Session. Defaults to False.
|
||||||
max_wait_secs: Maximum time to wait for the session to become available.
|
max_wait_secs: Maximum time to wait for the session to become available.
|
||||||
@ -724,18 +723,22 @@ class Supervisor(object):
|
|||||||
|
|
||||||
if self._is_chief:
|
if self._is_chief:
|
||||||
sess = self._session_manager.prepare_session(
|
sess = self._session_manager.prepare_session(
|
||||||
master, init_op=self.init_op, saver=self.saver,
|
master,
|
||||||
checkpoint_dir=self._logdir, wait_for_checkpoint=wait_for_checkpoint,
|
init_op=self.init_op,
|
||||||
max_wait_secs=max_wait_secs, config=config,
|
saver=self.saver,
|
||||||
init_feed_dict=self._init_feed_dict, init_fn=self._init_fn)
|
checkpoint_dir=self._logdir,
|
||||||
|
wait_for_checkpoint=wait_for_checkpoint,
|
||||||
|
max_wait_secs=max_wait_secs,
|
||||||
|
config=config,
|
||||||
|
init_feed_dict=self._init_feed_dict,
|
||||||
|
init_fn=self._init_fn)
|
||||||
self._write_graph()
|
self._write_graph()
|
||||||
if start_standard_services:
|
if start_standard_services:
|
||||||
logging.info("Starting standard services.")
|
logging.info("Starting standard services.")
|
||||||
self.start_standard_services(sess)
|
self.start_standard_services(sess)
|
||||||
else:
|
else:
|
||||||
sess = self._session_manager.wait_for_session(master,
|
sess = self._session_manager.wait_for_session(
|
||||||
config=config,
|
master, config=config, max_wait_secs=max_wait_secs)
|
||||||
max_wait_secs=max_wait_secs)
|
|
||||||
if start_standard_services:
|
if start_standard_services:
|
||||||
logging.info("Starting queue runners.")
|
logging.info("Starting queue runners.")
|
||||||
self.start_queue_runners(sess)
|
self.start_queue_runners(sess)
|
||||||
@ -772,8 +775,8 @@ class Supervisor(object):
|
|||||||
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
|
queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS)
|
||||||
threads = []
|
threads = []
|
||||||
for qr in queue_runners:
|
for qr in queue_runners:
|
||||||
threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True,
|
threads.extend(
|
||||||
start=True))
|
qr.create_threads(sess, coord=self._coord, daemon=True, start=True))
|
||||||
return threads
|
return threads
|
||||||
|
|
||||||
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
|
def loop(self, timer_interval_secs, target, args=None, kwargs=None):
|
||||||
@ -795,8 +798,12 @@ class Supervisor(object):
|
|||||||
Returns:
|
Returns:
|
||||||
The started thread.
|
The started thread.
|
||||||
"""
|
"""
|
||||||
looper = coordinator.LooperThread(self._coord, timer_interval_secs,
|
looper = coordinator.LooperThread(
|
||||||
target=target, args=args, kwargs=kwargs)
|
self._coord,
|
||||||
|
timer_interval_secs,
|
||||||
|
target=target,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs)
|
||||||
looper.start()
|
looper.start()
|
||||||
return looper
|
return looper
|
||||||
|
|
||||||
@ -812,13 +819,13 @@ class Supervisor(object):
|
|||||||
threads: Optional list of threads to join with the coordinator. If
|
threads: Optional list of threads to join with the coordinator. If
|
||||||
`None`, defaults to the threads running the standard services, the
|
`None`, defaults to the threads running the standard services, the
|
||||||
threads started for `QueueRunners`, and the threads started by the
|
threads started for `QueueRunners`, and the threads started by the
|
||||||
`loop()` method. To wait on additional threads, pass the
|
`loop()` method. To wait on additional threads, pass the list in this
|
||||||
list in this parameter.
|
parameter.
|
||||||
close_summary_writer: Whether to close the `summary_writer`. Defaults to
|
close_summary_writer: Whether to close the `summary_writer`. Defaults to
|
||||||
`True` if the summary writer was created by the supervisor, `False`
|
`True` if the summary writer was created by the supervisor, `False`
|
||||||
otherwise.
|
otherwise.
|
||||||
ignore_live_threads: If `True` ignores threads that remain running after
|
ignore_live_threads: If `True` ignores threads that remain running after a
|
||||||
a grace period when joining threads via the coordinator, instead of
|
grace period when joining threads via the coordinator, instead of
|
||||||
raising a RuntimeError.
|
raising a RuntimeError.
|
||||||
"""
|
"""
|
||||||
self._coord.request_stop()
|
self._coord.request_stop()
|
||||||
@ -926,7 +933,9 @@ class Supervisor(object):
|
|||||||
|
|
||||||
# pylint: disable=g-doc-return-or-yield,broad-except
|
# pylint: disable=g-doc-return-or-yield,broad-except
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def managed_session(self, master="", config=None,
|
def managed_session(self,
|
||||||
|
master="",
|
||||||
|
config=None,
|
||||||
start_standard_services=True,
|
start_standard_services=True,
|
||||||
close_summary_writer=True):
|
close_summary_writer=True):
|
||||||
"""Returns a context manager for a managed session.
|
"""Returns a context manager for a managed session.
|
||||||
@ -940,7 +949,7 @@ class Supervisor(object):
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
def train():
|
def train():
|
||||||
sv = tf.train.Supervisor(...)
|
sv = tf.compat.v1.train.Supervisor(...)
|
||||||
with sv.managed_session(<master>) as sess:
|
with sv.managed_session(<master>) as sess:
|
||||||
for step in xrange(..):
|
for step in xrange(..):
|
||||||
if sv.should_stop():
|
if sv.should_stop():
|
||||||
@ -973,14 +982,14 @@ class Supervisor(object):
|
|||||||
the training loop and are considered normal termination.
|
the training loop and are considered normal termination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
master: name of the TensorFlow master to use. See the `tf.Session`
|
master: name of the TensorFlow master to use. See the
|
||||||
constructor for how this is interpreted.
|
`tf.compat.v1.Session` constructor for how this is interpreted.
|
||||||
config: Optional `ConfigProto` proto used to configure the session.
|
config: Optional `ConfigProto` proto used to configure the session. Passed
|
||||||
Passed as-is to create the session.
|
as-is to create the session.
|
||||||
start_standard_services: Whether to start the standard services,
|
start_standard_services: Whether to start the standard services, such as
|
||||||
such as checkpoint, summary and step counter.
|
checkpoint, summary and step counter.
|
||||||
close_summary_writer: Whether to close the summary writer when
|
close_summary_writer: Whether to close the summary writer when closing the
|
||||||
closing the session. Defaults to True.
|
session. Defaults to True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A context manager that yields a `Session` restored from the latest
|
A context manager that yields a `Session` restored from the latest
|
||||||
@ -989,7 +998,8 @@ class Supervisor(object):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
sess = self.prepare_or_wait_for_session(
|
sess = self.prepare_or_wait_for_session(
|
||||||
master=master, config=config,
|
master=master,
|
||||||
|
config=config,
|
||||||
start_standard_services=start_standard_services)
|
start_standard_services=start_standard_services)
|
||||||
yield sess
|
yield sess
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1011,6 +1021,7 @@ class Supervisor(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Silently ignore exceptions raised by close().
|
# Silently ignore exceptions raised by close().
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# pylint: enable=g-doc-return-or-yield,broad-except
|
# pylint: enable=g-doc-return-or-yield,broad-except
|
||||||
|
|
||||||
|
|
||||||
@ -1030,8 +1041,8 @@ class SVSummaryThread(coordinator.LooperThread):
|
|||||||
|
|
||||||
def run_loop(self):
|
def run_loop(self):
|
||||||
if self._sv.global_step is not None:
|
if self._sv.global_step is not None:
|
||||||
summary_strs, global_step = self._sess.run([self._sv.summary_op,
|
summary_strs, global_step = self._sess.run(
|
||||||
self._sv.global_step])
|
[self._sv.summary_op, self._sv.global_step])
|
||||||
else:
|
else:
|
||||||
summary_strs = self._sess.run(self._sv.summary_op)
|
summary_strs = self._sess.run(self._sv.summary_op)
|
||||||
global_step = None
|
global_step = None
|
||||||
@ -1063,8 +1074,7 @@ class SVStepCounterThread(coordinator.LooperThread):
|
|||||||
|
|
||||||
def start_loop(self):
|
def start_loop(self):
|
||||||
self._last_time = time.time()
|
self._last_time = time.time()
|
||||||
self._last_step = training_util.global_step(
|
self._last_step = training_util.global_step(self._sess, self._step_counter)
|
||||||
self._sess, self._step_counter)
|
|
||||||
|
|
||||||
def run_loop(self):
|
def run_loop(self):
|
||||||
# Count the steps.
|
# Count the steps.
|
||||||
@ -1080,12 +1090,13 @@ class SVStepCounterThread(coordinator.LooperThread):
|
|||||||
steps_per_sec = added_steps / elapsed_time
|
steps_per_sec = added_steps / elapsed_time
|
||||||
else:
|
else:
|
||||||
steps_per_sec = float("inf")
|
steps_per_sec = float("inf")
|
||||||
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
|
summary = Summary(value=[
|
||||||
simple_value=steps_per_sec)])
|
Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
|
||||||
|
])
|
||||||
if self._sv.summary_writer:
|
if self._sv.summary_writer:
|
||||||
self._sv.summary_writer.add_summary(summary, current_step)
|
self._sv.summary_writer.add_summary(summary, current_step)
|
||||||
logging.log_first_n(logging.INFO, "%s: %g", 10,
|
logging.log_first_n(logging.INFO, "%s: %g", 10, self._summary_tag,
|
||||||
self._summary_tag, steps_per_sec)
|
steps_per_sec)
|
||||||
|
|
||||||
|
|
||||||
class SVTimerCheckpointThread(coordinator.LooperThread):
|
class SVTimerCheckpointThread(coordinator.LooperThread):
|
||||||
@ -1104,13 +1115,13 @@ class SVTimerCheckpointThread(coordinator.LooperThread):
|
|||||||
|
|
||||||
def run_loop(self):
|
def run_loop(self):
|
||||||
logging.info("Saving checkpoint to path %s", self._sv.save_path)
|
logging.info("Saving checkpoint to path %s", self._sv.save_path)
|
||||||
self._sv.saver.save(self._sess, self._sv.save_path,
|
self._sv.saver.save(
|
||||||
global_step=self._sv.global_step)
|
self._sess, self._sv.save_path, global_step=self._sv.global_step)
|
||||||
if self._sv.summary_writer and self._sv.global_step is not None:
|
if self._sv.summary_writer and self._sv.global_step is not None:
|
||||||
current_step = training_util.global_step(self._sess, self._sv.global_step)
|
current_step = training_util.global_step(self._sess, self._sv.global_step)
|
||||||
self._sv.summary_writer.add_session_log(
|
self._sv.summary_writer.add_session_log(
|
||||||
SessionLog(status=SessionLog.CHECKPOINT,
|
SessionLog(
|
||||||
checkpoint_path=self._sv.save_path),
|
status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path),
|
||||||
current_step)
|
current_step)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
|||||||
# Note that if you want to have 2 backup replicas, you can change
|
# Note that if you want to have 2 backup replicas, you can change
|
||||||
# total_num_replicas=52 and make sure this number matches how many physical
|
# total_num_replicas=52 and make sure this number matches how many physical
|
||||||
# replicas you started in your job.
|
# replicas you started in your job.
|
||||||
opt = tf.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
|
opt = tf.compat.v1.train.SyncReplicasOptimizer(opt, replicas_to_aggregate=50,
|
||||||
total_num_replicas=50)
|
total_num_replicas=50)
|
||||||
|
|
||||||
# Some models have startup_delays to help stabilize the model but when using
|
# Some models have startup_delays to help stabilize the model but when using
|
||||||
|
@ -38,11 +38,9 @@ from tensorflow.python.util import nest
|
|||||||
from tensorflow.python.util import serialization
|
from tensorflow.python.util import serialization
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
|
|
||||||
# Key where the object graph proto is saved in a TensorBundle
|
# Key where the object graph proto is saved in a TensorBundle
|
||||||
OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
|
OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
|
||||||
|
|
||||||
|
|
||||||
# A key indicating a variable's value in an object's checkpointed Tensors
|
# A key indicating a variable's value in an object's checkpointed Tensors
|
||||||
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
|
# (Trackable._gather_saveables_for_checkpoint). If this is the only key and
|
||||||
# the object has no dependencies, then its value may be restored on object
|
# the object has no dependencies, then its value may be restored on object
|
||||||
@ -74,8 +72,7 @@ class CheckpointInitialValue(ops.Tensor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, checkpoint_position, shape=None):
|
def __init__(self, checkpoint_position, shape=None):
|
||||||
self.wrapped_value = checkpoint_position.value_tensors()[
|
self.wrapped_value = checkpoint_position.value_tensors()[VARIABLE_VALUE_KEY]
|
||||||
VARIABLE_VALUE_KEY]
|
|
||||||
if shape:
|
if shape:
|
||||||
# We need to set the static shape information on the initializer if
|
# We need to set the static shape information on the initializer if
|
||||||
# possible so we don't get a variable with an unknown shape.
|
# possible so we don't get a variable with an unknown shape.
|
||||||
@ -97,8 +94,8 @@ class NoRestoreSaveable(saveable_object.SaveableObject):
|
|||||||
"""Embeds a tensor in a checkpoint with no restore ops."""
|
"""Embeds a tensor in a checkpoint with no restore ops."""
|
||||||
|
|
||||||
def __init__(self, tensor, name, dtype=None, device=None):
|
def __init__(self, tensor, name, dtype=None, device=None):
|
||||||
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype,
|
spec = saveable_object.SaveSpec(
|
||||||
device=device)
|
tensor, "", name, dtype=dtype, device=device)
|
||||||
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
|
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
@ -123,7 +120,8 @@ class PythonStateSaveable(saveable_object.SaveableObject):
|
|||||||
"""Create a new `SaveableObject` which freezes current state as a constant.
|
"""Create a new `SaveableObject` which freezes current state as a constant.
|
||||||
|
|
||||||
Used when executing eagerly to embed the current state as a constant, or
|
Used when executing eagerly to embed the current state as a constant, or
|
||||||
when creating a static tf.train.Saver with the frozen current Python state.
|
when creating a static tf.compat.v1.train.Saver with the frozen current
|
||||||
|
Python state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
|
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
|
||||||
@ -140,24 +138,26 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The checkpoint key to write to.
|
name: The checkpoint key to write to.
|
||||||
state_callback: A function taking no arguments which returns a
|
state_callback: A function taking no arguments which returns a string.
|
||||||
string. This function is run every time a checkpoint is written.
|
This function is run every time a checkpoint is written.
|
||||||
restore_callback: A function taking a Python string, used to restore
|
restore_callback: A function taking a Python string, used to restore
|
||||||
state. Optional; defaults to doing nothing, in which case it is ignored
|
state. Optional; defaults to doing nothing, in which case it is ignored
|
||||||
by status assertions such as assert_consumed().
|
by status assertions such as assert_consumed().
|
||||||
"""
|
"""
|
||||||
self._has_trivial_state_callback = (restore_callback is None)
|
self._has_trivial_state_callback = (restore_callback is None)
|
||||||
|
|
||||||
def _state_callback_wrapper():
|
def _state_callback_wrapper():
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
return state_callback()
|
return state_callback()
|
||||||
|
|
||||||
self._state_callback = _state_callback_wrapper
|
self._state_callback = _state_callback_wrapper
|
||||||
self._restore_callback = restore_callback
|
self._restore_callback = restore_callback
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
self._save_string = constant_op.constant("", dtype=dtypes.string)
|
self._save_string = constant_op.constant("", dtype=dtypes.string)
|
||||||
spec = saveable_object.SaveSpec(
|
spec = saveable_object.SaveSpec(
|
||||||
self._save_string, "", name, dtype=dtypes.string)
|
self._save_string, "", name, dtype=dtypes.string)
|
||||||
super(PythonStringStateSaveable, self).__init__(
|
super(PythonStringStateSaveable, self).__init__(self._save_string, [spec],
|
||||||
self._save_string, [spec], name)
|
name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optional_restore(self):
|
def optional_restore(self):
|
||||||
@ -170,8 +170,10 @@ class PythonStringStateSaveable(PythonStateSaveable):
|
|||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
"""Create a frozen `SaveableObject` which saves the current state."""
|
"""Create a frozen `SaveableObject` which saves the current state."""
|
||||||
|
|
||||||
def _constant_state():
|
def _constant_state():
|
||||||
return constant_op.constant(self._state_callback(), dtype=dtypes.string)
|
return constant_op.constant(self._state_callback(), dtype=dtypes.string)
|
||||||
|
|
||||||
return NoRestoreSaveable(
|
return NoRestoreSaveable(
|
||||||
tensor=_constant_state,
|
tensor=_constant_state,
|
||||||
dtype=dtypes.string,
|
dtype=dtypes.string,
|
||||||
@ -217,6 +219,7 @@ class CheckpointPosition(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
trackable: The object to record a correspondence for.
|
trackable: The object to record a correspondence for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if this is a new assignment, False if this object has already been
|
True if this is a new assignment, False if this object has already been
|
||||||
mapped to a checkpointed `Object` proto.
|
mapped to a checkpointed `Object` proto.
|
||||||
@ -263,21 +266,21 @@ class CheckpointPosition(object):
|
|||||||
# consistent (if the dependency DAG is not a tree then there are
|
# consistent (if the dependency DAG is not a tree then there are
|
||||||
# multiple paths to the same object).
|
# multiple paths to the same object).
|
||||||
if current_assignment is not trackable:
|
if current_assignment is not trackable:
|
||||||
logging.warning(
|
logging.warning((
|
||||||
("Inconsistent references when loading the checkpoint into this "
|
"Inconsistent references when loading the checkpoint into this "
|
||||||
"object graph. Either the Trackable object references in the "
|
"object graph. Either the Trackable object references in the "
|
||||||
"Python program have changed in an incompatible way, or the "
|
"Python program have changed in an incompatible way, or the "
|
||||||
"checkpoint was generated in an incompatible program.\n\nTwo "
|
"checkpoint was generated in an incompatible program.\n\nTwo "
|
||||||
"checkpoint references resolved to different objects (%s and %s).")
|
"checkpoint references resolved to different objects (%s and %s)."),
|
||||||
% (current_assignment, trackable))
|
current_assignment, trackable)
|
||||||
return False # Not a new assignment
|
return False # Not a new assignment
|
||||||
|
|
||||||
def is_simple_variable(self):
|
def is_simple_variable(self):
|
||||||
"""Determine whether this value is restorable with a Tensor initializer."""
|
"""Determine whether this value is restorable with a Tensor initializer."""
|
||||||
attributes = self.object_proto.attributes
|
attributes = self.object_proto.attributes
|
||||||
return (len(attributes) == 1
|
return (len(attributes) == 1 and
|
||||||
and attributes[0].name == VARIABLE_VALUE_KEY
|
attributes[0].name == VARIABLE_VALUE_KEY and
|
||||||
and not self.object_proto.children)
|
not self.object_proto.children)
|
||||||
|
|
||||||
def value_tensors(self):
|
def value_tensors(self):
|
||||||
"""Create value `Tensor`s for this object's attributes.
|
"""Create value `Tensor`s for this object's attributes.
|
||||||
@ -335,8 +338,9 @@ class CheckpointPosition(object):
|
|||||||
# If we've already created and cached a SaveableObject for this
|
# If we've already created and cached a SaveableObject for this
|
||||||
# attribute, we can re-use it to avoid re-creating some ops when graph
|
# attribute, we can re-use it to avoid re-creating some ops when graph
|
||||||
# building.
|
# building.
|
||||||
saveable_list = saveables_cache.get(
|
saveable_list = saveables_cache.get(self.trackable,
|
||||||
self.trackable, {}).get(serialized_tensor.name, (None,))
|
{}).get(serialized_tensor.name,
|
||||||
|
(None,))
|
||||||
if len(saveable_list) == 1:
|
if len(saveable_list) == 1:
|
||||||
# Almost every attribute will have exactly one SaveableObject.
|
# Almost every attribute will have exactly one SaveableObject.
|
||||||
saveable, = saveable_list
|
saveable, = saveable_list
|
||||||
@ -370,8 +374,8 @@ class CheckpointPosition(object):
|
|||||||
else:
|
else:
|
||||||
saveable = saveable_factory
|
saveable = saveable_factory
|
||||||
if saveables_cache is not None:
|
if saveables_cache is not None:
|
||||||
saveables_cache.setdefault(
|
saveables_cache.setdefault(self.trackable,
|
||||||
self.trackable, {})[serialized_tensor.name] = [saveable]
|
{})[serialized_tensor.name] = [saveable]
|
||||||
if isinstance(saveable, PythonStateSaveable):
|
if isinstance(saveable, PythonStateSaveable):
|
||||||
python_saveables.append(saveable)
|
python_saveables.append(saveable)
|
||||||
else:
|
else:
|
||||||
@ -388,11 +392,10 @@ class CheckpointPosition(object):
|
|||||||
A list of operations when graph building, or an empty list when executing
|
A list of operations when graph building, or an empty list when executing
|
||||||
eagerly.
|
eagerly.
|
||||||
"""
|
"""
|
||||||
(restore_ops,
|
(restore_ops, tensor_saveables,
|
||||||
tensor_saveables,
|
|
||||||
python_saveables) = self._gather_ops_or_named_saveables()
|
python_saveables) = self._gather_ops_or_named_saveables()
|
||||||
restore_ops.extend(self._checkpoint.restore_saveables(
|
restore_ops.extend(
|
||||||
tensor_saveables, python_saveables))
|
self._checkpoint.restore_saveables(tensor_saveables, python_saveables))
|
||||||
return restore_ops
|
return restore_ops
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -416,13 +419,11 @@ class CheckpointPosition(object):
|
|||||||
|
|
||||||
|
|
||||||
_DeferredSlotVariableRestoration = collections.namedtuple(
|
_DeferredSlotVariableRestoration = collections.namedtuple(
|
||||||
"_DeferredSlotVariableRestoration",
|
"_DeferredSlotVariableRestoration", [
|
||||||
[
|
|
||||||
"original_variable",
|
"original_variable",
|
||||||
"slot_variable_id",
|
"slot_variable_id",
|
||||||
"slot_name",
|
"slot_name",
|
||||||
]
|
])
|
||||||
)
|
|
||||||
|
|
||||||
_SlotVariableRestoration = collections.namedtuple(
|
_SlotVariableRestoration = collections.namedtuple(
|
||||||
"_SlotVariableRestoration",
|
"_SlotVariableRestoration",
|
||||||
@ -446,6 +447,7 @@ def no_automatic_dependency_tracking(method):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: The method to decorate.
|
method: The method to decorate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A decorated method which sets and un-sets automatic dependency tracking for
|
A decorated method which sets and un-sets automatic dependency tracking for
|
||||||
the object the method is called on (not thread safe).
|
the object the method is called on (not thread safe).
|
||||||
@ -595,16 +597,21 @@ class Trackable(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The local name of the dependency.
|
name: The local name of the dependency.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Trackable` object, or `None` if no dependency by this name was
|
A `Trackable` object, or `None` if no dependency by this name was
|
||||||
found.
|
found.
|
||||||
"""
|
"""
|
||||||
return self._self_unconditional_dependency_names.get(name, None)
|
return self._self_unconditional_dependency_names.get(name, None)
|
||||||
|
|
||||||
def _add_variable_with_custom_getter(
|
def _add_variable_with_custom_getter(self,
|
||||||
self, name, shape=None, dtype=dtypes.float32,
|
name,
|
||||||
initializer=None, getter=None, overwrite=False,
|
shape=None,
|
||||||
**kwargs_for_getter):
|
dtype=dtypes.float32,
|
||||||
|
initializer=None,
|
||||||
|
getter=None,
|
||||||
|
overwrite=False,
|
||||||
|
**kwargs_for_getter):
|
||||||
"""Restore-on-create for a variable be saved with this `Trackable`.
|
"""Restore-on-create for a variable be saved with this `Trackable`.
|
||||||
|
|
||||||
If the user has requested that this object or another `Trackable` which
|
If the user has requested that this object or another `Trackable` which
|
||||||
@ -640,11 +647,9 @@ class Trackable(object):
|
|||||||
name=name, shape=shape)
|
name=name, shape=shape)
|
||||||
else:
|
else:
|
||||||
checkpoint_initializer = None
|
checkpoint_initializer = None
|
||||||
if (checkpoint_initializer is not None
|
if (checkpoint_initializer is not None and
|
||||||
and not (
|
not (isinstance(initializer, CheckpointInitialValue) and
|
||||||
isinstance(initializer, CheckpointInitialValue)
|
(initializer.restore_uid > checkpoint_initializer.restore_uid))):
|
||||||
and (initializer.restore_uid
|
|
||||||
> checkpoint_initializer.restore_uid))):
|
|
||||||
# If multiple Trackable objects are "creating" the same variable
|
# If multiple Trackable objects are "creating" the same variable
|
||||||
# via the magic of custom getters, the one with the highest restore UID
|
# via the magic of custom getters, the one with the highest restore UID
|
||||||
# (the one called last) has to make the final initializer. If another
|
# (the one called last) has to make the final initializer. If another
|
||||||
@ -654,7 +659,10 @@ class Trackable(object):
|
|||||||
initializer = checkpoint_initializer
|
initializer = checkpoint_initializer
|
||||||
shape = None
|
shape = None
|
||||||
new_variable = getter(
|
new_variable = getter(
|
||||||
name=name, shape=shape, dtype=dtype, initializer=initializer,
|
name=name,
|
||||||
|
shape=shape,
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=initializer,
|
||||||
**kwargs_for_getter)
|
**kwargs_for_getter)
|
||||||
|
|
||||||
# If we set an initializer and the variable processed it, tracking will not
|
# If we set an initializer and the variable processed it, tracking will not
|
||||||
@ -662,8 +670,7 @@ class Trackable(object):
|
|||||||
# is a non-trivial restoration queued, it will handle that. This also
|
# is a non-trivial restoration queued, it will handle that. This also
|
||||||
# handles slot variables.
|
# handles slot variables.
|
||||||
if not overwrite or isinstance(new_variable, Trackable):
|
if not overwrite or isinstance(new_variable, Trackable):
|
||||||
return self._track_trackable(new_variable, name=name,
|
return self._track_trackable(new_variable, name=name, overwrite=overwrite)
|
||||||
overwrite=overwrite)
|
|
||||||
else:
|
else:
|
||||||
# TODO(allenl): Some variable types are not yet supported. Remove this
|
# TODO(allenl): Some variable types are not yet supported. Remove this
|
||||||
# fallback once all get_variable() return types are Trackable.
|
# fallback once all get_variable() return types are Trackable.
|
||||||
@ -681,6 +688,7 @@ class Trackable(object):
|
|||||||
name: The object-local name of the dependency holding the variable's
|
name: The object-local name of the dependency holding the variable's
|
||||||
value.
|
value.
|
||||||
shape: The shape of the variable being loaded into.
|
shape: The shape of the variable being loaded into.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An callable for use as a variable's initializer/initial_value, or None if
|
An callable for use as a variable's initializer/initial_value, or None if
|
||||||
one should not be set (either because there was no variable with this name
|
one should not be set (either because there was no variable with this name
|
||||||
@ -718,8 +726,8 @@ class Trackable(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
trackable: A `Trackable` which this object depends on.
|
trackable: A `Trackable` which this object depends on.
|
||||||
name: A local name for `trackable`, used for loading checkpoints into
|
name: A local name for `trackable`, used for loading checkpoints into the
|
||||||
the correct objects.
|
correct objects.
|
||||||
overwrite: Boolean, whether silently replacing dependencies is OK. Used
|
overwrite: Boolean, whether silently replacing dependencies is OK. Used
|
||||||
for __setattr__, where throwing an error on attribute reassignment would
|
for __setattr__, where throwing an error on attribute reassignment would
|
||||||
be inappropriate.
|
be inappropriate.
|
||||||
@ -734,13 +742,11 @@ class Trackable(object):
|
|||||||
"""
|
"""
|
||||||
self._maybe_initialize_trackable()
|
self._maybe_initialize_trackable()
|
||||||
if not isinstance(trackable, Trackable):
|
if not isinstance(trackable, Trackable):
|
||||||
raise TypeError(
|
raise TypeError(("Trackable._track_trackable() passed type %s, not a "
|
||||||
("Trackable._track_trackable() passed type %s, not a "
|
"Trackable.") % (type(trackable),))
|
||||||
"Trackable.") % (type(trackable),))
|
|
||||||
new_reference = TrackableReference(name=name, ref=trackable)
|
new_reference = TrackableReference(name=name, ref=trackable)
|
||||||
current_object = self._lookup_dependency(name)
|
current_object = self._lookup_dependency(name)
|
||||||
if (current_object is not None
|
if (current_object is not None and current_object is not trackable):
|
||||||
and current_object is not trackable):
|
|
||||||
if not overwrite:
|
if not overwrite:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
("Called Trackable._track_trackable() with name='%s', "
|
("Called Trackable._track_trackable() with name='%s', "
|
||||||
@ -755,8 +761,7 @@ class Trackable(object):
|
|||||||
index] = new_reference
|
index] = new_reference
|
||||||
elif current_object is None:
|
elif current_object is None:
|
||||||
self._self_unconditional_checkpoint_dependencies.append(new_reference)
|
self._self_unconditional_checkpoint_dependencies.append(new_reference)
|
||||||
self._handle_deferred_dependencies(
|
self._handle_deferred_dependencies(name=name, trackable=trackable)
|
||||||
name=name, trackable=trackable)
|
|
||||||
self._self_unconditional_dependency_names[name] = trackable
|
self._self_unconditional_dependency_names[name] = trackable
|
||||||
return trackable
|
return trackable
|
||||||
|
|
||||||
@ -780,8 +785,7 @@ class Trackable(object):
|
|||||||
Args:
|
Args:
|
||||||
name: The name of the dependency within this object (`self`), used to
|
name: The name of the dependency within this object (`self`), used to
|
||||||
match `trackable` with values saved in a checkpoint.
|
match `trackable` with values saved in a checkpoint.
|
||||||
trackable: The Trackable object to restore (inheriting from
|
trackable: The Trackable object to restore (inheriting from `Trackable`).
|
||||||
`Trackable`).
|
|
||||||
"""
|
"""
|
||||||
self._maybe_initialize_trackable()
|
self._maybe_initialize_trackable()
|
||||||
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
|
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||||
@ -809,15 +813,15 @@ class Trackable(object):
|
|||||||
restore_ops = []
|
restore_ops = []
|
||||||
while visit_queue:
|
while visit_queue:
|
||||||
current_position = visit_queue.popleft()
|
current_position = visit_queue.popleft()
|
||||||
restore_ops.extend(nest.flatten(
|
restore_ops.extend(
|
||||||
current_position.trackable # pylint: disable=protected-access
|
nest.flatten(current_position.trackable # pylint: disable=protected-access
|
||||||
._single_restoration_from_checkpoint_position(
|
._single_restoration_from_checkpoint_position(
|
||||||
checkpoint_position=current_position,
|
checkpoint_position=current_position,
|
||||||
visit_queue=visit_queue)))
|
visit_queue=visit_queue)))
|
||||||
return restore_ops
|
return restore_ops
|
||||||
|
|
||||||
def _single_restoration_from_checkpoint_position(
|
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
|
||||||
self, checkpoint_position, visit_queue):
|
visit_queue):
|
||||||
"""Restore this object, and either queue its dependencies or defer them."""
|
"""Restore this object, and either queue its dependencies or defer them."""
|
||||||
self._maybe_initialize_trackable()
|
self._maybe_initialize_trackable()
|
||||||
checkpoint = checkpoint_position.checkpoint
|
checkpoint = checkpoint_position.checkpoint
|
||||||
@ -831,14 +835,13 @@ class Trackable(object):
|
|||||||
restore_ops = ()
|
restore_ops = ()
|
||||||
for child in checkpoint_position.object_proto.children:
|
for child in checkpoint_position.object_proto.children:
|
||||||
child_position = CheckpointPosition(
|
child_position = CheckpointPosition(
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint, proto_id=child.node_id)
|
||||||
proto_id=child.node_id)
|
|
||||||
local_object = self._lookup_dependency(child.local_name)
|
local_object = self._lookup_dependency(child.local_name)
|
||||||
if local_object is None:
|
if local_object is None:
|
||||||
# We don't yet have a dependency registered with this name. Save it
|
# We don't yet have a dependency registered with this name. Save it
|
||||||
# in case we do.
|
# in case we do.
|
||||||
self._deferred_dependencies.setdefault(child.local_name, []).append(
|
self._deferred_dependencies.setdefault(child.local_name,
|
||||||
child_position)
|
[]).append(child_position)
|
||||||
else:
|
else:
|
||||||
if child_position.bind_object(trackable=local_object):
|
if child_position.bind_object(trackable=local_object):
|
||||||
# This object's correspondence is new, so dependencies need to be
|
# This object's correspondence is new, so dependencies need to be
|
||||||
@ -853,7 +856,8 @@ class Trackable(object):
|
|||||||
|
|
||||||
Keys in the returned dictionary are local to this object and in a separate
|
Keys in the returned dictionary are local to this object and in a separate
|
||||||
namespace from dependencies. Values may either be `SaveableObject` factories
|
namespace from dependencies. Values may either be `SaveableObject` factories
|
||||||
or variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s
|
or variables easily converted to `SaveableObject`s (as in
|
||||||
|
`tf.compat.v1.train.Saver`'s
|
||||||
`var_list` constructor argument).
|
`var_list` constructor argument).
|
||||||
|
|
||||||
`SaveableObjects` have a name set, which Trackable needs to generate
|
`SaveableObjects` have a name set, which Trackable needs to generate
|
||||||
@ -861,7 +865,8 @@ class Trackable(object):
|
|||||||
should return a dictionary of callables which take `name` arguments and
|
should return a dictionary of callables which take `name` arguments and
|
||||||
return `SaveableObjects` with that name.
|
return `SaveableObjects` with that name.
|
||||||
|
|
||||||
If this object may also be passed to the global-name-based `tf.train.Saver`,
|
If this object may also be passed to the global-name-based
|
||||||
|
`tf.compat.v1.train.Saver`,
|
||||||
the returned callables should have a default value for their name argument
|
the returned callables should have a default value for their name argument
|
||||||
(i.e. be callable with no arguments).
|
(i.e. be callable with no arguments).
|
||||||
|
|
||||||
@ -884,6 +889,7 @@ class Trackable(object):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return {}
|
return {}
|
||||||
weak_self = weakref.ref(self)
|
weak_self = weakref.ref(self)
|
||||||
|
|
||||||
def _state_callback():
|
def _state_callback():
|
||||||
"""Serializes `self.get_config()` for saving."""
|
"""Serializes `self.get_config()` for saving."""
|
||||||
dereferenced_self = weak_self()
|
dereferenced_self = weak_self()
|
||||||
@ -898,9 +904,12 @@ class Trackable(object):
|
|||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
|
|
||||||
PythonStringStateSaveable,
|
return {
|
||||||
state_callback=_state_callback)}
|
OBJECT_CONFIG_JSON_KEY:
|
||||||
|
functools.partial(
|
||||||
|
PythonStringStateSaveable, state_callback=_state_callback)
|
||||||
|
}
|
||||||
|
|
||||||
def _list_functions_for_serialization(self):
|
def _list_functions_for_serialization(self):
|
||||||
"""Lists the functions of this trackable to serialize.
|
"""Lists the functions of this trackable to serialize.
|
||||||
|
@ -61,8 +61,8 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
"""Specify the checkpoint being loaded.
|
"""Specify the checkpoint being loaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
object_graph_proto: The TrackableObjectGraph protocol buffer
|
object_graph_proto: The TrackableObjectGraph protocol buffer associated
|
||||||
associated with this checkpoint.
|
with this checkpoint.
|
||||||
save_path: A string, the path to the checkpoint, as returned by
|
save_path: A string, the path to the checkpoint, as returned by
|
||||||
`tf.train.latest_checkpoint`.
|
`tf.train.latest_checkpoint`.
|
||||||
save_path_tensor: A string `Tensor` which contains or will be fed the save
|
save_path_tensor: A string `Tensor` which contains or will be fed the save
|
||||||
@ -142,12 +142,10 @@ class _CheckpointRestoreCoordinator(object):
|
|||||||
"""
|
"""
|
||||||
restore_ops = []
|
restore_ops = []
|
||||||
# Eagerly run restorations for Python state.
|
# Eagerly run restorations for Python state.
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(
|
reader = pywrap_tensorflow.NewCheckpointReader(self.save_path_string)
|
||||||
self.save_path_string)
|
|
||||||
for saveable in python_saveables:
|
for saveable in python_saveables:
|
||||||
spec_names = [spec.name for spec in saveable.specs]
|
spec_names = [spec.name for spec in saveable.specs]
|
||||||
saveable.python_restore(
|
saveable.python_restore([reader.get_tensor(name) for name in spec_names])
|
||||||
[reader.get_tensor(name) for name in spec_names])
|
|
||||||
|
|
||||||
# If we have new SaveableObjects, extract and cache restore ops.
|
# If we have new SaveableObjects, extract and cache restore ops.
|
||||||
if tensor_saveables:
|
if tensor_saveables:
|
||||||
@ -205,14 +203,13 @@ class _NameBasedRestoreCoordinator(object):
|
|||||||
# whether it's optional to restore it. If it's optional we don't need
|
# whether it's optional to restore it. If it's optional we don't need
|
||||||
# to make assertions fail.
|
# to make assertions fail.
|
||||||
if not saveable_factory("").optional_restore:
|
if not saveable_factory("").optional_restore:
|
||||||
self.unused_attributes.setdefault(trackable, []).append(
|
self.unused_attributes.setdefault(trackable,
|
||||||
attribute_name)
|
[]).append(attribute_name)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
saveable = saveable_factory
|
saveable = saveable_factory
|
||||||
names_to_saveables = saveable_object_util.op_list_to_dict(
|
names_to_saveables = saveable_object_util.op_list_to_dict(
|
||||||
[saveable],
|
[saveable], convert_variable_to_tensor=False)
|
||||||
convert_variable_to_tensor=False)
|
|
||||||
for name, op in names_to_saveables.items():
|
for name, op in names_to_saveables.items():
|
||||||
for saveable_object in saveable_object_util.saveable_objects_for_op(
|
for saveable_object in saveable_object_util.saveable_objects_for_op(
|
||||||
op=op, name=name):
|
op=op, name=name):
|
||||||
@ -224,8 +221,7 @@ class _NameBasedRestoreCoordinator(object):
|
|||||||
# run_restore_ops/initialize_or_restore on the status object for name-based
|
# run_restore_ops/initialize_or_restore on the status object for name-based
|
||||||
# checkpoints.
|
# checkpoints.
|
||||||
assert context.executing_eagerly()
|
assert context.executing_eagerly()
|
||||||
for saveable in self.globally_named_object_attributes(
|
for saveable in self.globally_named_object_attributes(trackable):
|
||||||
trackable):
|
|
||||||
restored_tensors = []
|
restored_tensors = []
|
||||||
tensor_missing = False
|
tensor_missing = False
|
||||||
for spec in saveable.specs:
|
for spec in saveable.specs:
|
||||||
@ -248,14 +244,18 @@ class _NameBasedRestoreCoordinator(object):
|
|||||||
# Ignores values missing from the checkpoint, as with object-based
|
# Ignores values missing from the checkpoint, as with object-based
|
||||||
# restore. Status assertions can be used to check exact matches,
|
# restore. Status assertions can be used to check exact matches,
|
||||||
# although it's unlikely to ever happen for name-based checkpoints.
|
# although it's unlikely to ever happen for name-based checkpoints.
|
||||||
saveable.restore(restored_tensors=restored_tensors,
|
saveable.restore(
|
||||||
restored_shapes=None)
|
restored_tensors=restored_tensors, restored_shapes=None)
|
||||||
|
|
||||||
|
|
||||||
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
|
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
|
||||||
# or consolidating the implementation with get_variable.
|
# or consolidating the implementation with get_variable.
|
||||||
def _default_getter(name, shape, dtype, initializer=None,
|
def _default_getter(name,
|
||||||
partition_info=None, **kwargs):
|
shape,
|
||||||
|
dtype,
|
||||||
|
initializer=None,
|
||||||
|
partition_info=None,
|
||||||
|
**kwargs):
|
||||||
"""A pared-down version of get_variable which does not reuse variables."""
|
"""A pared-down version of get_variable which does not reuse variables."""
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
shape_object = tensor_shape.as_shape(shape)
|
shape_object = tensor_shape.as_shape(shape)
|
||||||
@ -263,7 +263,9 @@ def _default_getter(name, shape, dtype, initializer=None,
|
|||||||
if initializer is None:
|
if initializer is None:
|
||||||
initializer, initializing_from_value = (
|
initializer, initializing_from_value = (
|
||||||
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
|
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
|
||||||
name=name, shape=shape_object, dtype=dtype))
|
name=name,
|
||||||
|
shape=shape_object,
|
||||||
|
dtype=dtype))
|
||||||
else:
|
else:
|
||||||
initializing_from_value = not callable(initializer)
|
initializing_from_value = not callable(initializer)
|
||||||
# Same logic as get_variable
|
# Same logic as get_variable
|
||||||
@ -276,24 +278,33 @@ def _default_getter(name, shape, dtype, initializer=None,
|
|||||||
# Instantiate initializer if provided initializer is a type object.
|
# Instantiate initializer if provided initializer is a type object.
|
||||||
if isinstance(initializer, type(init_ops.Initializer)):
|
if isinstance(initializer, type(init_ops.Initializer)):
|
||||||
initializer = initializer(dtype=dtype)
|
initializer = initializer(dtype=dtype)
|
||||||
|
|
||||||
def initial_value():
|
def initial_value():
|
||||||
return initializer(
|
return initializer(
|
||||||
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
|
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
|
||||||
|
|
||||||
return variables.VariableV1(
|
return variables.VariableV1(
|
||||||
initial_value=initial_value,
|
initial_value=initial_value,
|
||||||
name=name,
|
name=name,
|
||||||
dtype=variable_dtype,
|
dtype=variable_dtype,
|
||||||
use_resource=True,
|
use_resource=True,
|
||||||
**kwargs
|
**kwargs)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_variable(trackable, name, shape=None, dtype=dtypes.float32,
|
def add_variable(trackable,
|
||||||
initializer=None, trainable=True):
|
name,
|
||||||
|
shape=None,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
initializer=None,
|
||||||
|
trainable=True):
|
||||||
"""Add a variable to a Trackable with no scope influence."""
|
"""Add a variable to a Trackable with no scope influence."""
|
||||||
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
|
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||||
name=name, shape=shape, dtype=dtype,
|
name=name,
|
||||||
initializer=initializer, getter=_default_getter, trainable=trainable)
|
shape=shape,
|
||||||
|
dtype=dtype,
|
||||||
|
initializer=initializer,
|
||||||
|
getter=_default_getter,
|
||||||
|
trainable=trainable)
|
||||||
|
|
||||||
|
|
||||||
def object_metadata(save_path):
|
def object_metadata(save_path):
|
||||||
@ -313,6 +324,7 @@ def object_metadata(save_path):
|
|||||||
Args:
|
Args:
|
||||||
save_path: The path to the checkpoint, as returned by `save` or
|
save_path: The path to the checkpoint, as returned by `save` or
|
||||||
`tf.train.latest_checkpoint`.
|
`tf.train.latest_checkpoint`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
|
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
|
||||||
Raises:
|
Raises:
|
||||||
@ -320,16 +332,14 @@ def object_metadata(save_path):
|
|||||||
"""
|
"""
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
|
||||||
try:
|
try:
|
||||||
object_graph_string = reader.get_tensor(
|
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||||
base.OBJECT_GRAPH_PROTO_KEY)
|
|
||||||
except errors_impl.NotFoundError:
|
except errors_impl.NotFoundError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
('The specified checkpoint "%s" does not appear to be object-based (it '
|
('The specified checkpoint "%s" does not appear to be object-based (it '
|
||||||
'is missing the key "%s"). Likely it was created with a name-based '
|
'is missing the key "%s"). Likely it was created with a name-based '
|
||||||
'saver and does not contain an object dependency graph.') % (
|
"saver and does not contain an object dependency graph.") %
|
||||||
save_path, base.OBJECT_GRAPH_PROTO_KEY))
|
(save_path, base.OBJECT_GRAPH_PROTO_KEY))
|
||||||
object_graph_proto = (
|
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
|
||||||
object_graph_proto.ParseFromString(object_graph_string)
|
object_graph_proto.ParseFromString(object_graph_string)
|
||||||
return object_graph_proto
|
return object_graph_proto
|
||||||
|
|
||||||
@ -343,8 +353,8 @@ def list_objects(root_trackable):
|
|||||||
(i.e. if they would be saved with a checkpoint).
|
(i.e. if they would be saved with a checkpoint).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_trackable: A `Trackable` object whose dependencies should be
|
root_trackable: A `Trackable` object whose dependencies should be flattened.
|
||||||
flattened.
|
|
||||||
Returns:
|
Returns:
|
||||||
A flat list of objects.
|
A flat list of objects.
|
||||||
"""
|
"""
|
||||||
@ -362,12 +372,16 @@ def gather_initializers(root_trackable):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
root_trackable: A `Trackable` object to gather initializers for.
|
root_trackable: A `Trackable` object to gather initializers for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of initialization ops.
|
A list of initialization ops.
|
||||||
"""
|
"""
|
||||||
trackable_objects = list_objects(root_trackable)
|
trackable_objects = list_objects(root_trackable)
|
||||||
return [c.initializer for c in trackable_objects
|
return [
|
||||||
if hasattr(c, "initializer") and c.initializer is not None]
|
c.initializer
|
||||||
|
for c in trackable_objects
|
||||||
|
if hasattr(c, "initializer") and c.initializer is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
@ -380,7 +394,7 @@ def capture_dependencies(template):
|
|||||||
object to add dependencies on variables created in a block of code which is
|
object to add dependencies on variables created in a block of code which is
|
||||||
not aware of object-based saving (and instead uses variable names
|
not aware of object-based saving (and instead uses variable names
|
||||||
heavily). This is how `Template` objects add dependencies on variables and
|
heavily). This is how `Template` objects add dependencies on variables and
|
||||||
sub-`Template`s. Where possible, use `tf.make_template` directly.
|
sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: The `Template` object to register dependencies with.
|
template: The `Template` object to register dependencies with.
|
||||||
@ -390,8 +404,11 @@ def capture_dependencies(template):
|
|||||||
"""
|
"""
|
||||||
name_prefix = template.variable_scope.name
|
name_prefix = template.variable_scope.name
|
||||||
|
|
||||||
def _trackable_custom_creator(next_creator, name, initial_value,
|
def _trackable_custom_creator(next_creator,
|
||||||
trackable_parent=None, **kwargs):
|
name,
|
||||||
|
initial_value,
|
||||||
|
trackable_parent=None,
|
||||||
|
**kwargs):
|
||||||
"""A variable creation hook which adds Trackable dependencies.
|
"""A variable creation hook which adds Trackable dependencies.
|
||||||
|
|
||||||
Set for example during a `Template`'s first wrapped function
|
Set for example during a `Template`'s first wrapped function
|
||||||
@ -415,21 +432,20 @@ def capture_dependencies(template):
|
|||||||
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
initial_value: See `variable_scope.variable_creator_scope`. Taken
|
||||||
explicitly so the argument can be re-named and used with
|
explicitly so the argument can be re-named and used with
|
||||||
`Trackable._add_variable_with_custom_getter`.
|
`Trackable._add_variable_with_custom_getter`.
|
||||||
trackable_parent: If not None, a more deeply nested trackable
|
trackable_parent: If not None, a more deeply nested trackable object and
|
||||||
object and its name prefix which were passed to `capture_dependencies`
|
its name prefix which were passed to `capture_dependencies` to add a
|
||||||
to add a dependency on (rather than depending on the variable directly).
|
dependency on (rather than depending on the variable directly).
|
||||||
**kwargs: Passed through to the next creator.
|
**kwargs: Passed through to the next creator.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output of `next_creator`: the fetched/created variable object.
|
The output of `next_creator`: the fetched/created variable object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
|
||||||
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
|
||||||
# we don't want to propagate.
|
# we don't want to propagate.
|
||||||
return next_creator(
|
return next_creator(initial_value=initializer, name=name, **inner_kwargs)
|
||||||
initial_value=initializer,
|
|
||||||
name=name,
|
|
||||||
**inner_kwargs)
|
|
||||||
if name is not None and name.startswith(name_prefix):
|
if name is not None and name.startswith(name_prefix):
|
||||||
scope_stripped_name = name[len(name_prefix) + 1:]
|
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||||
if not trackable_parent:
|
if not trackable_parent:
|
||||||
@ -450,8 +466,10 @@ def capture_dependencies(template):
|
|||||||
name=parent_name_prefix[len(name_prefix) + 1:],
|
name=parent_name_prefix[len(name_prefix) + 1:],
|
||||||
overwrite=True)
|
overwrite=True)
|
||||||
return next_creator(
|
return next_creator(
|
||||||
name=name, initial_value=initial_value,
|
name=name,
|
||||||
trackable_parent=(template, name_prefix), **kwargs)
|
initial_value=initial_value,
|
||||||
|
trackable_parent=(template, name_prefix),
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
with variable_scope.variable_creator_scope(_trackable_custom_creator):
|
with variable_scope.variable_creator_scope(_trackable_custom_creator):
|
||||||
yield
|
yield
|
||||||
@ -490,9 +508,8 @@ def streaming_restore(status, session=None):
|
|||||||
"""When graph building, runs restore ops as soon as they come in.
|
"""When graph building, runs restore ops as soon as they come in.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
status: A _LoadStatus objects from an object-based saver's
|
status: A _LoadStatus objects from an object-based saver's restore().
|
||||||
restore(). Streaming restore from name-based checkpoints is not currently
|
Streaming restore from name-based checkpoints is not currently supported.
|
||||||
supported.
|
|
||||||
session: A session to run new restore ops in.
|
session: A session to run new restore ops in.
|
||||||
"""
|
"""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -553,13 +570,13 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
if self._checkpoint.slot_restorations:
|
if self._checkpoint.slot_restorations:
|
||||||
# Sanity check; this collection should be clear if everything has been
|
# Sanity check; this collection should be clear if everything has been
|
||||||
# restored.
|
# restored.
|
||||||
raise AssertionError("Unresolved slot restorations: %s" % (
|
raise AssertionError("Unresolved slot restorations: %s" %
|
||||||
self._checkpoint.slot_restorations,))
|
(self._checkpoint.slot_restorations,))
|
||||||
if self._checkpoint.unused_attributes:
|
if self._checkpoint.unused_attributes:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
("Unused attributes in these objects (the attributes exist in the "
|
("Unused attributes in these objects (the attributes exist in the "
|
||||||
"checkpoint but not in the objects): %s") % (
|
"checkpoint but not in the objects): %s") %
|
||||||
list(self._checkpoint.unused_attributes.items()),))
|
(list(self._checkpoint.unused_attributes.items()),))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def assert_existing_objects_matched(self):
|
def assert_existing_objects_matched(self):
|
||||||
@ -581,10 +598,10 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
"""
|
"""
|
||||||
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
|
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
|
||||||
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
|
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
|
||||||
if (trackable is not None
|
if (trackable is not None and
|
||||||
and trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
|
trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
|
||||||
raise AssertionError(
|
raise AssertionError("Object not assigned a value from checkpoint: %s" %
|
||||||
"Object not assigned a value from checkpoint: %s" % (node,))
|
(node,))
|
||||||
for trackable_object in self._graph_view.list_objects():
|
for trackable_object in self._graph_view.list_objects():
|
||||||
# Remove data structures that do not contain any variables from
|
# Remove data structures that do not contain any variables from
|
||||||
# restoration checks.
|
# restoration checks.
|
||||||
@ -594,14 +611,14 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
continue
|
continue
|
||||||
self._checkpoint.all_python_objects.add(trackable_object)
|
self._checkpoint.all_python_objects.add(trackable_object)
|
||||||
unused_python_objects = (
|
unused_python_objects = (
|
||||||
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
|
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects) -
|
||||||
- object_identity.ObjectIdentitySet(
|
object_identity.ObjectIdentitySet(
|
||||||
self._checkpoint.object_by_proto_id.values()))
|
self._checkpoint.object_by_proto_id.values()))
|
||||||
if unused_python_objects:
|
if unused_python_objects:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
("Some Python objects were not bound to checkpointed values, likely "
|
("Some Python objects were not bound to checkpointed values, likely "
|
||||||
"due to changes in the Python program: %s")
|
"due to changes in the Python program: %s") %
|
||||||
% (list(unused_python_objects),))
|
(list(unused_python_objects),))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def assert_nontrivial_match(self):
|
def assert_nontrivial_match(self):
|
||||||
@ -610,8 +627,7 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
self._checkpoint.all_python_objects.add(trackable_object)
|
self._checkpoint.all_python_objects.add(trackable_object)
|
||||||
if len(self._checkpoint.object_by_proto_id) <= 1:
|
if len(self._checkpoint.object_by_proto_id) <= 1:
|
||||||
unused_python_objects = (
|
unused_python_objects = (
|
||||||
object_identity.ObjectIdentitySet(
|
object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
|
||||||
self._checkpoint.all_python_objects)
|
|
||||||
- object_identity.ObjectIdentitySet(
|
- object_identity.ObjectIdentitySet(
|
||||||
self._checkpoint.object_by_proto_id.values()))
|
self._checkpoint.object_by_proto_id.values()))
|
||||||
if unused_python_objects:
|
if unused_python_objects:
|
||||||
@ -622,8 +638,8 @@ class CheckpointLoadStatus(_LoadStatus):
|
|||||||
"checkpointed value: %s") % (list(unused_python_objects),))
|
"checkpointed value: %s") % (list(unused_python_objects),))
|
||||||
else:
|
else:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Nothing to load. No dependencies have been added to %s yet." % (
|
"Nothing to load. No dependencies have been added to %s yet." %
|
||||||
self._graph_view.root,))
|
(self._graph_view.root,))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def run_restore_ops(self, session=None):
|
def run_restore_ops(self, session=None):
|
||||||
@ -760,8 +776,8 @@ class NameBasedSaverStatus(_LoadStatus):
|
|||||||
unused_attributes = dict(self._checkpoint.unused_attributes)
|
unused_attributes = dict(self._checkpoint.unused_attributes)
|
||||||
if unused_attributes:
|
if unused_attributes:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Some objects had attributes which were not restored: %s"
|
"Some objects had attributes which were not restored: %s" %
|
||||||
% (unused_attributes,))
|
(unused_attributes,))
|
||||||
for trackable in self._graph_view.list_objects():
|
for trackable in self._graph_view.list_objects():
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
trackable._maybe_initialize_trackable()
|
trackable._maybe_initialize_trackable()
|
||||||
@ -799,12 +815,11 @@ class NameBasedSaverStatus(_LoadStatus):
|
|||||||
continue
|
continue
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
saveable_objects.extend(
|
saveable_objects.extend(
|
||||||
self._checkpoint.globally_named_object_attributes(
|
self._checkpoint.globally_named_object_attributes(trackable))
|
||||||
trackable))
|
|
||||||
return saveable_objects
|
return saveable_objects
|
||||||
|
|
||||||
def run_restore_ops(self, session=None):
|
def run_restore_ops(self, session=None):
|
||||||
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
|
"""Load the name-based checkpoint using a new `tf.compat.v1.train.Saver`."""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return # Nothing to do, variables are restored on creation.
|
return # Nothing to do, variables are restored on creation.
|
||||||
if session is None:
|
if session is None:
|
||||||
@ -840,7 +855,8 @@ class TrackableSaver(object):
|
|||||||
"""Saves and restores a `Trackable` object and its dependencies.
|
"""Saves and restores a `Trackable` object and its dependencies.
|
||||||
|
|
||||||
See `Trackable` for details of dependency management. `Saver` wraps
|
See `Trackable` for details of dependency management. `Saver` wraps
|
||||||
`tf.train.Saver` for saving, including extra information about the graph of
|
`tf.compat.v1.train.Saver` for saving, including extra information about the
|
||||||
|
graph of
|
||||||
dependencies between Python objects. When restoring, it uses this information
|
dependencies between Python objects. When restoring, it uses this information
|
||||||
about the save-time dependency graph to more robustly match objects with their
|
about the save-time dependency graph to more robustly match objects with their
|
||||||
checkpointed values. When executing eagerly, it supports restoring variables
|
checkpointed values. When executing eagerly, it supports restoring variables
|
||||||
@ -851,7 +867,8 @@ class TrackableSaver(object):
|
|||||||
checkpoint was written. To avoid breaking existing checkpoints when modifying
|
checkpoint was written. To avoid breaking existing checkpoints when modifying
|
||||||
a class, dependency names (the names of attributes to which `Trackable`
|
a class, dependency names (the names of attributes to which `Trackable`
|
||||||
objects are assigned) may not change. These names are local to objects, in
|
objects are assigned) may not change. These names are local to objects, in
|
||||||
contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and
|
contrast to the `Variable.name`-based save/restore from
|
||||||
|
`tf.compat.v1.train.Saver`, and
|
||||||
so allow additional program transformations.
|
so allow additional program transformations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -877,8 +894,7 @@ class TrackableSaver(object):
|
|||||||
self._restore_op_cache = {}
|
self._restore_op_cache = {}
|
||||||
self._graph_view = graph_view
|
self._graph_view = graph_view
|
||||||
|
|
||||||
def _gather_saveables(
|
def _gather_saveables(self, object_graph_tensor=None):
|
||||||
self, object_graph_tensor=None):
|
|
||||||
"""Wraps _serialize_object_graph to include the object graph proto."""
|
"""Wraps _serialize_object_graph to include the object graph proto."""
|
||||||
(named_saveable_objects, graph_proto,
|
(named_saveable_objects, graph_proto,
|
||||||
feed_additions) = self._graph_view.serialize_object_graph()
|
feed_additions) = self._graph_view.serialize_object_graph()
|
||||||
@ -892,14 +908,12 @@ class TrackableSaver(object):
|
|||||||
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
|
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
|
||||||
named_saveable_objects.append(
|
named_saveable_objects.append(
|
||||||
base.NoRestoreSaveable(
|
base.NoRestoreSaveable(
|
||||||
tensor=object_graph_tensor,
|
tensor=object_graph_tensor, name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
|
||||||
return named_saveable_objects, graph_proto, feed_additions
|
return named_saveable_objects, graph_proto, feed_additions
|
||||||
|
|
||||||
def _save_cached_when_graph_building(
|
def _save_cached_when_graph_building(self,
|
||||||
self,
|
file_prefix,
|
||||||
file_prefix,
|
object_graph_tensor=None):
|
||||||
object_graph_tensor=None):
|
|
||||||
"""Create or retrieve save ops.
|
"""Create or retrieve save ops.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -921,8 +935,7 @@ class TrackableSaver(object):
|
|||||||
# save() is called so they pick up new Tensors passed to their
|
# save() is called so they pick up new Tensors passed to their
|
||||||
# constructors. That means the Saver needs to be copied with a new
|
# constructors. That means the Saver needs to be copied with a new
|
||||||
# var_list.
|
# var_list.
|
||||||
or context.executing_eagerly()
|
or context.executing_eagerly() or ops.inside_function()):
|
||||||
or ops.inside_function()):
|
|
||||||
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
|
saver = functional_saver.MultiDeviceSaver(named_saveable_objects)
|
||||||
save_op = saver.save(file_prefix)
|
save_op = saver.save(file_prefix)
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
@ -954,8 +967,8 @@ class TrackableSaver(object):
|
|||||||
The full path to the checkpoint.
|
The full path to the checkpoint.
|
||||||
"""
|
"""
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
use_session = (not context.executing_eagerly()
|
use_session = (not context.executing_eagerly() and
|
||||||
and not ops.inside_function())
|
not ops.inside_function())
|
||||||
if checkpoint_number:
|
if checkpoint_number:
|
||||||
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
|
file_prefix = "%s-%d" % (file_prefix, checkpoint_number)
|
||||||
if use_session:
|
if use_session:
|
||||||
@ -976,8 +989,7 @@ class TrackableSaver(object):
|
|||||||
|
|
||||||
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
file_io.recursive_create_dir(os.path.dirname(file_prefix))
|
||||||
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
save_path, new_feed_additions = self._save_cached_when_graph_building(
|
||||||
file_prefix=file_prefix_tensor,
|
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
|
||||||
object_graph_tensor=object_graph_tensor)
|
|
||||||
if new_feed_additions:
|
if new_feed_additions:
|
||||||
feed_dict.update(new_feed_additions)
|
feed_dict.update(new_feed_additions)
|
||||||
if not use_session:
|
if not use_session:
|
||||||
@ -1024,7 +1036,7 @@ class TrackableSaver(object):
|
|||||||
If the checkpoint has not been consumed completely, then the list of restore
|
If the checkpoint has not been consumed completely, then the list of restore
|
||||||
ops will grow as more objects are added to the dependency graph.
|
ops will grow as more objects are added to the dependency graph.
|
||||||
|
|
||||||
Name-based `tf.train.Saver` checkpoints can be loaded using this
|
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
|
||||||
method. There is no deferred loading, and names are used to match
|
method. There is no deferred loading, and names are used to match
|
||||||
variables. No restore ops are created/run until `run_restore_ops()` or
|
variables. No restore ops are created/run until `run_restore_ops()` or
|
||||||
`initialize_or_restore()` are called on the returned status object, even
|
`initialize_or_restore()` are called on the returned status object, even
|
||||||
@ -1035,9 +1047,9 @@ class TrackableSaver(object):
|
|||||||
save_path: The path to the checkpoint, as returned by `save` or
|
save_path: The path to the checkpoint, as returned by `save` or
|
||||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||||
object which may run initializers for objects in the dependency
|
object which may run initializers for objects in the dependency graph.
|
||||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
If the checkpoint was written by the name-based
|
||||||
names are used to match variables.
|
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
@ -1057,8 +1069,7 @@ class TrackableSaver(object):
|
|||||||
else:
|
else:
|
||||||
dtype_map = reader.get_variable_to_dtype_map()
|
dtype_map = reader.get_variable_to_dtype_map()
|
||||||
try:
|
try:
|
||||||
object_graph_string = reader.get_tensor(
|
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
|
||||||
base.OBJECT_GRAPH_PROTO_KEY)
|
|
||||||
except errors_impl.NotFoundError:
|
except errors_impl.NotFoundError:
|
||||||
# The object graph proto does not exist in this checkpoint. Try the
|
# The object graph proto does not exist in this checkpoint. Try the
|
||||||
# name-based compatibility mode.
|
# name-based compatibility mode.
|
||||||
@ -1069,8 +1080,7 @@ class TrackableSaver(object):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
existing_trackable._maybe_initialize_trackable()
|
existing_trackable._maybe_initialize_trackable()
|
||||||
existing_trackable._name_based_restores.add(restore_coordinator)
|
existing_trackable._name_based_restores.add(restore_coordinator)
|
||||||
existing_trackable._name_based_attribute_restore(
|
existing_trackable._name_based_attribute_restore(restore_coordinator)
|
||||||
restore_coordinator)
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
return NameBasedSaverStatus(
|
return NameBasedSaverStatus(
|
||||||
restore_coordinator, graph_view=self._graph_view)
|
restore_coordinator, graph_view=self._graph_view)
|
||||||
@ -1085,8 +1095,7 @@ class TrackableSaver(object):
|
|||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
file_prefix_tensor = constant_op.constant(save_path)
|
file_prefix_tensor = constant_op.constant(save_path)
|
||||||
file_prefix_feed_dict = None
|
file_prefix_feed_dict = None
|
||||||
object_graph_proto = (
|
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
|
||||||
trackable_object_graph_pb2.TrackableObjectGraph())
|
|
||||||
object_graph_proto.ParseFromString(object_graph_string)
|
object_graph_proto.ParseFromString(object_graph_string)
|
||||||
checkpoint = _CheckpointRestoreCoordinator(
|
checkpoint = _CheckpointRestoreCoordinator(
|
||||||
object_graph_proto=object_graph_proto,
|
object_graph_proto=object_graph_proto,
|
||||||
@ -1094,8 +1103,8 @@ class TrackableSaver(object):
|
|||||||
save_path_tensor=file_prefix_tensor,
|
save_path_tensor=file_prefix_tensor,
|
||||||
restore_op_cache=self._restore_op_cache,
|
restore_op_cache=self._restore_op_cache,
|
||||||
graph_view=self._graph_view)
|
graph_view=self._graph_view)
|
||||||
base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(
|
base.CheckpointPosition(
|
||||||
self._graph_view.root)
|
checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root)
|
||||||
load_status = CheckpointLoadStatus(
|
load_status = CheckpointLoadStatus(
|
||||||
checkpoint,
|
checkpoint,
|
||||||
graph_view=self._graph_view,
|
graph_view=self._graph_view,
|
||||||
@ -1104,7 +1113,7 @@ class TrackableSaver(object):
|
|||||||
|
|
||||||
|
|
||||||
def frozen_saver(root_trackable):
|
def frozen_saver(root_trackable):
|
||||||
"""Creates a static `tf.train.Saver` from a trackable object.
|
"""Creates a static `tf.compat.v1.train.Saver` from a trackable object.
|
||||||
|
|
||||||
The returned `Saver` saves object-based checkpoints, but these checkpoints
|
The returned `Saver` saves object-based checkpoints, but these checkpoints
|
||||||
will no longer reflect structural changes to the object graph, only changes to
|
will no longer reflect structural changes to the object graph, only changes to
|
||||||
@ -1135,9 +1144,9 @@ def saver_with_op_caching(obj):
|
|||||||
saveables_cache = None
|
saveables_cache = None
|
||||||
else:
|
else:
|
||||||
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
||||||
return TrackableSaver(graph_view_lib.ObjectGraphView(
|
return TrackableSaver(
|
||||||
weakref.ref(obj),
|
graph_view_lib.ObjectGraphView(
|
||||||
saveables_cache=saveables_cache))
|
weakref.ref(obj), saveables_cache=saveables_cache))
|
||||||
|
|
||||||
|
|
||||||
# Mentions graph building / Sessions. The v2 version is below.
|
# Mentions graph building / Sessions. The v2 version is below.
|
||||||
@ -1146,7 +1155,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
"""Groups trackable objects, saving and restoring them.
|
"""Groups trackable objects, saving and restoring them.
|
||||||
|
|
||||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||||
that contain trackable state, such as `tf.train.Optimizer`
|
that contain trackable state, such as `tf.compat.v1.train.Optimizer`
|
||||||
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
||||||
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
||||||
maintains a `save_counter` for numbering checkpoints.
|
maintains a `save_counter` for numbering checkpoints.
|
||||||
@ -1164,7 +1173,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||||
train_op = optimizer.minimize( ... )
|
train_op = optimizer.minimize( ... )
|
||||||
status.assert_consumed() # Optional sanity checks.
|
status.assert_consumed() # Optional sanity checks.
|
||||||
with tf.Session() as session:
|
with tf.compat.v1.Session() as session:
|
||||||
# Use the Session to restore variables, or initialize them if
|
# Use the Session to restore variables, or initialize them if
|
||||||
# tf.train.latest_checkpoint returned None.
|
# tf.train.latest_checkpoint returned None.
|
||||||
status.initialize_or_restore(session)
|
status.initialize_or_restore(session)
|
||||||
@ -1179,7 +1188,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import os
|
import os
|
||||||
|
|
||||||
tf.enable_eager_execution()
|
tf.compat.v1.enable_eager_execution()
|
||||||
|
|
||||||
checkpoint_directory = "/tmp/training_checkpoints"
|
checkpoint_directory = "/tmp/training_checkpoints"
|
||||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||||
@ -1193,13 +1202,14 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
```
|
```
|
||||||
|
|
||||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
||||||
checkpoints, in contrast to `tf.train.Saver` which writes and reads
|
checkpoints, in contrast to `tf.compat.v1.train.Saver` which writes and reads
|
||||||
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
|
`variable.name` based checkpoints. Object-based checkpointing saves a graph of
|
||||||
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
|
dependencies between Python objects (`Layer`s, `Optimizer`s, `Variable`s,
|
||||||
etc.) with named edges, and this graph is used to match variables when
|
etc.) with named edges, and this graph is used to match variables when
|
||||||
restoring a checkpoint. It can be more robust to changes in the Python
|
restoring a checkpoint. It can be more robust to changes in the Python
|
||||||
program, and helps to support restore-on-create for variables when executing
|
program, and helps to support restore-on-create for variables when executing
|
||||||
eagerly. Prefer `tf.train.Checkpoint` over `tf.train.Saver` for new code.
|
eagerly. Prefer `tf.train.Checkpoint` over `tf.compat.v1.train.Saver` for new
|
||||||
|
code.
|
||||||
|
|
||||||
`Checkpoint` objects have dependencies on the objects passed as keyword
|
`Checkpoint` objects have dependencies on the objects passed as keyword
|
||||||
arguments to their constructors, and each dependency is given a name that is
|
arguments to their constructors, and each dependency is given a name that is
|
||||||
@ -1244,6 +1254,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||||
saved with the checkpoint. Values must be trackable objects.
|
saved with the checkpoint. Values must be trackable objects.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If objects in `kwargs` are not trackable.
|
ValueError: If objects in `kwargs` are not trackable.
|
||||||
"""
|
"""
|
||||||
@ -1269,8 +1280,12 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
# add_variable creates a dependency named "save_counter"; NoDependency
|
# add_variable creates a dependency named "save_counter"; NoDependency
|
||||||
# prevents creating a second dependency named "_save_counter".
|
# prevents creating a second dependency named "_save_counter".
|
||||||
self._save_counter = data_structures.NoDependency(
|
self._save_counter = data_structures.NoDependency(
|
||||||
add_variable(self, name="save_counter", initializer=0,
|
add_variable(
|
||||||
dtype=dtypes.int64, trainable=False))
|
self,
|
||||||
|
name="save_counter",
|
||||||
|
initializer=0,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
trainable=False))
|
||||||
|
|
||||||
def write(self, file_prefix, session=None):
|
def write(self, file_prefix, session=None):
|
||||||
"""Writes a training checkpoint.
|
"""Writes a training checkpoint.
|
||||||
@ -1294,9 +1309,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint (i.e. `file_prefix`).
|
The full path to the checkpoint (i.e. `file_prefix`).
|
||||||
"""
|
"""
|
||||||
output = self._saver.save(
|
output = self._saver.save(file_prefix=file_prefix, session=session)
|
||||||
file_prefix=file_prefix,
|
|
||||||
session=session)
|
|
||||||
if tensor_util.is_tensor(output):
|
if tensor_util.is_tensor(output):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return compat.as_str(output.numpy())
|
return compat.as_str(output.numpy())
|
||||||
@ -1370,8 +1383,8 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
checkpoint_number = session.run(self._save_assign_op)
|
checkpoint_number = session.run(self._save_assign_op)
|
||||||
else:
|
else:
|
||||||
checkpoint_number = assign_op.numpy()
|
checkpoint_number = assign_op.numpy()
|
||||||
file_path = self.write("%s-%d" % (file_prefix, checkpoint_number),
|
file_path = self.write(
|
||||||
session=session)
|
"%s-%d" % (file_prefix, checkpoint_number), session=session)
|
||||||
checkpoint_management.update_checkpoint_state_internal(
|
checkpoint_management.update_checkpoint_state_internal(
|
||||||
save_dir=os.path.dirname(file_prefix),
|
save_dir=os.path.dirname(file_prefix),
|
||||||
model_checkpoint_path=file_path,
|
model_checkpoint_path=file_path,
|
||||||
@ -1417,7 +1430,7 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
If the checkpoint has not been consumed completely, then the list of restore
|
If the checkpoint has not been consumed completely, then the list of restore
|
||||||
ops will grow as more objects are added to the dependency graph.
|
ops will grow as more objects are added to the dependency graph.
|
||||||
|
|
||||||
Name-based `tf.train.Saver` checkpoints can be loaded using this
|
Name-based `tf.compat.v1.train.Saver` checkpoints can be loaded using this
|
||||||
method. Names are used to match variables. No restore ops are created/run
|
method. Names are used to match variables. No restore ops are created/run
|
||||||
until `run_restore_ops()` or `initialize_or_restore()` are called on the
|
until `run_restore_ops()` or `initialize_or_restore()` are called on the
|
||||||
returned status object when graph building, but there is restore-on-creation
|
returned status object when graph building, but there is restore-on-creation
|
||||||
@ -1428,9 +1441,9 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
save_path: The path to the checkpoint, as returned by `save` or
|
save_path: The path to the checkpoint, as returned by `save` or
|
||||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||||
object which may run initializers for objects in the dependency
|
object which may run initializers for objects in the dependency graph.
|
||||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
If the checkpoint was written by the name-based
|
||||||
names are used to match variables.
|
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
@ -1453,7 +1466,8 @@ class CheckpointV1(tracking.AutoTrackable):
|
|||||||
built, and so has not created any variables, will pass this assertion
|
built, and so has not created any variables, will pass this assertion
|
||||||
but fail `assert_consumed`. Useful when loading part of a larger
|
but fail `assert_consumed`. Useful when loading part of a larger
|
||||||
checkpoint into a new Python program, e.g. a training checkpoint with
|
checkpoint into a new Python program, e.g. a training checkpoint with
|
||||||
a `tf.train.Optimizer` was saved but only the state required for
|
a `tf.compat.v1.train.Optimizer` was saved but only the state required
|
||||||
|
for
|
||||||
inference is being loaded. This method returns the status object, and
|
inference is being loaded. This method returns the status object, and
|
||||||
so may be chained with `initialize_or_restore` or `run_restore_ops`.
|
so may be chained with `initialize_or_restore` or `run_restore_ops`.
|
||||||
|
|
||||||
@ -1488,7 +1502,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
"""Groups trackable objects, saving and restoring them.
|
"""Groups trackable objects, saving and restoring them.
|
||||||
|
|
||||||
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
`Checkpoint`'s constructor accepts keyword arguments whose values are types
|
||||||
that contain trackable state, such as `tf.train.Optimizer`
|
that contain trackable state, such as `tf.keras.optimizers.Optimizer`
|
||||||
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
implementations, `tf.Variable`, `tf.keras.Layer` implementations, or
|
||||||
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
`tf.keras.Model` implementations. It saves these values with a checkpoint, and
|
||||||
maintains a `save_counter` for numbering checkpoints.
|
maintains a `save_counter` for numbering checkpoints.
|
||||||
@ -1511,7 +1525,8 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
```
|
```
|
||||||
|
|
||||||
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
`Checkpoint.save` and `Checkpoint.restore` write and read object-based
|
||||||
checkpoints, in contrast to TensorFlow 1.x's `tf.train.Saver` which writes and
|
checkpoints, in contrast to TensorFlow 1.x's `tf.compat.v1.train.Saver` which
|
||||||
|
writes and
|
||||||
reads `variable.name` based checkpoints. Object-based checkpointing saves a
|
reads `variable.name` based checkpoints. Object-based checkpointing saves a
|
||||||
graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
|
graph of dependencies between Python objects (`Layer`s, `Optimizer`s,
|
||||||
`Variable`s, etc.) with named edges, and this graph is used to match variables
|
`Variable`s, etc.) with named edges, and this graph is used to match variables
|
||||||
@ -1561,6 +1576,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments are set as attributes of this object, and are
|
**kwargs: Keyword arguments are set as attributes of this object, and are
|
||||||
saved with the checkpoint. Values must be trackable objects.
|
saved with the checkpoint. Values must be trackable objects.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If objects in `kwargs` are not trackable.
|
ValueError: If objects in `kwargs` are not trackable.
|
||||||
"""
|
"""
|
||||||
@ -1586,8 +1602,12 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
# add_variable creates a dependency named "save_counter"; NoDependency
|
# add_variable creates a dependency named "save_counter"; NoDependency
|
||||||
# prevents creating a second dependency named "_save_counter".
|
# prevents creating a second dependency named "_save_counter".
|
||||||
self._save_counter = data_structures.NoDependency(
|
self._save_counter = data_structures.NoDependency(
|
||||||
add_variable(self, name="save_counter", initializer=0,
|
add_variable(
|
||||||
dtype=dtypes.int64, trainable=False))
|
self,
|
||||||
|
name="save_counter",
|
||||||
|
initializer=0,
|
||||||
|
dtype=dtypes.int64,
|
||||||
|
trainable=False))
|
||||||
|
|
||||||
def write(self, file_prefix):
|
def write(self, file_prefix):
|
||||||
"""Writes a training checkpoint.
|
"""Writes a training checkpoint.
|
||||||
@ -1608,8 +1628,7 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
Returns:
|
Returns:
|
||||||
The full path to the checkpoint (i.e. `file_prefix`).
|
The full path to the checkpoint (i.e. `file_prefix`).
|
||||||
"""
|
"""
|
||||||
output = self._saver.save(
|
output = self._saver.save(file_prefix=file_prefix)
|
||||||
file_prefix=file_prefix)
|
|
||||||
if tensor_util.is_tensor(output):
|
if tensor_util.is_tensor(output):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return compat.as_str(output.numpy())
|
return compat.as_str(output.numpy())
|
||||||
@ -1711,7 +1730,8 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
were not found in the checkpoint, or if any checkpointed values do not have
|
were not found in the checkpoint, or if any checkpointed values do not have
|
||||||
a matching Python object.
|
a matching Python object.
|
||||||
|
|
||||||
Name-based `tf.train.Saver` checkpoints from TensorFlow 1.x can be loaded
|
Name-based `tf.compat.v1.train.Saver` checkpoints from TensorFlow 1.x can be
|
||||||
|
loaded
|
||||||
using this method. Names are used to match variables. Re-encode name-based
|
using this method. Names are used to match variables. Re-encode name-based
|
||||||
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
checkpoints using `tf.train.Checkpoint.save` as soon as possible.
|
||||||
|
|
||||||
@ -1719,9 +1739,9 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
save_path: The path to the checkpoint, as returned by `save` or
|
save_path: The path to the checkpoint, as returned by `save` or
|
||||||
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
`tf.train.latest_checkpoint`. If None (as when there is no latest
|
||||||
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
checkpoint for `tf.train.latest_checkpoint` to return), returns an
|
||||||
object which may run initializers for objects in the dependency
|
object which may run initializers for objects in the dependency graph.
|
||||||
graph. If the checkpoint was written by the name-based `tf.train.Saver`,
|
If the checkpoint was written by the name-based
|
||||||
names are used to match variables.
|
`tf.compat.v1.train.Saver`, names are used to match variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A load status object, which can be used to make assertions about the
|
A load status object, which can be used to make assertions about the
|
||||||
@ -1744,7 +1764,8 @@ class Checkpoint(tracking.AutoTrackable):
|
|||||||
built, and so has not created any variables, will pass this assertion
|
built, and so has not created any variables, will pass this assertion
|
||||||
but fail `assert_consumed`. Useful when loading part of a larger
|
but fail `assert_consumed`. Useful when loading part of a larger
|
||||||
checkpoint into a new Python program, e.g. a training checkpoint with
|
checkpoint into a new Python program, e.g. a training checkpoint with
|
||||||
a `tf.train.Optimizer` was saved but only the state required for
|
a `tf.compat.v1.train.Optimizer` was saved but only the state required
|
||||||
|
for
|
||||||
inference is being loaded. This method returns the status object, and
|
inference is being loaded. This method returns the status object, and
|
||||||
so may be chained with other assertions.
|
so may be chained with other assertions.
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Utility functions for training."""
|
"""Utility functions for training."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -34,7 +33,6 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
# collection keys.
|
# collection keys.
|
||||||
GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
|
GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
|
||||||
|
|
||||||
|
|
||||||
# TODO(drpng): remove this after legacy uses are resolved.
|
# TODO(drpng): remove this after legacy uses are resolved.
|
||||||
write_graph = graph_io.write_graph
|
write_graph = graph_io.write_graph
|
||||||
|
|
||||||
@ -47,11 +45,12 @@ def global_step(sess, global_step_tensor):
|
|||||||
# Create a variable to hold the global_step.
|
# Create a variable to hold the global_step.
|
||||||
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
|
global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
|
||||||
# Create a session.
|
# Create a session.
|
||||||
sess = tf.Session()
|
sess = tf.compat.v1.Session()
|
||||||
# Initialize the variable
|
# Initialize the variable
|
||||||
sess.run(global_step_tensor.initializer)
|
sess.run(global_step_tensor.initializer)
|
||||||
# Get the variable value.
|
# Get the variable value.
|
||||||
print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))
|
print('global_step: %s' % tf.compat.v1.train.global_step(sess,
|
||||||
|
global_step_tensor))
|
||||||
|
|
||||||
global_step: 10
|
global_step: 10
|
||||||
```
|
```
|
||||||
@ -109,8 +108,8 @@ def create_global_step(graph=None):
|
|||||||
"""Create global step tensor in graph.
|
"""Create global step tensor in graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: The graph in which to create the global step tensor. If missing,
|
graph: The graph in which to create the global step tensor. If missing, use
|
||||||
use default graph.
|
default graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Global step tensor.
|
Global step tensor.
|
||||||
@ -130,8 +129,9 @@ def create_global_step(graph=None):
|
|||||||
initializer=init_ops.zeros_initializer(),
|
initializer=init_ops.zeros_initializer(),
|
||||||
trainable=False,
|
trainable=False,
|
||||||
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
collections=[
|
||||||
ops.GraphKeys.GLOBAL_STEP])
|
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
|
||||||
|
])
|
||||||
# Create in proper graph and base name_scope.
|
# Create in proper graph and base name_scope.
|
||||||
with graph.as_default() as g, g.name_scope(None):
|
with graph.as_default() as g, g.name_scope(None):
|
||||||
return variable_scope.get_variable(
|
return variable_scope.get_variable(
|
||||||
@ -141,8 +141,7 @@ def create_global_step(graph=None):
|
|||||||
initializer=init_ops.zeros_initializer(),
|
initializer=init_ops.zeros_initializer(),
|
||||||
trainable=False,
|
trainable=False,
|
||||||
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
|
||||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
|
||||||
ops.GraphKeys.GLOBAL_STEP])
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=['train.get_or_create_global_step'])
|
@tf_export(v1=['train.get_or_create_global_step'])
|
||||||
@ -173,9 +172,8 @@ def assert_global_step(global_step_tensor):
|
|||||||
if not (isinstance(global_step_tensor, variables.Variable) or
|
if not (isinstance(global_step_tensor, variables.Variable) or
|
||||||
isinstance(global_step_tensor, ops.Tensor) or
|
isinstance(global_step_tensor, ops.Tensor) or
|
||||||
resource_variable_ops.is_resource_variable(global_step_tensor)):
|
resource_variable_ops.is_resource_variable(global_step_tensor)):
|
||||||
raise TypeError(
|
raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
|
||||||
'Existing "global_step" must be a Variable or Tensor: %s.' %
|
global_step_tensor)
|
||||||
global_step_tensor)
|
|
||||||
|
|
||||||
if not global_step_tensor.dtype.base_dtype.is_integer:
|
if not global_step_tensor.dtype.base_dtype.is_integer:
|
||||||
raise TypeError('Existing "global_step" does not have integer type: %s' %
|
raise TypeError('Existing "global_step" does not have integer type: %s' %
|
||||||
|
@ -49,8 +49,8 @@ class VocabInfo(
|
|||||||
VocabInfo to warm-start.
|
VocabInfo to warm-start.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
new_vocab: [Required] A path to the new vocabulary file (used with the
|
new_vocab: [Required] A path to the new vocabulary file (used with the model
|
||||||
model to be trained).
|
to be trained).
|
||||||
new_vocab_size: [Required] An integer indicating how many entries of the new
|
new_vocab_size: [Required] An integer indicating how many entries of the new
|
||||||
vocabulary will used in training.
|
vocabulary will used in training.
|
||||||
num_oov_buckets: [Required] An integer indicating how many OOV buckets are
|
num_oov_buckets: [Required] An integer indicating how many OOV buckets are
|
||||||
@ -76,7 +76,7 @@ class VocabInfo(
|
|||||||
num_oov_buckets=1,
|
num_oov_buckets=1,
|
||||||
old_vocab='pretrained_embeddings_vocab',
|
old_vocab='pretrained_embeddings_vocab',
|
||||||
old_vocab_size=10000,
|
old_vocab_size=10000,
|
||||||
backup_initializer=tf.truncated_normal_initializer(
|
backup_initializer=tf.compat.v1.truncated_normal_initializer(
|
||||||
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
|
mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
|
||||||
axis=0)
|
axis=0)
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class VocabInfo(
|
|||||||
num_oov_buckets=0, # No OOV for classes.
|
num_oov_buckets=0, # No OOV for classes.
|
||||||
old_vocab='old_class_vocab',
|
old_vocab='old_class_vocab',
|
||||||
old_vocab_size=8,
|
old_vocab_size=8,
|
||||||
backup_initializer=tf.glorot_uniform_initializer(),
|
backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
|
||||||
axis=1)
|
axis=1)
|
||||||
|
|
||||||
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
|
softmax_output_layer_bias_vocab_info = tf.VocabInfo(
|
||||||
@ -95,7 +95,7 @@ class VocabInfo(
|
|||||||
num_oov_buckets=0, # No OOV for classes.
|
num_oov_buckets=0, # No OOV for classes.
|
||||||
old_vocab='old_class_vocab',
|
old_vocab='old_class_vocab',
|
||||||
old_vocab_size=8,
|
old_vocab_size=8,
|
||||||
backup_initializer=tf.zeros_initializer(),
|
backup_initializer=tf.compat.v1.zeros_initializer(),
|
||||||
axis=0)
|
axis=0)
|
||||||
|
|
||||||
Currently, only axis=0 and axis=1 are supported.
|
Currently, only axis=0 and axis=1 are supported.
|
||||||
@ -255,8 +255,7 @@ def _warm_start_var_with_vocab(var,
|
|||||||
partition_info = None
|
partition_info = None
|
||||||
if slice_info:
|
if slice_info:
|
||||||
partition_info = variable_scope._PartitionInfo(
|
partition_info = variable_scope._PartitionInfo(
|
||||||
full_shape=slice_info.full_shape,
|
full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
|
||||||
var_offset=slice_info.var_offset)
|
|
||||||
|
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
new_row_vocab_size = current_vocab_size
|
new_row_vocab_size = current_vocab_size
|
||||||
@ -301,6 +300,8 @@ def _warm_start_var_with_vocab(var,
|
|||||||
new_init_val = ops.convert_to_tensor(
|
new_init_val = ops.convert_to_tensor(
|
||||||
init(shape=v_shape, partition_info=partition_info))
|
init(shape=v_shape, partition_info=partition_info))
|
||||||
v._initializer_op = state_ops.assign(v, new_init_val)
|
v._initializer_op = state_ops.assign(v, new_init_val)
|
||||||
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@ -314,7 +315,8 @@ def _get_grouped_variables(vars_to_warm_start):
|
|||||||
vars_to_warm_start: One of the following:
|
vars_to_warm_start: One of the following:
|
||||||
|
|
||||||
- A regular expression (string) that captures which variables to
|
- A regular expression (string) that captures which variables to
|
||||||
warm-start (see tf.get_collection). This expression will only consider
|
warm-start (see tf.compat.v1.get_collection). This expression will
|
||||||
|
only consider
|
||||||
variables in the TRAINABLE_VARIABLES collection.
|
variables in the TRAINABLE_VARIABLES collection.
|
||||||
- A list of Variables to warm-start.
|
- A list of Variables to warm-start.
|
||||||
- A list of strings, each representing a full variable name to warm-start.
|
- A list of strings, each representing a full variable name to warm-start.
|
||||||
@ -330,14 +332,13 @@ def _get_grouped_variables(vars_to_warm_start):
|
|||||||
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
# Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
|
||||||
# everything (in TRAINABLE_VARIABLES) here.
|
# everything (in TRAINABLE_VARIABLES) here.
|
||||||
list_of_vars = ops.get_collection(
|
list_of_vars = ops.get_collection(
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
|
||||||
scope=vars_to_warm_start)
|
|
||||||
elif isinstance(vars_to_warm_start, list):
|
elif isinstance(vars_to_warm_start, list):
|
||||||
if all(isinstance(v, str) for v in vars_to_warm_start):
|
if all(isinstance(v, str) for v in vars_to_warm_start):
|
||||||
list_of_vars = []
|
list_of_vars = []
|
||||||
for v in vars_to_warm_start:
|
for v in vars_to_warm_start:
|
||||||
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
|
list_of_vars += ops.get_collection(
|
||||||
scope=v)
|
ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
|
||||||
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
|
elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
|
||||||
list_of_vars = vars_to_warm_start
|
list_of_vars = vars_to_warm_start
|
||||||
else:
|
else:
|
||||||
@ -377,16 +378,16 @@ def warm_start(ckpt_to_initialize_from,
|
|||||||
vars_to_warm_start: [Optional] One of the following:
|
vars_to_warm_start: [Optional] One of the following:
|
||||||
|
|
||||||
- A regular expression (string) that captures which variables to
|
- A regular expression (string) that captures which variables to
|
||||||
warm-start (see tf.get_collection). This expression will only consider
|
warm-start (see tf.compat.v1.get_collection). This expression will only
|
||||||
variables in the TRAINABLE_VARIABLES collection -- if you need to
|
consider variables in the TRAINABLE_VARIABLES collection -- if you need
|
||||||
warm-start non_TRAINABLE vars (such as optimizer accumulators or batch
|
to warm-start non_TRAINABLE vars (such as optimizer accumulators or
|
||||||
norm statistics), please use the below option.
|
batch norm statistics), please use the below option.
|
||||||
- A list of Variables to warm-start. If you do not have access to the
|
- A list of Variables to warm-start. If you do not have access to the
|
||||||
`Variable` objects at the call site, please use the below option.
|
`Variable` objects at the call site, please use the below option.
|
||||||
- A list of strings, each a regex scope provided to tf.get_collection with
|
- A list of strings, each a regex scope provided to
|
||||||
GLOBAL_VARIABLES (please see tf.get_collection). For backwards
|
tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
|
||||||
compatibility reasons, this is separate from the single-string argument
|
tf.compat.v1.get_collection). For backwards compatibility reasons,
|
||||||
type.
|
this is separate from the single-string argument type.
|
||||||
- `None`, in which case only variables specified in
|
- `None`, in which case only variables specified in
|
||||||
`var_name_to_vocab_info` will be warm-started.
|
`var_name_to_vocab_info` will be warm-started.
|
||||||
|
|
||||||
@ -404,6 +405,7 @@ def warm_start(ckpt_to_initialize_from,
|
|||||||
effect on the set of variables that is warm-started, and only controls
|
effect on the set of variables that is warm-started, and only controls
|
||||||
name mapping (use `vars_to_warm_start` for controlling what variables to
|
name mapping (use `vars_to_warm_start` for controlling what variables to
|
||||||
warm-start).
|
warm-start).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
|
ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
|
||||||
configuration for variable names that are not used. This is to ensure
|
configuration for variable names that are not used. This is to ensure
|
||||||
|
Loading…
x
Reference in New Issue
Block a user