Break all remaining rependencies in tensorflow/python/training on
contrib/.../framework_py PiperOrigin-RevId: 268579788
This commit is contained in:
parent
53d31e990f
commit
6cb5fb444c
@ -5916,7 +5916,6 @@ tf_py_test(
|
|||||||
":training",
|
":training",
|
||||||
":variable_scope",
|
":variable_scope",
|
||||||
":variables",
|
":variables",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
|
||||||
"//tensorflow/contrib/testing:testing_py",
|
"//tensorflow/contrib/testing:testing_py",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
],
|
],
|
||||||
@ -6009,7 +6008,6 @@ tf_py_test(
|
|||||||
":summary",
|
":summary",
|
||||||
":training",
|
":training",
|
||||||
":variables",
|
":variables",
|
||||||
"//tensorflow/contrib/framework:framework_py",
|
|
||||||
"//tensorflow/contrib/testing:testing_py",
|
"//tensorflow/contrib/testing:testing_py",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
"//tensorflow/python/distribute:collective_all_reduce_strategy",
|
||||||
|
@ -24,8 +24,6 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
|
||||||
from tensorflow.contrib.framework.python.ops import variables
|
|
||||||
from tensorflow.contrib.testing.python.framework import fake_summary_writer
|
from tensorflow.contrib.testing.python.framework import fake_summary_writer
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
@ -47,6 +45,7 @@ from tensorflow.python.platform import tf_logging
|
|||||||
from tensorflow.python.summary import summary as summary_lib
|
from tensorflow.python.summary import summary as summary_lib
|
||||||
from tensorflow.python.summary.writer import writer_cache
|
from tensorflow.python.summary.writer import writer_cache
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
|
from tensorflow.python.training import checkpoint_utils
|
||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
@ -151,7 +150,7 @@ class StopAtStepTest(test.TestCase):
|
|||||||
def test_stop_based_on_last_step(self):
|
def test_stop_based_on_last_step(self):
|
||||||
h = basic_session_run_hooks.StopAtStepHook(last_step=10)
|
h = basic_session_run_hooks.StopAtStepHook(last_step=10)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
no_op = control_flow_ops.no_op()
|
no_op = control_flow_ops.no_op()
|
||||||
h.begin()
|
h.begin()
|
||||||
with session_lib.Session() as sess:
|
with session_lib.Session() as sess:
|
||||||
@ -175,7 +174,7 @@ class StopAtStepTest(test.TestCase):
|
|||||||
h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
|
h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
no_op = control_flow_ops.no_op()
|
no_op = control_flow_ops.no_op()
|
||||||
h.begin()
|
h.begin()
|
||||||
with session_lib.Session() as sess:
|
with session_lib.Session() as sess:
|
||||||
@ -202,7 +201,7 @@ class StopAtStepTest(test.TestCase):
|
|||||||
h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
|
h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
no_op = control_flow_ops.no_op()
|
no_op = control_flow_ops.no_op()
|
||||||
h.begin()
|
h.begin()
|
||||||
with session_lib.Session() as sess:
|
with session_lib.Session() as sess:
|
||||||
@ -391,7 +390,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
self.graph = ops.Graph()
|
self.graph = ops.Graph()
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.scaffold = monitored_session.Scaffold()
|
self.scaffold = monitored_session.Scaffold()
|
||||||
self.global_step = variables.get_or_create_global_step()
|
self.global_step = training_util.get_or_create_global_step()
|
||||||
self.train_op = training_util._increment_global_step(1)
|
self.train_op = training_util._increment_global_step(1)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -467,7 +466,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
def test_listener_with_monitored_session(self):
|
def test_listener_with_monitored_session(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
scaffold = monitored_session.Scaffold()
|
scaffold = monitored_session.Scaffold()
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
listener = MockCheckpointSaverListener()
|
listener = MockCheckpointSaverListener()
|
||||||
hook = basic_session_run_hooks.CheckpointSaverHook(
|
hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||||
@ -494,7 +493,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
def test_listener_stops_training_in_after_save(self):
|
def test_listener_stops_training_in_after_save(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
scaffold = monitored_session.Scaffold()
|
scaffold = monitored_session.Scaffold()
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
listener = MockCheckpointSaverListener()
|
listener = MockCheckpointSaverListener()
|
||||||
hook = basic_session_run_hooks.CheckpointSaverHook(
|
hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||||
@ -512,7 +511,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
|
|
||||||
def test_listener_with_default_saver(self):
|
def test_listener_with_default_saver(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
listener = MockCheckpointSaverListener()
|
listener = MockCheckpointSaverListener()
|
||||||
hook = basic_session_run_hooks.CheckpointSaverHook(
|
hook = basic_session_run_hooks.CheckpointSaverHook(
|
||||||
@ -535,7 +534,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
}, listener_counts)
|
}, listener_counts)
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
with monitored_session.SingularMonitoredSession(
|
with monitored_session.SingularMonitoredSession(
|
||||||
checkpoint_dir=self.model_dir) as sess2:
|
checkpoint_dir=self.model_dir) as sess2:
|
||||||
global_step_saved_val = sess2.run(global_step)
|
global_step_saved_val = sess2.run(global_step)
|
||||||
@ -543,7 +542,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
|
|
||||||
def test_two_listeners_with_default_saver(self):
|
def test_two_listeners_with_default_saver(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
listener1 = MockCheckpointSaverListener()
|
listener1 = MockCheckpointSaverListener()
|
||||||
listener2 = MockCheckpointSaverListener()
|
listener2 = MockCheckpointSaverListener()
|
||||||
@ -569,7 +568,7 @@ class CheckpointSaverHookTest(test.TestCase):
|
|||||||
self.assertEqual(listener1_counts, listener2_counts)
|
self.assertEqual(listener1_counts, listener2_counts)
|
||||||
|
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
global_step = variables.get_or_create_global_step()
|
global_step = training_util.get_or_create_global_step()
|
||||||
with monitored_session.SingularMonitoredSession(
|
with monitored_session.SingularMonitoredSession(
|
||||||
checkpoint_dir=self.model_dir) as sess2:
|
checkpoint_dir=self.model_dir) as sess2:
|
||||||
global_step_saved_val = sess2.run(global_step)
|
global_step_saved_val = sess2.run(global_step)
|
||||||
@ -786,7 +785,7 @@ class CheckpointSaverHookMultiStepTest(test.TestCase):
|
|||||||
self.steps_per_run = 5
|
self.steps_per_run = 5
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.scaffold = monitored_session.Scaffold()
|
self.scaffold = monitored_session.Scaffold()
|
||||||
self.global_step = variables.get_or_create_global_step()
|
self.global_step = training_util.get_or_create_global_step()
|
||||||
self.train_op = training_util._increment_global_step(self.steps_per_run)
|
self.train_op = training_util._increment_global_step(self.steps_per_run)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -926,7 +925,7 @@ class StepCounterHookTest(test.TestCase):
|
|||||||
def test_step_counter_every_n_steps(self, mock_time):
|
def test_step_counter_every_n_steps(self, mock_time):
|
||||||
mock_time.return_value = MOCK_START_TIME
|
mock_time.return_value = MOCK_START_TIME
|
||||||
with ops.Graph().as_default() as g, session_lib.Session() as sess:
|
with ops.Graph().as_default() as g, session_lib.Session() as sess:
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
|
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
|
||||||
hook = basic_session_run_hooks.StepCounterHook(
|
hook = basic_session_run_hooks.StepCounterHook(
|
||||||
@ -956,7 +955,7 @@ class StepCounterHookTest(test.TestCase):
|
|||||||
def test_step_counter_every_n_secs(self, mock_time):
|
def test_step_counter_every_n_secs(self, mock_time):
|
||||||
mock_time.return_value = MOCK_START_TIME
|
mock_time.return_value = MOCK_START_TIME
|
||||||
with ops.Graph().as_default() as g, session_lib.Session() as sess:
|
with ops.Graph().as_default() as g, session_lib.Session() as sess:
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(1)
|
train_op = training_util._increment_global_step(1)
|
||||||
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
|
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
|
||||||
hook = basic_session_run_hooks.StepCounterHook(
|
hook = basic_session_run_hooks.StepCounterHook(
|
||||||
@ -1018,7 +1017,7 @@ class StepCounterHookTest(test.TestCase):
|
|||||||
|
|
||||||
def test_log_warning_if_global_step_not_increased(self):
|
def test_log_warning_if_global_step_not_increased(self):
|
||||||
with ops.Graph().as_default(), session_lib.Session() as sess:
|
with ops.Graph().as_default(), session_lib.Session() as sess:
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
train_op = training_util._increment_global_step(0) # keep same.
|
train_op = training_util._increment_global_step(0) # keep same.
|
||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
hook = basic_session_run_hooks.StepCounterHook(
|
hook = basic_session_run_hooks.StepCounterHook(
|
||||||
@ -1039,7 +1038,7 @@ class StepCounterHookTest(test.TestCase):
|
|||||||
steps_per_run,
|
steps_per_run,
|
||||||
graph,
|
graph,
|
||||||
sess):
|
sess):
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
self.train_op = training_util._increment_global_step(steps_per_run)
|
self.train_op = training_util._increment_global_step(steps_per_run)
|
||||||
self.summary_writer = fake_summary_writer.FakeSummaryWriter(
|
self.summary_writer = fake_summary_writer.FakeSummaryWriter(
|
||||||
self.log_dir, graph)
|
self.log_dir, graph)
|
||||||
@ -1137,7 +1136,7 @@ class SummarySaverHookTest(test.TestCase):
|
|||||||
self.summary_op = summary_lib.scalar('my_summary', tensor)
|
self.summary_op = summary_lib.scalar('my_summary', tensor)
|
||||||
self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
|
self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
|
||||||
|
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
self.train_op = training_util._increment_global_step(1)
|
self.train_op = training_util._increment_global_step(1)
|
||||||
|
|
||||||
def test_raise_when_scaffold_and_summary_op_both_missing(self):
|
def test_raise_when_scaffold_and_summary_op_both_missing(self):
|
||||||
@ -1292,7 +1291,7 @@ class GlobalStepWaiterHookTest(test.TestCase):
|
|||||||
|
|
||||||
def test_not_wait_for_step_zero(self):
|
def test_not_wait_for_step_zero(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.get_or_create_global_step()
|
training_util.get_or_create_global_step()
|
||||||
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
|
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
|
||||||
hook.begin()
|
hook.begin()
|
||||||
with session_lib.Session() as sess:
|
with session_lib.Session() as sess:
|
||||||
@ -1304,7 +1303,7 @@ class GlobalStepWaiterHookTest(test.TestCase):
|
|||||||
@test.mock.patch.object(time, 'sleep')
|
@test.mock.patch.object(time, 'sleep')
|
||||||
def test_wait_for_step(self, mock_sleep):
|
def test_wait_for_step(self, mock_sleep):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
|
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
|
||||||
hook.begin()
|
hook.begin()
|
||||||
|
|
||||||
@ -1418,7 +1417,7 @@ class ResourceSummarySaverHookTest(test.TestCase):
|
|||||||
self.summary_op = summary_lib.scalar('my_summary', tensor)
|
self.summary_op = summary_lib.scalar('my_summary', tensor)
|
||||||
|
|
||||||
with variable_scope.variable_scope('foo', use_resource=True):
|
with variable_scope.variable_scope('foo', use_resource=True):
|
||||||
variables.create_global_step()
|
training_util.create_global_step()
|
||||||
self.train_op = training_util._increment_global_step(1)
|
self.train_op = training_util._increment_global_step(1)
|
||||||
|
|
||||||
def test_save_steps(self):
|
def test_save_steps(self):
|
||||||
@ -1475,7 +1474,7 @@ class ProfilerHookTest(test.TestCase):
|
|||||||
self.graph = ops.Graph()
|
self.graph = ops.Graph()
|
||||||
self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
|
self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
|
||||||
with self.graph.as_default():
|
with self.graph.as_default():
|
||||||
self.global_step = variables.get_or_create_global_step()
|
self.global_step = training_util.get_or_create_global_step()
|
||||||
self.train_op = state_ops.assign_add(self.global_step, 1)
|
self.train_op = state_ops.assign_add(self.global_step, 1)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
|
@ -27,7 +27,6 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.ops import variables as variables_lib
|
|
||||||
from tensorflow.contrib.testing.python.framework import util_test
|
from tensorflow.contrib.testing.python.framework import util_test
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.protobuf import debug_pb2
|
from tensorflow.core.protobuf import debug_pb2
|
||||||
@ -52,6 +51,7 @@ from tensorflow.python.training import coordinator
|
|||||||
from tensorflow.python.training import monitored_session
|
from tensorflow.python.training import monitored_session
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
|
from tensorflow.python.training import training_util
|
||||||
|
|
||||||
|
|
||||||
class ScaffoldTest(test.TestCase):
|
class ScaffoldTest(test.TestCase):
|
||||||
@ -274,7 +274,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
def test_saving_restoring_checkpoint(self):
|
def test_saving_restoring_checkpoint(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
|
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
is_chief=True, checkpoint_dir=logdir) as session:
|
is_chief=True, checkpoint_dir=logdir) as session:
|
||||||
@ -289,7 +289,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
def test_save_checkpoint_steps(self):
|
def test_save_checkpoint_steps(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
|
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
is_chief=True,
|
is_chief=True,
|
||||||
@ -306,7 +306,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
def test_save_checkpoint_secs(self):
|
def test_save_checkpoint_secs(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
|
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
is_chief=True,
|
is_chief=True,
|
||||||
@ -325,7 +325,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
def test_summaries_steps(self):
|
def test_summaries_steps(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
|
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
@ -343,7 +343,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
def test_summaries_secs(self):
|
def test_summaries_secs(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs')
|
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
@ -365,7 +365,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
|
|||||||
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
|
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
|
||||||
fake_hook = FakeHook()
|
fake_hook = FakeHook()
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
with monitored_session.MonitoredTrainingSession(
|
with monitored_session.MonitoredTrainingSession(
|
||||||
is_chief=True,
|
is_chief=True,
|
||||||
@ -414,7 +414,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
|||||||
|
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
|
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||||
with context, monitored_session.MonitoredTrainingSession(
|
with context, monitored_session.MonitoredTrainingSession(
|
||||||
@ -435,7 +435,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
|||||||
|
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
|
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
summary.scalar('my_summary_tag', new_gstep * 2)
|
summary.scalar('my_summary_tag', new_gstep * 2)
|
||||||
with context, monitored_session.MonitoredTrainingSession(
|
with context, monitored_session.MonitoredTrainingSession(
|
||||||
@ -455,7 +455,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
|||||||
|
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
|
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
with context, monitored_session.MonitoredTrainingSession(
|
with context, monitored_session.MonitoredTrainingSession(
|
||||||
checkpoint_dir=logdir,
|
checkpoint_dir=logdir,
|
||||||
@ -475,7 +475,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
|||||||
|
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
|
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
with context, monitored_session.MonitoredTrainingSession(
|
with context, monitored_session.MonitoredTrainingSession(
|
||||||
checkpoint_dir=logdir,
|
checkpoint_dir=logdir,
|
||||||
@ -496,7 +496,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
|
|||||||
|
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
|
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
new_gstep = state_ops.assign_add(gstep, 1)
|
new_gstep = state_ops.assign_add(gstep, 1)
|
||||||
with context, monitored_session.MonitoredTrainingSession(
|
with context, monitored_session.MonitoredTrainingSession(
|
||||||
checkpoint_dir=logdir,
|
checkpoint_dir=logdir,
|
||||||
@ -1430,7 +1430,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_last_step(self):
|
def test_last_step(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_last_step')
|
logdir = _test_dir(self.get_temp_dir(), 'test_last_step')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
# Run till step 3 and save.
|
# Run till step 3 and save.
|
||||||
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
|
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
|
||||||
@ -1465,7 +1465,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_num_steps(self):
|
def test_num_steps(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
|
logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
# Do 3 steps and save.
|
# Do 3 steps and save.
|
||||||
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
|
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
|
||||||
@ -1504,7 +1504,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_recovery(self):
|
def test_recovery(self):
|
||||||
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
|
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
scaffold = monitored_session.Scaffold()
|
scaffold = monitored_session.Scaffold()
|
||||||
# Use a hook to save the model every 100 steps. It also saves it at
|
# Use a hook to save the model every 100 steps. It also saves it at
|
||||||
@ -1536,7 +1536,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_retry_initialization_on_aborted_error(self):
|
def test_retry_initialization_on_aborted_error(self):
|
||||||
# Tests that we silently retry on abort during initialization.
|
# Tests that we silently retry on abort during initialization.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
self.init_raised_aborted_error = False
|
self.init_raised_aborted_error = False
|
||||||
|
|
||||||
def _init_fn(scaffold, session):
|
def _init_fn(scaffold, session):
|
||||||
@ -1557,7 +1557,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
# Tests that we silently retry on error. Note that this does not test
|
# Tests that we silently retry on error. Note that this does not test
|
||||||
# recovery as we do not use a CheckpointSaver in this test.
|
# recovery as we do not use a CheckpointSaver in this test.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
hook = RaiseOnceAtCountN(4, ex)
|
hook = RaiseOnceAtCountN(4, ex)
|
||||||
with monitored_session.MonitoredSession(hooks=[hook]) as session:
|
with monitored_session.MonitoredSession(hooks=[hook]) as session:
|
||||||
@ -1587,7 +1587,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
logdir = _test_dir(self.get_temp_dir(),
|
logdir = _test_dir(self.get_temp_dir(),
|
||||||
'test_recover_and_retry_on_aborted_error')
|
'test_recover_and_retry_on_aborted_error')
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
scaffold = monitored_session.Scaffold()
|
scaffold = monitored_session.Scaffold()
|
||||||
abort_hook = RaiseOnceAtCountN(
|
abort_hook = RaiseOnceAtCountN(
|
||||||
@ -1615,7 +1615,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_exit_cleanly_on_out_of_range_exception(self):
|
def test_exit_cleanly_on_out_of_range_exception(self):
|
||||||
# Tests that we stop cleanly when OutOfRange is raised.
|
# Tests that we stop cleanly when OutOfRange is raised.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
|
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
|
||||||
'EOI'))
|
'EOI'))
|
||||||
@ -1634,7 +1634,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_exit_cleanly_on_stop_iteration_exception(self):
|
def test_exit_cleanly_on_stop_iteration_exception(self):
|
||||||
# Tests that we stop cleanly when OutOfRange is raised.
|
# Tests that we stop cleanly when OutOfRange is raised.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
hook = RaiseOnceAtCountN(2, StopIteration)
|
hook = RaiseOnceAtCountN(2, StopIteration)
|
||||||
session = monitored_session.MonitoredSession(hooks=[hook])
|
session = monitored_session.MonitoredSession(hooks=[hook])
|
||||||
@ -1653,7 +1653,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
# Tests that regular exceptions just pass through a "with
|
# Tests that regular exceptions just pass through a "with
|
||||||
# MonitoredSession" block and set the session in stop mode.
|
# MonitoredSession" block and set the session in stop mode.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
|
hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
|
||||||
session = monitored_session.MonitoredSession(hooks=[hook])
|
session = monitored_session.MonitoredSession(hooks=[hook])
|
||||||
@ -1675,7 +1675,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
# passes through a "run()" call within a "with MonitoredSession" block and
|
# passes through a "run()" call within a "with MonitoredSession" block and
|
||||||
# set the session in stop mode.
|
# set the session in stop mode.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
session = monitored_session.MonitoredSession()
|
session = monitored_session.MonitoredSession()
|
||||||
run_performed_without_error = False
|
run_performed_without_error = False
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
||||||
@ -1696,7 +1696,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
# passes through returning from a "with MonitoredSession" block and
|
# passes through returning from a "with MonitoredSession" block and
|
||||||
# set the session in stop mode.
|
# set the session in stop mode.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
session = monitored_session.MonitoredSession()
|
session = monitored_session.MonitoredSession()
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
||||||
with session:
|
with session:
|
||||||
@ -1714,7 +1714,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_stop_cleanly_when_no_exception_in_with_body(self):
|
def test_stop_cleanly_when_no_exception_in_with_body(self):
|
||||||
# Tests that regular exceptions pass through
|
# Tests that regular exceptions pass through
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
session = monitored_session.MonitoredSession()
|
session = monitored_session.MonitoredSession()
|
||||||
with session:
|
with session:
|
||||||
@ -1728,7 +1728,7 @@ class MonitoredSessionTest(test.TestCase):
|
|||||||
def test_raises_regular_exceptions_in_with_body(self):
|
def test_raises_regular_exceptions_in_with_body(self):
|
||||||
# Tests that regular exceptions in "with body" are seen outside.
|
# Tests that regular exceptions in "with body" are seen outside.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
session = monitored_session.MonitoredSession()
|
session = monitored_session.MonitoredSession()
|
||||||
# We should see that exception.
|
# We should see that exception.
|
||||||
@ -2184,7 +2184,7 @@ class SingularMonitoredSessionTest(test.TestCase):
|
|||||||
|
|
||||||
def test_do_not_handle_aborted_error(self):
|
def test_do_not_handle_aborted_error(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
|
|
||||||
class _RaiseAbortedHook(session_run_hook.SessionRunHook):
|
class _RaiseAbortedHook(session_run_hook.SessionRunHook):
|
||||||
|
|
||||||
@ -2204,7 +2204,7 @@ class SingularMonitoredSessionTest(test.TestCase):
|
|||||||
def test_exit_cleanly_on_out_of_range_exception(self):
|
def test_exit_cleanly_on_out_of_range_exception(self):
|
||||||
# Tests that we stop cleanly when OutOfRange is raised.
|
# Tests that we stop cleanly when OutOfRange is raised.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
|
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
|
||||||
'EOI'))
|
'EOI'))
|
||||||
@ -2225,7 +2225,7 @@ class SingularMonitoredSessionTest(test.TestCase):
|
|||||||
# passes through a "run()" call within a "with MonitoredSession" block and
|
# passes through a "run()" call within a "with MonitoredSession" block and
|
||||||
# set the session in stop mode.
|
# set the session in stop mode.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
session = monitored_session.SingularMonitoredSession()
|
session = monitored_session.SingularMonitoredSession()
|
||||||
run_performed_without_error = False
|
run_performed_without_error = False
|
||||||
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
|
||||||
@ -2244,7 +2244,7 @@ class SingularMonitoredSessionTest(test.TestCase):
|
|||||||
def test_stop_cleanly_when_no_exception_in_with_body(self):
|
def test_stop_cleanly_when_no_exception_in_with_body(self):
|
||||||
# Tests that regular exceptions pass through
|
# Tests that regular exceptions pass through
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
gstep = variables_lib.get_or_create_global_step()
|
gstep = training_util.get_or_create_global_step()
|
||||||
do_step = state_ops.assign_add(gstep, 1)
|
do_step = state_ops.assign_add(gstep, 1)
|
||||||
session = monitored_session.SingularMonitoredSession()
|
session = monitored_session.SingularMonitoredSession()
|
||||||
with session:
|
with session:
|
||||||
|
Loading…
Reference in New Issue
Block a user