proper fix the earlier compile error for --config=rocm, + some minor changes
This commit is contained in:
parent
896ad1053b
commit
aa26dce3be
@ -33,7 +33,7 @@ filegroup(
|
||||
cc_library(
|
||||
name = "rocm_diagnostics",
|
||||
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
|
||||
hdrs = [],
|
||||
hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
|
||||
deps = if_rocm_is_configured([
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/lib/error.h"
|
||||
#include "tensorflow/stream_executor/lib/numbers.h"
|
||||
#include "tensorflow/stream_executor/lib/process_state.h"
|
||||
@ -40,7 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
string DriverVersionToString(DriverVersion version) {
|
||||
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
|
||||
@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
||||
// -- class Diagnostician
|
||||
|
||||
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
||||
@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
|
||||
}
|
||||
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
|
||||
LOG(INFO) << "librocm reported version is: "
|
||||
<< DriverVersionStatusToString(dso_version);
|
||||
<< rocm::DriverVersionStatusToString(dso_version);
|
||||
|
||||
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
|
||||
LOG(INFO) << "kernel reported version is: "
|
||||
<< DriverVersionStatusToString(kernel_version);
|
||||
<< rocm::DriverVersionStatusToString(kernel_version);
|
||||
|
||||
if (kernel_version.ok() && dso_version.ok()) {
|
||||
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
||||
@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
|
||||
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
|
||||
*result = StringToDriverVersion(stripped_dso_version);
|
||||
*result = rocm::StringToDriverVersion(stripped_dso_version);
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
|
||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||
auto stripped_kernel_version =
|
||||
port::StripSuffixString(kernel_version, ".ld64");
|
||||
return StringToDriverVersion(stripped_kernel_version);
|
||||
return rocm::StringToDriverVersion(stripped_kernel_version);
|
||||
}
|
||||
|
||||
void Diagnostician::WarnOnDsoKernelMismatch(
|
||||
@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
|
||||
if (kernel_version.ok() && dso_version.ok() &&
|
||||
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
|
||||
LOG(INFO) << "kernel version seems to match DSO: "
|
||||
<< DriverVersionToString(kernel_version.ValueOrDie());
|
||||
<< rocm::DriverVersionToString(kernel_version.ValueOrDie());
|
||||
} else {
|
||||
LOG(ERROR) << "kernel version "
|
||||
<< DriverVersionStatusToString(kernel_version)
|
||||
<< rocm::DriverVersionStatusToString(kernel_version)
|
||||
<< " does not match DSO version "
|
||||
<< DriverVersionStatusToString(dso_version)
|
||||
<< rocm::DriverVersionStatusToString(dso_version)
|
||||
<< " -- cannot find working devices in this configuration";
|
||||
}
|
||||
}
|
||||
|
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace rocm {
|
||||
|
||||
// e.g. DriverVersion{346, 3, 4}
|
||||
using DriverVersion = gpu::DriverVersion;
|
||||
|
||||
// Converts a parsed driver version to string form.
|
||||
string DriverVersionToString(DriverVersion version);
|
||||
|
||||
// Converts a parsed driver version or status value to natural string form.
|
||||
string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
|
||||
|
||||
// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
|
||||
port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
|
||||
|
||||
using Diagnostician = gpu::Diagnostician;
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_event.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
@ -41,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
|
||||
blas::BlasSupport* GpuExecutor::CreateBlas() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::BlasFactory> status =
|
||||
registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.blas());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve BLAS factory: "
|
||||
@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
|
||||
dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::DnnFactory> status =
|
||||
registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.dnn());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve DNN factory: "
|
||||
@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
||||
fft::FftSupport* GpuExecutor::CreateFft() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::FftFactory> status =
|
||||
registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.fft());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve FFT factory: "
|
||||
@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
|
||||
rng::RngSupport* GpuExecutor::CreateRng() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::RngFactory> status =
|
||||
registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.rng());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve RNG factory: "
|
||||
@ -878,12 +878,9 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
|
||||
{
|
||||
int driver_version = 0;
|
||||
(void)GpuDriver::GetDriverVersion(&driver_version);
|
||||
string augmented_driver_version =
|
||||
absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
|
||||
// FIXME:
|
||||
// uncomment the line below once the "DriverVersionStatusToString"
|
||||
// routine is moved from the "cuda" namespace to the "gpu" naemspace
|
||||
// DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
|
||||
string augmented_driver_version = absl::StrFormat(
|
||||
"%d (%s)", driver_version,
|
||||
rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
|
||||
builder.set_driver_version(augmented_driver_version);
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
|
||||
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
|
||||
}
|
||||
|
||||
Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
|
||||
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
|
||||
|
||||
int ROCmPlatform::VisibleDeviceCount() const {
|
||||
// Throw away the result - it logs internally, and this [containing] function
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
PLATFORM_DEFINE_ID(kROCmPlatformId);
|
||||
|
||||
|
@ -19,16 +19,16 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
// Opaque and unique identifier for the ROCm platform.
|
||||
// This is needed so that plugins can refer to/identify this platform without
|
||||
// instantiating a ROCmPlatform object.
|
||||
// This is broken out here to avoid a circular dependency between ROCmPlatform
|
||||
// and GpuExecutor.
|
||||
// and ROCmExecutor.
|
||||
extern const Platform::Id kROCmPlatformId;
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_
|
||||
|
@ -253,7 +253,7 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
|
||||
se::port::Status status =
|
||||
se::PluginRegistry::Instance()
|
||||
->RegisterFactory<se::PluginRegistry::RngFactory>(
|
||||
se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
|
||||
se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
|
||||
[](se::internal::StreamExecutorInterface* parent)
|
||||
-> se::rng::RngSupport* {
|
||||
se::gpu::GpuExecutor* rocm_executor =
|
||||
@ -280,5 +280,5 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
|
||||
}
|
||||
|
||||
se::PluginRegistry::Instance()->SetDefaultFactory(
|
||||
se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
|
||||
se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user