ROCm platform currently does not support the ability to do GPU stream level tracing / profiling. This commit skip subtests (within python unit-tests) that this functionality. The "skip" is guarded by the call to "is_buil_with_rocm()", and hence these unit-tests will not be affected in any way when running with TF which was not built with ROCm support (i.e. `--config=rocm`)
129 lines
4.7 KiB
Python
129 lines
4.7 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import gfile
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.profiler import option_builder
|
|
|
|
# pylint: disable=g-bad-import-order
|
|
from tensorflow.python.profiler import profile_context
|
|
from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
|
|
|
|
builder = option_builder.ProfileOptionBuilder
|
|
|
|
|
|
class ProfilerContextTest(test.TestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasics(self):
|
|
ops.reset_default_graph()
|
|
outfile = os.path.join(test.get_temp_dir(), "dump")
|
|
opts = builder(builder.time_and_memory()
|
|
).with_file_output(outfile).build()
|
|
|
|
x = lib.BuildFullModel()
|
|
|
|
profile_str = None
|
|
profile_step100 = os.path.join(test.get_temp_dir(), "profile_100")
|
|
with profile_context.ProfileContext(test.get_temp_dir()) as pctx:
|
|
pctx.add_auto_profiling("op", options=opts, profile_steps=[15, 50, 100])
|
|
with session.Session() as sess:
|
|
self.evaluate(variables.global_variables_initializer())
|
|
total_steps = 101
|
|
for i in range(total_steps):
|
|
self.evaluate(x)
|
|
if i == 14 or i == 49:
|
|
self.assertTrue(gfile.Exists(outfile))
|
|
gfile.Remove(outfile)
|
|
if i == 99:
|
|
self.assertTrue(gfile.Exists(profile_step100))
|
|
with gfile.Open(outfile, "r") as f:
|
|
profile_str = f.read()
|
|
gfile.Remove(outfile)
|
|
|
|
self.assertEqual(set([15, 50, 100]), set(pctx.get_profiles("op").keys()))
|
|
|
|
with lib.ProfilerFromFile(
|
|
os.path.join(test.get_temp_dir(), "profile_100")) as profiler:
|
|
profiler.profile_operations(options=opts)
|
|
with gfile.Open(outfile, "r") as f:
|
|
|
|
if test.is_built_with_rocm():
|
|
# The profiler output for ROCm mode, includes an extra warning
|
|
# related to the lack of stream tracing in ROCm mode.
|
|
# Need to skip this warning when doing the diff
|
|
profile_str = "\n".join(profile_str.split("\n")[7:])
|
|
|
|
self.assertEqual(profile_str, f.read())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testAutoTracingInDeubMode(self):
|
|
ops.reset_default_graph()
|
|
x = lib.BuildFullModel()
|
|
|
|
with profile_context.ProfileContext(test.get_temp_dir(), debug=True):
|
|
with session.Session() as sess:
|
|
self.evaluate(variables.global_variables_initializer())
|
|
for _ in range(10):
|
|
self.evaluate(x)
|
|
for f in gfile.ListDirectory(test.get_temp_dir()):
|
|
# Warm up, no tracing.
|
|
self.assertFalse("run_meta" in f)
|
|
self.evaluate(x)
|
|
self.assertTrue(
|
|
gfile.Exists(os.path.join(test.get_temp_dir(), "run_meta_11")))
|
|
gfile.Remove(os.path.join(test.get_temp_dir(), "run_meta_11"))
|
|
# fetched already.
|
|
self.evaluate(x)
|
|
for f in gfile.ListDirectory(test.get_temp_dir()):
|
|
self.assertFalse("run_meta" in f)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDisabled(self):
|
|
ops.reset_default_graph()
|
|
x = lib.BuildFullModel()
|
|
with profile_context.ProfileContext(test.get_temp_dir(),
|
|
enabled=False) as pctx:
|
|
with session.Session() as sess:
|
|
self.evaluate(variables.global_variables_initializer())
|
|
for _ in range(10):
|
|
self.evaluate(x)
|
|
self.assertTrue(pctx.profiler is None)
|
|
self.assertTrue(
|
|
getattr(session.BaseSession, "profile_context", None) is None)
|
|
|
|
with profile_context.ProfileContext(test.get_temp_dir()) as pctx:
|
|
with session.Session() as sess:
|
|
self.evaluate(variables.global_variables_initializer())
|
|
for _ in range(10):
|
|
self.evaluate(x)
|
|
self.assertFalse(pctx.profiler is None)
|
|
self.assertFalse(
|
|
getattr(session.BaseSession, "profile_context", None) is None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|