proper fix the earlier compile error for --config=rocm, + some minor changes

This commit is contained in:
Deven Desai 2019-01-31 15:16:57 +00:00
parent 896ad1053b
commit aa26dce3be
8 changed files with 72 additions and 28 deletions

View File

@ -33,7 +33,7 @@ filegroup(
cc_library(
name = "rocm_diagnostics",
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
hdrs = [],
hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
deps = if_rocm_is_configured([
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",

View File

@ -30,7 +30,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/numbers.h"
#include "tensorflow/stream_executor/lib/process_state.h"
@ -40,7 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
string DriverVersionToString(DriverVersion version) {
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
return result;
}
} // namespace rocm
} // namespace stream_executor
namespace stream_executor {
namespace gpu {
// -- class Diagnostician
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
}
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
LOG(INFO) << "librocm reported version is: "
<< DriverVersionStatusToString(dso_version);
<< rocm::DriverVersionStatusToString(dso_version);
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
LOG(INFO) << "kernel reported version is: "
<< DriverVersionStatusToString(kernel_version);
<< rocm::DriverVersionStatusToString(kernel_version);
if (kernel_version.ok() && dso_version.ok()) {
WarnOnDsoKernelMismatch(dso_version, kernel_version);
@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
// TODO(b/22689637): Eliminate the explicit namespace if possible.
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
*result = StringToDriverVersion(stripped_dso_version);
*result = rocm::StringToDriverVersion(stripped_dso_version);
return 1;
}
return 0;
@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
// TODO(b/22689637): Eliminate the explicit namespace if possible.
auto stripped_kernel_version =
port::StripSuffixString(kernel_version, ".ld64");
return StringToDriverVersion(stripped_kernel_version);
return rocm::StringToDriverVersion(stripped_kernel_version);
}
void Diagnostician::WarnOnDsoKernelMismatch(
@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
if (kernel_version.ok() && dso_version.ok() &&
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
LOG(INFO) << "kernel version seems to match DSO: "
<< DriverVersionToString(kernel_version.ValueOrDie());
<< rocm::DriverVersionToString(kernel_version.ValueOrDie());
} else {
LOG(ERROR) << "kernel version "
<< DriverVersionStatusToString(kernel_version)
<< rocm::DriverVersionStatusToString(kernel_version)
<< " does not match DSO version "
<< DriverVersionStatusToString(dso_version)
<< rocm::DriverVersionStatusToString(dso_version)
<< " -- cannot find working devices in this configuration";
}
}

View File

@ -0,0 +1,41 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
namespace stream_executor {
namespace rocm {
// e.g. DriverVersion{346, 3, 4}
using DriverVersion = gpu::DriverVersion;
// Converts a parsed driver version to string form.
string DriverVersionToString(DriverVersion version);
// Converts a parsed driver version or status value to natural string form.
string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
using Diagnostician = gpu::Diagnostician;
} // namespace rocm
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
#include "tensorflow/stream_executor/gpu/gpu_event.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
@ -41,6 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
blas::BlasSupport* GpuExecutor::CreateBlas() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::BlasFactory> status =
registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
plugin_config_.blas());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve BLAS factory: "
@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
dnn::DnnSupport* GpuExecutor::CreateDnn() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::DnnFactory> status =
registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
plugin_config_.dnn());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve DNN factory: "
@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
fft::FftSupport* GpuExecutor::CreateFft() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::FftFactory> status =
registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
plugin_config_.fft());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve FFT factory: "
@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
rng::RngSupport* GpuExecutor::CreateRng() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::RngFactory> status =
registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
plugin_config_.rng());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve RNG factory: "
@ -878,12 +878,9 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
{
int driver_version = 0;
(void)GpuDriver::GetDriverVersion(&driver_version);
string augmented_driver_version =
absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
// FIXME:
// uncomment the line below once the "DriverVersionStatusToString"
// routine is moved from the "cuda" namespace to the "gpu" naemspace
// DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
string augmented_driver_version = absl::StrFormat(
"%d (%s)", driver_version,
rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
builder.set_driver_version(augmented_driver_version);
}

View File

@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
}
Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
int ROCmPlatform::VisibleDeviceCount() const {
// Throw away the result - it logs internally, and this [containing] function

View File

@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
PLATFORM_DEFINE_ID(kROCmPlatformId);

View File

@ -19,16 +19,16 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
// Opaque and unique identifier for the ROCm platform.
// This is needed so that plugins can refer to/identify this platform without
// instantiating a ROCmPlatform object.
// This is broken out here to avoid a circular dependency between ROCmPlatform
// and GpuExecutor.
// and ROCmExecutor.
extern const Platform::Id kROCmPlatformId;
} // namespace gpu
} // namespace rocm
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_

View File

@ -253,7 +253,7 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
se::port::Status status =
se::PluginRegistry::Instance()
->RegisterFactory<se::PluginRegistry::RngFactory>(
se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
se::rocm::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
[](se::internal::StreamExecutorInterface* parent)
-> se::rng::RngSupport* {
se::gpu::GpuExecutor* rocm_executor =
@ -280,5 +280,5 @@ REGISTER_MODULE_INITIALIZER(register_hiprand, {
}
se::PluginRegistry::Instance()->SetDefaultFactory(
se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
se::rocm::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
});