move singleton xla::gpu::GPUDebugInfoManager to xla::XlaDebugInfoManager. in hope that it can be reused by XLA/CPU.
PiperOrigin-RevId: 356533750 Change-Id: Ib4dc25f9ce7bfe729bf19da6a92e0102518f502a
This commit is contained in:
parent
f11e226a34
commit
e1b9195cae
@ -5436,6 +5436,34 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_debug_info_manager",
|
||||
srcs = [
|
||||
"xla_debug_info_manager.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_debug_info_manager.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_proto_cc",
|
||||
":hlo_proto_util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "xla_debug_info_manager_test",
|
||||
srcs = ["xla_debug_info_manager_test.cc"],
|
||||
deps = [
|
||||
":hlo_proto_cc",
|
||||
":xla_debug_info_manager",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# copybara:uncomment_begin(google-only)
|
||||
# py_proto_library(
|
||||
# name = "hlo_pb2",
|
||||
|
@ -678,35 +678,6 @@ alias(
|
||||
actual = if_cuda_or_rocm(":nccl_utils", ":empty"),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_debug_info_manager",
|
||||
srcs = [
|
||||
"gpu_debug_info_manager.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gpu_debug_info_manager.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_debug_info_manager_test",
|
||||
srcs = ["gpu_debug_info_manager_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_debug_info_manager",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_executable",
|
||||
srcs = [
|
||||
@ -763,7 +734,6 @@ cc_library(
|
||||
":cudnn_batchnorm_runner",
|
||||
":gpu_constants",
|
||||
":gpu_conv_runner",
|
||||
":gpu_debug_info_manager",
|
||||
":gpu_executable_run_options",
|
||||
":gpu_types",
|
||||
":hlo_execution_profiler",
|
||||
@ -795,6 +765,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/compiler/xla/service:xla_debug_info_manager",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_types.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
@ -33,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -69,12 +69,12 @@ GpuExecutable::GpuExecutable(GpuExecutable::Params params)
|
||||
entry_computation_profile_index_(params.entry_computation_profile_index),
|
||||
constants_(std::move(params.constants)),
|
||||
output_info_(std::move(params.output_info)) {
|
||||
GpuDebugInfoManager::Get()->RegisterModule(module_name_, shared_module(),
|
||||
XlaDebugInfoManager::Get()->RegisterModule(module_name_, shared_module(),
|
||||
debug_buffer_assignment_);
|
||||
}
|
||||
|
||||
GpuExecutable::~GpuExecutable() {
|
||||
GpuDebugInfoManager::Get()->UnregisterModule(module_name_, shared_module(),
|
||||
XlaDebugInfoManager::Get()->UnregisterModule(module_name_, shared_module(),
|
||||
debug_buffer_assignment_);
|
||||
|
||||
{
|
||||
@ -131,9 +131,9 @@ Status GpuExecutable::ExecuteThunks(
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckCompatibilityWithServiceExecutableRunOptions(run_options));
|
||||
GpuDebugInfoManager::Get()->OnModuleStart(module_name_);
|
||||
XlaDebugInfoManager::Get()->OnModuleStart(module_name_);
|
||||
auto cleanup = MakeCleanup(
|
||||
[&]() { GpuDebugInfoManager::Get()->OnModuleStop(module_name_); });
|
||||
[&]() { XlaDebugInfoManager::Get()->OnModuleStop(module_name_); });
|
||||
|
||||
se::Stream* main_stream = run_options->stream();
|
||||
se::StreamExecutor* executor = main_stream->parent();
|
||||
|
@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
void GpuDebugInfoManager::RegisterModule(
|
||||
void XlaDebugInfoManager::RegisterModule(
|
||||
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
@ -28,7 +27,7 @@ void GpuDebugInfoManager::RegisterModule(
|
||||
active_modules_[module_id].instances.emplace_back(hlo_module,
|
||||
buffer_assignment);
|
||||
} else {
|
||||
GpuModuleEntry m;
|
||||
XlaModuleEntry m;
|
||||
m.module_id = module_id;
|
||||
m.instances.emplace_back(hlo_module, buffer_assignment);
|
||||
active_modules_[module_id] = std::move(m);
|
||||
@ -38,14 +37,14 @@ void GpuDebugInfoManager::RegisterModule(
|
||||
// Unregister an active module, when the last active module of the same
|
||||
// module id is out of scope, we remove it from our database.
|
||||
// However during tracing, we will defer the cleanup after serialization.
|
||||
void GpuDebugInfoManager::UnregisterModule(
|
||||
void XlaDebugInfoManager::UnregisterModule(
|
||||
const ModuleIdentifier& module_id, std::shared_ptr<HloModule> hlo_module,
|
||||
std::shared_ptr<const BufferAssignmentProto> buffer_assignment) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
CHECK(active_modules_.find(module_id) != active_modules_.end());
|
||||
GpuModuleEntry& active_module = active_modules_[module_id];
|
||||
XlaModuleEntry& active_module = active_modules_[module_id];
|
||||
auto instance_it =
|
||||
absl::c_find_if(active_module.instances, [&](GpuModuleInstance& e) {
|
||||
absl::c_find_if(active_module.instances, [&](XlaModuleInstance& e) {
|
||||
return e.hlo_module == hlo_module &&
|
||||
e.buffer_assignment == buffer_assignment;
|
||||
});
|
||||
@ -62,12 +61,12 @@ void GpuDebugInfoManager::UnregisterModule(
|
||||
}
|
||||
}
|
||||
|
||||
void GpuDebugInfoManager::OnModuleStart(ModuleIdentifier module_id) {
|
||||
void XlaDebugInfoManager::OnModuleStart(ModuleIdentifier module_id) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
running_module_ids_[module_id]++;
|
||||
}
|
||||
|
||||
void GpuDebugInfoManager::OnModuleStop(ModuleIdentifier module_id) {
|
||||
void XlaDebugInfoManager::OnModuleStop(ModuleIdentifier module_id) {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
if (--running_module_ids_[module_id] == 0) {
|
||||
if (!tracing_active_) {
|
||||
@ -76,17 +75,17 @@ void GpuDebugInfoManager::OnModuleStop(ModuleIdentifier module_id) {
|
||||
}
|
||||
}
|
||||
|
||||
void GpuDebugInfoManager::StartTracing() {
|
||||
void XlaDebugInfoManager::StartTracing() {
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
tracing_active_ = true;
|
||||
}
|
||||
|
||||
void GpuDebugInfoManager::StopTracing(
|
||||
std::vector<GpuModuleDebugInfo>* module_debug_info) {
|
||||
std::vector<GpuModuleEntry> modules_to_serialize;
|
||||
void XlaDebugInfoManager::StopTracing(
|
||||
std::vector<XlaModuleDebugInfo>* module_debug_info) {
|
||||
std::vector<XlaModuleEntry> modules_to_serialize;
|
||||
{
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
CHECK(tracing_active_);
|
||||
if (!tracing_active_) return;
|
||||
tracing_active_ = false;
|
||||
for (const auto& running_module_id : running_module_ids_) {
|
||||
const ModuleIdentifier& module_id = running_module_id.first;
|
||||
@ -94,13 +93,13 @@ void GpuDebugInfoManager::StopTracing(
|
||||
LOG(ERROR) << "Cannot find debug info for module: " << module_id;
|
||||
continue;
|
||||
}
|
||||
const GpuModuleEntry& active_module = active_modules_[module_id];
|
||||
const XlaModuleEntry& active_module = active_modules_[module_id];
|
||||
|
||||
// Copy the instance so that we can serialize without holding the lock.
|
||||
// All instances are equivalent from the perspective of symbolization.
|
||||
// We only use the first one.
|
||||
if (!active_module.instances.empty()) {
|
||||
GpuModuleEntry e;
|
||||
XlaModuleEntry e;
|
||||
e.module_id = active_module.module_id;
|
||||
e.instances.push_back(active_module.instances[0]);
|
||||
modules_to_serialize.push_back(std::move(e));
|
||||
@ -140,7 +139,7 @@ void GpuDebugInfoManager::StopTracing(
|
||||
if (module_debug_info) {
|
||||
module_debug_info->clear();
|
||||
for (const auto& m : modules_to_serialize) {
|
||||
GpuModuleDebugInfo info;
|
||||
XlaModuleDebugInfo info;
|
||||
info.module_id = m.module_id;
|
||||
// In real world, hlo_module and buffer_assignment will always be
|
||||
// non-nullptr. Due to the inconvenience of creation of buffer_assignment
|
||||
@ -156,5 +155,4 @@ void GpuDebugInfoManager::StopTracing(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_DEBUG_INFO_MANAGER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_DEBUG_INFO_MANAGER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
@ -22,13 +22,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
using ModuleIdentifier = string;
|
||||
|
||||
struct GpuModuleDebugInfo {
|
||||
struct XlaModuleDebugInfo {
|
||||
ModuleIdentifier module_id;
|
||||
// The hlo proto associated with this gpu program.
|
||||
// The hlo proto associated with this xla program.
|
||||
std::unique_ptr<HloProto> hlo_proto;
|
||||
// TODO(b/133503446): We might need add performance info from cost analysis
|
||||
// and DeviceDescription which contains peak memory bandwidth, clock speed,
|
||||
@ -44,14 +43,14 @@ struct GpuModuleDebugInfo {
|
||||
// information. We will only keep track unique debug information, identified
|
||||
// by module_id.
|
||||
// This class is thread-safe.
|
||||
class GpuDebugInfoManager {
|
||||
class XlaDebugInfoManager {
|
||||
public:
|
||||
static GpuDebugInfoManager* Get() {
|
||||
static GpuDebugInfoManager* singleton = new GpuDebugInfoManager();
|
||||
static XlaDebugInfoManager* Get() {
|
||||
static XlaDebugInfoManager* singleton = new XlaDebugInfoManager();
|
||||
return singleton;
|
||||
}
|
||||
|
||||
// Register an active module to GpuDebugInfoManager. We will keep track all
|
||||
// Register an active module to XlaDebugInfoManager. We will keep track all
|
||||
// existing HloModules within the process.
|
||||
// Modules with same module id can be registered and tracked separately.
|
||||
void RegisterModule(
|
||||
@ -79,12 +78,12 @@ class GpuDebugInfoManager {
|
||||
// Then drop all modules that have no instances registered. Dump debug
|
||||
// information for all the running modules to module_debug_info if specified.
|
||||
void StopTracing(
|
||||
std::vector<GpuModuleDebugInfo>* module_debug_info = nullptr);
|
||||
std::vector<XlaModuleDebugInfo>* module_debug_info = nullptr);
|
||||
|
||||
friend class GpuDebugInfoManagerTest;
|
||||
friend class XlaDebugInfoManagerTest;
|
||||
|
||||
private:
|
||||
GpuDebugInfoManager() {}
|
||||
XlaDebugInfoManager() {}
|
||||
|
||||
// Test accessors.
|
||||
std::set<ModuleIdentifier> GetRunningModules() {
|
||||
@ -108,8 +107,8 @@ class GpuDebugInfoManager {
|
||||
// can have same unique id if they are actually same program. From the
|
||||
// perspective of symbol table, they are identical, but for the life time
|
||||
// tracking, they need to be tracked separately.
|
||||
struct GpuModuleInstance {
|
||||
GpuModuleInstance(std::shared_ptr<HloModule> m,
|
||||
struct XlaModuleInstance {
|
||||
XlaModuleInstance(std::shared_ptr<HloModule> m,
|
||||
std::shared_ptr<const BufferAssignmentProto> b)
|
||||
: hlo_module(std::move(m)), buffer_assignment(std::move(b)) {}
|
||||
std::shared_ptr<HloModule> hlo_module;
|
||||
@ -117,12 +116,12 @@ class GpuDebugInfoManager {
|
||||
bool active = true;
|
||||
};
|
||||
|
||||
// Each GpuModuleEntry can have multiple GpuModuleInstance's if XlA registers
|
||||
// Each XlaModuleEntry can have multiple XlaModuleInstance's if XlA registers
|
||||
// them with the same ModuleIdentifier.
|
||||
struct GpuModuleEntry {
|
||||
struct XlaModuleEntry {
|
||||
// The module symbol table/debug info that shared by all instances.
|
||||
ModuleIdentifier module_id;
|
||||
std::vector<GpuModuleInstance> instances;
|
||||
std::vector<XlaModuleInstance> instances;
|
||||
};
|
||||
|
||||
tensorflow::mutex mutex_;
|
||||
@ -135,11 +134,10 @@ class GpuDebugInfoManager {
|
||||
// Active modules are those still tracked by us. There could be much more
|
||||
// active modules than running modules, we will try to reduce the trace size
|
||||
// by only transfer those modules that were running during tracing period.
|
||||
absl::flat_hash_map<ModuleIdentifier, GpuModuleEntry> active_modules_
|
||||
absl::flat_hash_map<ModuleIdentifier, XlaModuleEntry> active_modules_
|
||||
TF_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_DEBUG_INFO_MANAGER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_XLA_DEBUG_INFO_MANAGER_H_
|
@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
class XlaDebugInfoManagerTest : public HloTestBase {
|
||||
protected:
|
||||
struct DebugMetadata {
|
||||
// We allow same id to be registered multiple times. we need unique id to
|
||||
@ -41,7 +40,7 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
debug_info.id = module_id;
|
||||
debug_info.module = std::make_shared<HloModule>(module_id, config);
|
||||
debug_info.buffer_assignment = nullptr;
|
||||
gpu_debug_info_manager_.RegisterModule(module_id, debug_info.module,
|
||||
xla_debug_info_manager_.RegisterModule(module_id, debug_info.module,
|
||||
debug_info.buffer_assignment);
|
||||
external_references_.push_back(std::move(debug_info));
|
||||
return serial_;
|
||||
@ -50,7 +49,7 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
void UnregisterProgram(int unique_id) {
|
||||
for (int i = 0; i < external_references_.size(); i++) {
|
||||
if (external_references_[i].unique_id == unique_id) {
|
||||
gpu_debug_info_manager_.UnregisterModule(
|
||||
xla_debug_info_manager_.UnregisterModule(
|
||||
external_references_[i].id, external_references_[i].module,
|
||||
external_references_[i].buffer_assignment);
|
||||
external_references_.erase(external_references_.begin() + i);
|
||||
@ -62,7 +61,7 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
void StartProgram(int unique_id) {
|
||||
for (int i = 0; i < external_references_.size(); i++) {
|
||||
if (external_references_[i].unique_id == unique_id) {
|
||||
gpu_debug_info_manager_.OnModuleStart(external_references_[i].id);
|
||||
xla_debug_info_manager_.OnModuleStart(external_references_[i].id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -71,7 +70,7 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
void StopProgram(int unique_id) {
|
||||
for (int i = 0; i < external_references_.size(); i++) {
|
||||
if (external_references_[i].unique_id == unique_id) {
|
||||
gpu_debug_info_manager_.OnModuleStop(external_references_[i].id);
|
||||
xla_debug_info_manager_.OnModuleStop(external_references_[i].id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -83,17 +82,17 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
}
|
||||
|
||||
std::set<ModuleIdentifier> GetRunningModule() {
|
||||
return gpu_debug_info_manager_.GetRunningModules();
|
||||
return xla_debug_info_manager_.GetRunningModules();
|
||||
}
|
||||
std::set<ModuleIdentifier> GetActiveModule() {
|
||||
return gpu_debug_info_manager_.GetActiveModules();
|
||||
return xla_debug_info_manager_.GetActiveModules();
|
||||
}
|
||||
|
||||
void StartTrace() { gpu_debug_info_manager_.StartTracing(); }
|
||||
void StartTrace() { xla_debug_info_manager_.StartTracing(); }
|
||||
|
||||
std::set<ModuleIdentifier> StopTrace() {
|
||||
std::vector<GpuModuleDebugInfo> module_debug_info;
|
||||
gpu_debug_info_manager_.StopTracing(&module_debug_info);
|
||||
std::vector<XlaModuleDebugInfo> module_debug_info;
|
||||
xla_debug_info_manager_.StopTracing(&module_debug_info);
|
||||
std::set<ModuleIdentifier> serialized;
|
||||
for (const auto& module : module_debug_info) {
|
||||
serialized.insert(module.module_id);
|
||||
@ -107,11 +106,11 @@ class GpuDebugInfoManagerTest : public HloTestBase {
|
||||
std::vector<DebugMetadata> external_references_;
|
||||
|
||||
// Use an instance per test instead of singleton to avoid interferences.
|
||||
GpuDebugInfoManager gpu_debug_info_manager_;
|
||||
XlaDebugInfoManager xla_debug_info_manager_;
|
||||
};
|
||||
|
||||
// Test the cases where no trace session is involved.
|
||||
TEST_F(GpuDebugInfoManagerTest, NoTraceBasic) {
|
||||
TEST_F(XlaDebugInfoManagerTest, NoTraceBasic) {
|
||||
auto program0 = RegisterProgram("program0");
|
||||
EXPECT_THAT(GetActiveModule(), UnorderedElementsAre("program0"));
|
||||
EXPECT_TRUE(GetRunningModule().empty());
|
||||
@ -135,7 +134,7 @@ TEST_F(GpuDebugInfoManagerTest, NoTraceBasic) {
|
||||
EXPECT_TRUE(GetActiveModule().empty());
|
||||
}
|
||||
|
||||
TEST_F(GpuDebugInfoManagerTest, NoTraceDuplicateIds) {
|
||||
TEST_F(XlaDebugInfoManagerTest, NoTraceDuplicateIds) {
|
||||
auto program0A = RegisterProgram("program0");
|
||||
auto program0B = RegisterProgram("program0"); // duplicates
|
||||
auto program1 = RegisterProgram("program1");
|
||||
@ -163,7 +162,7 @@ TEST_F(GpuDebugInfoManagerTest, NoTraceDuplicateIds) {
|
||||
}
|
||||
|
||||
// Test the cases where an active trace session is involved.
|
||||
TEST_F(GpuDebugInfoManagerTest, ActiveTrace) {
|
||||
TEST_F(XlaDebugInfoManagerTest, ActiveTrace) {
|
||||
auto program0A = RegisterProgram("program0");
|
||||
auto program0B = RegisterProgram("program0"); // duplicates
|
||||
auto program1 = RegisterProgram("program1");
|
||||
@ -195,7 +194,7 @@ TEST_F(GpuDebugInfoManagerTest, ActiveTrace) {
|
||||
EXPECT_TRUE(GetActiveModule().empty());
|
||||
}
|
||||
|
||||
TEST_F(GpuDebugInfoManagerTest, UnregisterDuringTrace) {
|
||||
TEST_F(XlaDebugInfoManagerTest, UnregisterDuringTrace) {
|
||||
auto program0A = RegisterProgram("program0");
|
||||
auto program0B = RegisterProgram("program0"); // duplicates
|
||||
auto program1 = RegisterProgram("program1");
|
||||
@ -211,5 +210,4 @@ TEST_F(GpuDebugInfoManagerTest, UnregisterDuringTrace) {
|
||||
UnregisterProgram(program0A);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -165,7 +165,7 @@ cc_library(
|
||||
copts = tf_profiler_copts(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_debug_info_manager",
|
||||
"//tensorflow/compiler/xla/service:xla_debug_info_manager",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler:profiler_options_proto_cc",
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#endif
|
||||
|
||||
TF_PROFILER_DISABLE_CXX17_WARNINGS
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/xla_debug_info_manager.h"
|
||||
TF_PROFILER_ENABLE_CXX17_WARNINGS
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -57,7 +57,7 @@ class MetadataCollector : public ProfilerInterface {
|
||||
|
||||
Status Start() override {
|
||||
if (!trace_active_) {
|
||||
xla::gpu::GpuDebugInfoManager::Get()->StartTracing();
|
||||
xla::XlaDebugInfoManager::Get()->StartTracing();
|
||||
trace_active_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
@ -65,7 +65,7 @@ class MetadataCollector : public ProfilerInterface {
|
||||
|
||||
Status Stop() override {
|
||||
if (trace_active_) {
|
||||
xla::gpu::GpuDebugInfoManager::Get()->StopTracing(&debug_info_);
|
||||
xla::XlaDebugInfoManager::Get()->StopTracing(&debug_info_);
|
||||
trace_active_ = false;
|
||||
}
|
||||
return Status::OK();
|
||||
@ -91,7 +91,7 @@ class MetadataCollector : public ProfilerInterface {
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<xla::gpu::GpuModuleDebugInfo> debug_info_;
|
||||
std::vector<xla::XlaModuleDebugInfo> debug_info_;
|
||||
bool trace_active_ = false;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MetadataCollector);
|
||||
|
Loading…
x
Reference in New Issue
Block a user