Add traceme for ExecutorState::Process.
PiperOrigin-RevId: 273607269
This commit is contained in:
parent
0f5189f2bd
commit
8f90210c1e
@ -665,7 +665,9 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
|
||||
|
||||
// Checks RunMetadata is well-formed
|
||||
ASSERT_TRUE(run_metadata.has_step_stats());
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
|
||||
// TODO(b/142268415): Should be 2 after separating traceme events from
|
||||
// StepStats.
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
|
||||
@ -697,7 +699,9 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
|
||||
|
||||
// Checks RunMetadata is well-formed
|
||||
ASSERT_TRUE(run_metadata.has_step_stats());
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
|
||||
// TODO(b/142268415): Should be 2 after separating traceme events from
|
||||
// StepStats.
|
||||
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
|
||||
|
@ -1667,6 +1667,15 @@ bool MightTrace(const NodeItem& item,
|
||||
}
|
||||
|
||||
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
profiler::TraceMe activity(
|
||||
[&] {
|
||||
int64 id = step_id_;
|
||||
if (step_container_ && step_container_->step_id()) {
|
||||
id = step_container_->step_id();
|
||||
}
|
||||
return absl::StrCat("ExecutorState::Process#id=", id, "#");
|
||||
},
|
||||
2);
|
||||
WithContext wc(context_);
|
||||
TaggedNodeSeq ready;
|
||||
TaggedNodeReadyQueue inline_ready;
|
||||
|
@ -393,7 +393,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
res = sess.run([[], tuple(), {}])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEquals(3, len(res))
|
||||
self.assertEqual(3, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
self.assertEqual(0, len(res[0]))
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
@ -403,7 +403,7 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
res = sess.run([[], tuple(), {}, a])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEquals(4, len(res))
|
||||
self.assertEqual(4, len(res))
|
||||
self.assertTrue(isinstance(res[0], list))
|
||||
self.assertEqual(0, len(res[0]))
|
||||
self.assertTrue(isinstance(res[1], tuple))
|
||||
@ -1231,13 +1231,13 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
|
||||
sess.graph_def)
|
||||
c = constant_op.constant(5.0, name='c')
|
||||
self.assertEquals(len(sess.graph_def.node), 1)
|
||||
self.assertEqual(len(sess.graph_def.node), 1)
|
||||
d = constant_op.constant(6.0, name='d')
|
||||
self.assertEquals(len(sess.graph_def.node), 2)
|
||||
self.assertEqual(len(sess.graph_def.node), 2)
|
||||
self.assertAllEqual(c.eval(), 5.0)
|
||||
self.assertAllEqual(d.eval(), 6.0)
|
||||
e = constant_op.constant(7.0, name='e')
|
||||
self.assertEquals(len(sess.graph_def.node), 3)
|
||||
self.assertEqual(len(sess.graph_def.node), 3)
|
||||
self.assertAllEqual(e.eval(), 7.0)
|
||||
|
||||
def testUseAfterClose(self):
|
||||
@ -1683,7 +1683,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
|
||||
# TODO(b/142268415): Should be 1 after separating traceme events from
|
||||
# StepStats.
|
||||
self.assertEqual(len(run_metadata.step_stats.dev_stats), 2)
|
||||
|
||||
def testRunOptionsRunMetadata(self):
|
||||
run_options = config_pb2.RunOptions(
|
||||
@ -1708,7 +1710,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
run_metadata=run_metadata)
|
||||
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
|
||||
# TODO(b/142268415): Should be 1 after separating traceme events from
|
||||
# StepStats.
|
||||
self.assertEqual(len(run_metadata.step_stats.dev_stats), 2)
|
||||
|
||||
def testFeedShapeCompatibility(self):
|
||||
with session.Session() as sess:
|
||||
|
@ -169,8 +169,10 @@ class RunMetadataTest(test.TestCase):
|
||||
ops.reset_default_graph()
|
||||
with ops.device('/cpu:0'):
|
||||
tfprof_node, run_meta = _run_model()
|
||||
self.assertEqual(tfprof_node.children[0].name, 'MatMul')
|
||||
self.assertGreater(tfprof_node.children[0].exec_micros, 0)
|
||||
# TODO(b/142268415): Should be children[0] after separating traceme events
|
||||
# from StepStats.
|
||||
self.assertEqual(tfprof_node.children[1].name, 'MatMul')
|
||||
self.assertGreater(tfprof_node.children[1].exec_micros, 0)
|
||||
|
||||
ret = _extract_node(run_meta, 'MatMul')
|
||||
self.assertEqual(len(ret['cpu:0']), 1)
|
||||
|
Loading…
Reference in New Issue
Block a user