[ROCm] Skipping subtests that check GPU stream tracing/profiling.

ROCm platform currently does not support the ability to do GPU stream level tracing / profiling.

This commit skip subtests (within python unit-tests) that this functionality. The "skip" is guarded by the call to "is_buil_with_rocm()", and hence these unit-tests will not be affected in any way when running with TF which was not built with ROCm support (i.e. `--config=rocm`)
This commit is contained in:
Deven Desai 2019-07-08 14:18:54 +00:00
parent 09d232dcaf
commit 360f99cbc3
3 changed files with 20 additions and 3 deletions

View File

@ -104,7 +104,10 @@ class TimelineTest(test.TestCase):
step_stats = run_metadata.step_stats
devices = [d.device for d in step_stats.dev_stats]
self.assertTrue('/job:localhost/replica:0/task:0/device:GPU:0' in devices)
self.assertTrue('/device:GPU:0/stream:all' in devices)
if not test.is_built_with_rocm():
# skip this check for the ROCm platform
# stream level tracing is not yet supported on the ROCm platform
self.assertTrue('/device:GPU:0/stream:all' in devices)
tl = timeline.Timeline(step_stats)
ctf = tl.generate_chrome_trace_format()
self._validateTrace(ctf)

View File

@ -129,7 +129,10 @@ class RunMetadataTest(test.TestCase):
ret = _extract_node(run_meta, 'MatMul')
self.assertEqual(len(ret['gpu:0']), 1)
self.assertEqual(len(ret['gpu:0/stream:all']), 1, '%s' % run_meta)
if not test.is_built_with_rocm():
# skip this check for the ROCm platform
# stream level tracing is not yet supported on the ROCm platform
self.assertEqual(len(ret['gpu:0/stream:all']), 1, '%s' % run_meta)
@test_util.run_deprecated_v1
def testAllocationHistory(self):
@ -234,7 +237,11 @@ class RunMetadataTest(test.TestCase):
for node in ret['gpu:0']:
total_cpu_execs += node.op_end_rel_micros
self.assertGreaterEqual(len(ret['gpu:0/stream:all']), 4, '%s' % run_meta)
if not test.is_built_with_rocm():
# skip this check for the ROCm platform
# stream level tracing is not yet supported on the ROCm platform
self.assertGreaterEqual(len(ret['gpu:0/stream:all']),
4, '%s' % run_meta)
if __name__ == '__main__':

View File

@ -69,6 +69,13 @@ class ProfilerContextTest(test.TestCase):
os.path.join(test.get_temp_dir(), "profile_100")) as profiler:
profiler.profile_operations(options=opts)
with gfile.Open(outfile, "r") as f:
if test.is_built_with_rocm():
# The profiler output for ROCm mode, includes an extra warning
# related to the lack of stream tracing in ROCm mode.
# Need to skip this warning when doing the diff
profile_str = "\n".join(profile_str.split("\n")[7:])
self.assertEqual(profile_str, f.read())
@test_util.run_deprecated_v1