Break all remaining rependencies in tensorflow/python/training on

contrib/.../framework_py

PiperOrigin-RevId: 268579788
This commit is contained in:
Gunhan Gulsoy 2019-09-11 17:36:06 -07:00 committed by TensorFlower Gardener
parent 53d31e990f
commit 6cb5fb444c
3 changed files with 50 additions and 53 deletions

View File

@ -5916,7 +5916,6 @@ tf_py_test(
":training", ":training",
":variable_scope", ":variable_scope",
":variables", ":variables",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
], ],
@ -6009,7 +6008,6 @@ tf_py_test(
":summary", ":summary",
":training", ":training",
":variables", ":variables",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/distribute:collective_all_reduce_strategy", "//tensorflow/python/distribute:collective_all_reduce_strategy",

View File

@ -24,8 +24,6 @@ import shutil
import tempfile import tempfile
import time import time
from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.testing.python.framework import fake_summary_writer from tensorflow.contrib.testing.python.framework import fake_summary_writer
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
@ -47,6 +45,7 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary as summary_lib from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.summary.writer import writer_cache from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
@ -151,7 +150,7 @@ class StopAtStepTest(test.TestCase):
def test_stop_based_on_last_step(self): def test_stop_based_on_last_step(self):
h = basic_session_run_hooks.StopAtStepHook(last_step=10) h = basic_session_run_hooks.StopAtStepHook(last_step=10)
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
no_op = control_flow_ops.no_op() no_op = control_flow_ops.no_op()
h.begin() h.begin()
with session_lib.Session() as sess: with session_lib.Session() as sess:
@ -175,7 +174,7 @@ class StopAtStepTest(test.TestCase):
h = basic_session_run_hooks.StopAtStepHook(num_steps=10) h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
no_op = control_flow_ops.no_op() no_op = control_flow_ops.no_op()
h.begin() h.begin()
with session_lib.Session() as sess: with session_lib.Session() as sess:
@ -202,7 +201,7 @@ class StopAtStepTest(test.TestCase):
h = basic_session_run_hooks.StopAtStepHook(num_steps=10) h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
no_op = control_flow_ops.no_op() no_op = control_flow_ops.no_op()
h.begin() h.begin()
with session_lib.Session() as sess: with session_lib.Session() as sess:
@ -391,7 +390,7 @@ class CheckpointSaverHookTest(test.TestCase):
self.graph = ops.Graph() self.graph = ops.Graph()
with self.graph.as_default(): with self.graph.as_default():
self.scaffold = monitored_session.Scaffold() self.scaffold = monitored_session.Scaffold()
self.global_step = variables.get_or_create_global_step() self.global_step = training_util.get_or_create_global_step()
self.train_op = training_util._increment_global_step(1) self.train_op = training_util._increment_global_step(1)
def tearDown(self): def tearDown(self):
@ -467,7 +466,7 @@ class CheckpointSaverHookTest(test.TestCase):
def test_listener_with_monitored_session(self): def test_listener_with_monitored_session(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
scaffold = monitored_session.Scaffold() scaffold = monitored_session.Scaffold()
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
listener = MockCheckpointSaverListener() listener = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook( hook = basic_session_run_hooks.CheckpointSaverHook(
@ -494,7 +493,7 @@ class CheckpointSaverHookTest(test.TestCase):
def test_listener_stops_training_in_after_save(self): def test_listener_stops_training_in_after_save(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
scaffold = monitored_session.Scaffold() scaffold = monitored_session.Scaffold()
variables.get_or_create_global_step() training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
listener = MockCheckpointSaverListener() listener = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook( hook = basic_session_run_hooks.CheckpointSaverHook(
@ -512,7 +511,7 @@ class CheckpointSaverHookTest(test.TestCase):
def test_listener_with_default_saver(self): def test_listener_with_default_saver(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
listener = MockCheckpointSaverListener() listener = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook( hook = basic_session_run_hooks.CheckpointSaverHook(
@ -535,7 +534,7 @@ class CheckpointSaverHookTest(test.TestCase):
}, listener_counts) }, listener_counts)
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
with monitored_session.SingularMonitoredSession( with monitored_session.SingularMonitoredSession(
checkpoint_dir=self.model_dir) as sess2: checkpoint_dir=self.model_dir) as sess2:
global_step_saved_val = sess2.run(global_step) global_step_saved_val = sess2.run(global_step)
@ -543,7 +542,7 @@ class CheckpointSaverHookTest(test.TestCase):
def test_two_listeners_with_default_saver(self): def test_two_listeners_with_default_saver(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
listener1 = MockCheckpointSaverListener() listener1 = MockCheckpointSaverListener()
listener2 = MockCheckpointSaverListener() listener2 = MockCheckpointSaverListener()
@ -569,7 +568,7 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(listener1_counts, listener2_counts) self.assertEqual(listener1_counts, listener2_counts)
with ops.Graph().as_default(): with ops.Graph().as_default():
global_step = variables.get_or_create_global_step() global_step = training_util.get_or_create_global_step()
with monitored_session.SingularMonitoredSession( with monitored_session.SingularMonitoredSession(
checkpoint_dir=self.model_dir) as sess2: checkpoint_dir=self.model_dir) as sess2:
global_step_saved_val = sess2.run(global_step) global_step_saved_val = sess2.run(global_step)
@ -786,7 +785,7 @@ class CheckpointSaverHookMultiStepTest(test.TestCase):
self.steps_per_run = 5 self.steps_per_run = 5
with self.graph.as_default(): with self.graph.as_default():
self.scaffold = monitored_session.Scaffold() self.scaffold = monitored_session.Scaffold()
self.global_step = variables.get_or_create_global_step() self.global_step = training_util.get_or_create_global_step()
self.train_op = training_util._increment_global_step(self.steps_per_run) self.train_op = training_util._increment_global_step(self.steps_per_run)
def tearDown(self): def tearDown(self):
@ -926,7 +925,7 @@ class StepCounterHookTest(test.TestCase):
def test_step_counter_every_n_steps(self, mock_time): def test_step_counter_every_n_steps(self, mock_time):
mock_time.return_value = MOCK_START_TIME mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess: with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step() training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook( hook = basic_session_run_hooks.StepCounterHook(
@ -956,7 +955,7 @@ class StepCounterHookTest(test.TestCase):
def test_step_counter_every_n_secs(self, mock_time): def test_step_counter_every_n_secs(self, mock_time):
mock_time.return_value = MOCK_START_TIME mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess: with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step() training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(1) train_op = training_util._increment_global_step(1)
summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook( hook = basic_session_run_hooks.StepCounterHook(
@ -1018,7 +1017,7 @@ class StepCounterHookTest(test.TestCase):
def test_log_warning_if_global_step_not_increased(self): def test_log_warning_if_global_step_not_increased(self):
with ops.Graph().as_default(), session_lib.Session() as sess: with ops.Graph().as_default(), session_lib.Session() as sess:
variables.get_or_create_global_step() training_util.get_or_create_global_step()
train_op = training_util._increment_global_step(0) # keep same. train_op = training_util._increment_global_step(0) # keep same.
self.evaluate(variables_lib.global_variables_initializer()) self.evaluate(variables_lib.global_variables_initializer())
hook = basic_session_run_hooks.StepCounterHook( hook = basic_session_run_hooks.StepCounterHook(
@ -1039,7 +1038,7 @@ class StepCounterHookTest(test.TestCase):
steps_per_run, steps_per_run,
graph, graph,
sess): sess):
variables.get_or_create_global_step() training_util.get_or_create_global_step()
self.train_op = training_util._increment_global_step(steps_per_run) self.train_op = training_util._increment_global_step(steps_per_run)
self.summary_writer = fake_summary_writer.FakeSummaryWriter( self.summary_writer = fake_summary_writer.FakeSummaryWriter(
self.log_dir, graph) self.log_dir, graph)
@ -1137,7 +1136,7 @@ class SummarySaverHookTest(test.TestCase):
self.summary_op = summary_lib.scalar('my_summary', tensor) self.summary_op = summary_lib.scalar('my_summary', tensor)
self.summary_op2 = summary_lib.scalar('my_summary2', tensor2) self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
variables.get_or_create_global_step() training_util.get_or_create_global_step()
self.train_op = training_util._increment_global_step(1) self.train_op = training_util._increment_global_step(1)
def test_raise_when_scaffold_and_summary_op_both_missing(self): def test_raise_when_scaffold_and_summary_op_both_missing(self):
@ -1292,7 +1291,7 @@ class GlobalStepWaiterHookTest(test.TestCase):
def test_not_wait_for_step_zero(self): def test_not_wait_for_step_zero(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
variables.get_or_create_global_step() training_util.get_or_create_global_step()
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0) hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
hook.begin() hook.begin()
with session_lib.Session() as sess: with session_lib.Session() as sess:
@ -1304,7 +1303,7 @@ class GlobalStepWaiterHookTest(test.TestCase):
@test.mock.patch.object(time, 'sleep') @test.mock.patch.object(time, 'sleep')
def test_wait_for_step(self, mock_sleep): def test_wait_for_step(self, mock_sleep):
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
hook.begin() hook.begin()
@ -1418,7 +1417,7 @@ class ResourceSummarySaverHookTest(test.TestCase):
self.summary_op = summary_lib.scalar('my_summary', tensor) self.summary_op = summary_lib.scalar('my_summary', tensor)
with variable_scope.variable_scope('foo', use_resource=True): with variable_scope.variable_scope('foo', use_resource=True):
variables.create_global_step() training_util.create_global_step()
self.train_op = training_util._increment_global_step(1) self.train_op = training_util._increment_global_step(1)
def test_save_steps(self): def test_save_steps(self):
@ -1475,7 +1474,7 @@ class ProfilerHookTest(test.TestCase):
self.graph = ops.Graph() self.graph = ops.Graph()
self.filepattern = os.path.join(self.output_dir, 'timeline-*.json') self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
with self.graph.as_default(): with self.graph.as_default():
self.global_step = variables.get_or_create_global_step() self.global_step = training_util.get_or_create_global_step()
self.train_op = state_ops.assign_add(self.global_step, 1) self.train_op = state_ops.assign_add(self.global_step, 1)
def tearDown(self): def tearDown(self):

View File

@ -27,7 +27,6 @@ import threading
import time import time
import traceback import traceback
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.testing.python.framework import util_test from tensorflow.contrib.testing.python.framework import util_test
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2 from tensorflow.core.protobuf import debug_pb2
@ -52,6 +51,7 @@ from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
class ScaffoldTest(test.TestCase): class ScaffoldTest(test.TestCase):
@ -274,7 +274,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
def test_saving_restoring_checkpoint(self): def test_saving_restoring_checkpoint(self):
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint') logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session: is_chief=True, checkpoint_dir=logdir) as session:
@ -289,7 +289,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
def test_save_checkpoint_steps(self): def test_save_checkpoint_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps') logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
is_chief=True, is_chief=True,
@ -306,7 +306,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
def test_save_checkpoint_secs(self): def test_save_checkpoint_secs(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs') logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
is_chief=True, is_chief=True,
@ -325,7 +325,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
def test_summaries_steps(self): def test_summaries_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps') logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2) summary.scalar('my_summary_tag', new_gstep * 2)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
@ -343,7 +343,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
def test_summaries_secs(self): def test_summaries_secs(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs') logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2) summary.scalar('my_summary_tag', new_gstep * 2)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
@ -365,7 +365,7 @@ class MonitoredTrainingSessionTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint') logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
fake_hook = FakeHook() fake_hook = FakeHook()
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession( with monitored_session.MonitoredTrainingSession(
is_chief=True, is_chief=True,
@ -414,7 +414,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled') logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2) summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession( with context, monitored_session.MonitoredTrainingSession(
@ -435,7 +435,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled') logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2) summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession( with context, monitored_session.MonitoredTrainingSession(
@ -455,7 +455,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled') logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession( with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir, checkpoint_dir=logdir,
@ -475,7 +475,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled') logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession( with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir, checkpoint_dir=logdir,
@ -496,7 +496,7 @@ class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled') logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1) new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession( with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir, checkpoint_dir=logdir,
@ -1430,7 +1430,7 @@ class MonitoredSessionTest(test.TestCase):
def test_last_step(self): def test_last_step(self):
logdir = _test_dir(self.get_temp_dir(), 'test_last_step') logdir = _test_dir(self.get_temp_dir(), 'test_last_step')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
# Run till step 3 and save. # Run till step 3 and save.
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)] hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
@ -1465,7 +1465,7 @@ class MonitoredSessionTest(test.TestCase):
def test_num_steps(self): def test_num_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_num_steps') logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
# Do 3 steps and save. # Do 3 steps and save.
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)] hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
@ -1504,7 +1504,7 @@ class MonitoredSessionTest(test.TestCase):
def test_recovery(self): def test_recovery(self):
logdir = _test_dir(self.get_temp_dir(), 'test_recovery') logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold() scaffold = monitored_session.Scaffold()
# Use a hook to save the model every 100 steps. It also saves it at # Use a hook to save the model every 100 steps. It also saves it at
@ -1536,7 +1536,7 @@ class MonitoredSessionTest(test.TestCase):
def test_retry_initialization_on_aborted_error(self): def test_retry_initialization_on_aborted_error(self):
# Tests that we silently retry on abort during initialization. # Tests that we silently retry on abort during initialization.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
self.init_raised_aborted_error = False self.init_raised_aborted_error = False
def _init_fn(scaffold, session): def _init_fn(scaffold, session):
@ -1557,7 +1557,7 @@ class MonitoredSessionTest(test.TestCase):
# Tests that we silently retry on error. Note that this does not test # Tests that we silently retry on error. Note that this does not test
# recovery as we do not use a CheckpointSaver in this test. # recovery as we do not use a CheckpointSaver in this test.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, ex) hook = RaiseOnceAtCountN(4, ex)
with monitored_session.MonitoredSession(hooks=[hook]) as session: with monitored_session.MonitoredSession(hooks=[hook]) as session:
@ -1587,7 +1587,7 @@ class MonitoredSessionTest(test.TestCase):
logdir = _test_dir(self.get_temp_dir(), logdir = _test_dir(self.get_temp_dir(),
'test_recover_and_retry_on_aborted_error') 'test_recover_and_retry_on_aborted_error')
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold() scaffold = monitored_session.Scaffold()
abort_hook = RaiseOnceAtCountN( abort_hook = RaiseOnceAtCountN(
@ -1615,7 +1615,7 @@ class MonitoredSessionTest(test.TestCase):
def test_exit_cleanly_on_out_of_range_exception(self): def test_exit_cleanly_on_out_of_range_exception(self):
# Tests that we stop cleanly when OutOfRange is raised. # Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None, hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
'EOI')) 'EOI'))
@ -1634,7 +1634,7 @@ class MonitoredSessionTest(test.TestCase):
def test_exit_cleanly_on_stop_iteration_exception(self): def test_exit_cleanly_on_stop_iteration_exception(self):
# Tests that we stop cleanly when OutOfRange is raised. # Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, StopIteration) hook = RaiseOnceAtCountN(2, StopIteration)
session = monitored_session.MonitoredSession(hooks=[hook]) session = monitored_session.MonitoredSession(hooks=[hook])
@ -1653,7 +1653,7 @@ class MonitoredSessionTest(test.TestCase):
# Tests that regular exceptions just pass through a "with # Tests that regular exceptions just pass through a "with
# MonitoredSession" block and set the session in stop mode. # MonitoredSession" block and set the session in stop mode.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, RuntimeError('regular exception')) hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
session = monitored_session.MonitoredSession(hooks=[hook]) session = monitored_session.MonitoredSession(hooks=[hook])
@ -1675,7 +1675,7 @@ class MonitoredSessionTest(test.TestCase):
# passes through a "run()" call within a "with MonitoredSession" block and # passes through a "run()" call within a "with MonitoredSession" block and
# set the session in stop mode. # set the session in stop mode.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
session = monitored_session.MonitoredSession() session = monitored_session.MonitoredSession()
run_performed_without_error = False run_performed_without_error = False
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'): with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
@ -1696,7 +1696,7 @@ class MonitoredSessionTest(test.TestCase):
# passes through returning from a "with MonitoredSession" block and # passes through returning from a "with MonitoredSession" block and
# set the session in stop mode. # set the session in stop mode.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
session = monitored_session.MonitoredSession() session = monitored_session.MonitoredSession()
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'): with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
with session: with session:
@ -1714,7 +1714,7 @@ class MonitoredSessionTest(test.TestCase):
def test_stop_cleanly_when_no_exception_in_with_body(self): def test_stop_cleanly_when_no_exception_in_with_body(self):
# Tests that regular exceptions pass through # Tests that regular exceptions pass through
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.MonitoredSession() session = monitored_session.MonitoredSession()
with session: with session:
@ -1728,7 +1728,7 @@ class MonitoredSessionTest(test.TestCase):
def test_raises_regular_exceptions_in_with_body(self): def test_raises_regular_exceptions_in_with_body(self):
# Tests that regular exceptions in "with body" are seen outside. # Tests that regular exceptions in "with body" are seen outside.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.MonitoredSession() session = monitored_session.MonitoredSession()
# We should see that exception. # We should see that exception.
@ -2184,7 +2184,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_do_not_handle_aborted_error(self): def test_do_not_handle_aborted_error(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
class _RaiseAbortedHook(session_run_hook.SessionRunHook): class _RaiseAbortedHook(session_run_hook.SessionRunHook):
@ -2204,7 +2204,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_exit_cleanly_on_out_of_range_exception(self): def test_exit_cleanly_on_out_of_range_exception(self):
# Tests that we stop cleanly when OutOfRange is raised. # Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None, hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
'EOI')) 'EOI'))
@ -2225,7 +2225,7 @@ class SingularMonitoredSessionTest(test.TestCase):
# passes through a "run()" call within a "with MonitoredSession" block and # passes through a "run()" call within a "with MonitoredSession" block and
# set the session in stop mode. # set the session in stop mode.
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
session = monitored_session.SingularMonitoredSession() session = monitored_session.SingularMonitoredSession()
run_performed_without_error = False run_performed_without_error = False
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'): with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
@ -2244,7 +2244,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_stop_cleanly_when_no_exception_in_with_body(self): def test_stop_cleanly_when_no_exception_in_with_body(self):
# Tests that regular exceptions pass through # Tests that regular exceptions pass through
with ops.Graph().as_default(): with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step() gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1) do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.SingularMonitoredSession() session = monitored_session.SingularMonitoredSession()
with session: with session: