Qualify uses of std::string
PiperOrigin-RevId: 297212802 Change-Id: Ic65150e7ab418be034f48d45ce25ef5d19105836
This commit is contained in:
parent
9098692bf4
commit
155ce6c067
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
string AllocatorStats::DebugString() const {
|
std::string AllocatorStats::DebugString() const {
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"Limit: %20lld\n"
|
"Limit: %20lld\n"
|
||||||
"InUse: %20lld\n"
|
"InUse: %20lld\n"
|
||||||
|
@ -51,7 +51,7 @@ struct AllocatorStats {
|
|||||||
bytes_reserved(0),
|
bytes_reserved(0),
|
||||||
peak_bytes_reserved(0) {}
|
peak_bytes_reserved(0) {}
|
||||||
|
|
||||||
string DebugString() const;
|
std::string DebugString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace blas {
|
namespace blas {
|
||||||
|
|
||||||
string TransposeString(Transpose t) {
|
std::string TransposeString(Transpose t) {
|
||||||
switch (t) {
|
switch (t) {
|
||||||
case Transpose::kNoTranspose:
|
case Transpose::kNoTranspose:
|
||||||
return "NoTranspose";
|
return "NoTranspose";
|
||||||
@ -33,7 +33,7 @@ string TransposeString(Transpose t) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string UpperLowerString(UpperLower ul) {
|
std::string UpperLowerString(UpperLower ul) {
|
||||||
switch (ul) {
|
switch (ul) {
|
||||||
case UpperLower::kUpper:
|
case UpperLower::kUpper:
|
||||||
return "Upper";
|
return "Upper";
|
||||||
@ -44,7 +44,7 @@ string UpperLowerString(UpperLower ul) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string DiagonalString(Diagonal d) {
|
std::string DiagonalString(Diagonal d) {
|
||||||
switch (d) {
|
switch (d) {
|
||||||
case Diagonal::kUnit:
|
case Diagonal::kUnit:
|
||||||
return "Unit";
|
return "Unit";
|
||||||
@ -55,7 +55,7 @@ string DiagonalString(Diagonal d) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string SideString(Side s) {
|
std::string SideString(Side s) {
|
||||||
switch (s) {
|
switch (s) {
|
||||||
case Side::kLeft:
|
case Side::kLeft:
|
||||||
return "Left";
|
return "Left";
|
||||||
@ -68,9 +68,11 @@ string SideString(Side s) {
|
|||||||
|
|
||||||
// -- AlgorithmConfig
|
// -- AlgorithmConfig
|
||||||
|
|
||||||
string AlgorithmConfig::ToString() const { return absl::StrCat(algorithm_); }
|
std::string AlgorithmConfig::ToString() const {
|
||||||
|
return absl::StrCat(algorithm_);
|
||||||
|
}
|
||||||
|
|
||||||
string ComputationTypeString(ComputationType ty) {
|
std::string ComputationTypeString(ComputationType ty) {
|
||||||
switch (ty) {
|
switch (ty) {
|
||||||
case ComputationType::kF16:
|
case ComputationType::kF16:
|
||||||
return "f16";
|
return "f16";
|
||||||
|
@ -67,27 +67,27 @@ namespace blas {
|
|||||||
enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
|
enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose };
|
||||||
|
|
||||||
// Returns a name for t.
|
// Returns a name for t.
|
||||||
string TransposeString(Transpose t);
|
std::string TransposeString(Transpose t);
|
||||||
|
|
||||||
// Specifies whether the upper or lower triangular part of a
|
// Specifies whether the upper or lower triangular part of a
|
||||||
// symmetric/Hermitian matrix is used.
|
// symmetric/Hermitian matrix is used.
|
||||||
enum class UpperLower { kUpper, kLower };
|
enum class UpperLower { kUpper, kLower };
|
||||||
|
|
||||||
// Returns a name for ul.
|
// Returns a name for ul.
|
||||||
string UpperLowerString(UpperLower ul);
|
std::string UpperLowerString(UpperLower ul);
|
||||||
|
|
||||||
// Specifies whether a matrix is unit triangular.
|
// Specifies whether a matrix is unit triangular.
|
||||||
enum class Diagonal { kUnit, kNonUnit };
|
enum class Diagonal { kUnit, kNonUnit };
|
||||||
|
|
||||||
// Returns a name for d.
|
// Returns a name for d.
|
||||||
string DiagonalString(Diagonal d);
|
std::string DiagonalString(Diagonal d);
|
||||||
|
|
||||||
// Specifies whether a Hermitian matrix appears on the left or right in
|
// Specifies whether a Hermitian matrix appears on the left or right in
|
||||||
// operation.
|
// operation.
|
||||||
enum class Side { kLeft, kRight };
|
enum class Side { kLeft, kRight };
|
||||||
|
|
||||||
// Returns a name for s.
|
// Returns a name for s.
|
||||||
string SideString(Side s);
|
std::string SideString(Side s);
|
||||||
|
|
||||||
// Type with which intermediate computations of a blas routine are performed.
|
// Type with which intermediate computations of a blas routine are performed.
|
||||||
//
|
//
|
||||||
@ -104,7 +104,7 @@ enum class ComputationType {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Converts a ComputationType to a string.
|
// Converts a ComputationType to a string.
|
||||||
string ComputationTypeString(ComputationType ty);
|
std::string ComputationTypeString(ComputationType ty);
|
||||||
|
|
||||||
std::ostream &operator<<(std::ostream &os, ComputationType ty);
|
std::ostream &operator<<(std::ostream &os, ComputationType ty);
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class AlgorithmConfig {
|
|||||||
bool operator!=(const AlgorithmConfig &other) const {
|
bool operator!=(const AlgorithmConfig &other) const {
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AlgorithmType algorithm_;
|
AlgorithmType algorithm_;
|
||||||
@ -1383,7 +1383,7 @@ class BlasSupport {
|
|||||||
const DeviceMemory<std::complex<double>> &a, int lda,
|
const DeviceMemory<std::complex<double>> &a, int lda,
|
||||||
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
|
DeviceMemory<std::complex<double>> *b, int ldb) = 0;
|
||||||
|
|
||||||
virtual port::Status GetVersion(string *version) = 0;
|
virtual port::Status GetVersion(std::string *version) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BlasSupport() {}
|
BlasSupport() {}
|
||||||
@ -2196,7 +2196,7 @@ class BlasSupport {
|
|||||||
uint64 n, std::complex<double> alpha, \
|
uint64 n, std::complex<double> alpha, \
|
||||||
const DeviceMemory<std::complex<double>> &a, int lda, \
|
const DeviceMemory<std::complex<double>> &a, int lda, \
|
||||||
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
DeviceMemory<std::complex<double>> *b, int ldb) override; \
|
||||||
port::Status GetVersion(string *version) override;
|
port::Status GetVersion(std::string *version) override;
|
||||||
|
|
||||||
} // namespace blas
|
} // namespace blas
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -55,10 +55,11 @@ DeviceDescription::DeviceDescription()
|
|||||||
core_count_(-1),
|
core_count_(-1),
|
||||||
ecc_enabled_(false) {}
|
ecc_enabled_(false) {}
|
||||||
|
|
||||||
std::unique_ptr<std::map<string, string>> DeviceDescription::ToMap() const {
|
std::unique_ptr<std::map<std::string, std::string>> DeviceDescription::ToMap()
|
||||||
std::unique_ptr<std::map<string, string>> owned_result{
|
const {
|
||||||
new std::map<string, string>};
|
std::unique_ptr<std::map<std::string, std::string>> owned_result{
|
||||||
std::map<string, string> &result = *owned_result;
|
new std::map<std::string, std::string>};
|
||||||
|
std::map<std::string, std::string> &result = *owned_result;
|
||||||
result["Device Vendor"] = device_vendor();
|
result["Device Vendor"] = device_vendor();
|
||||||
result["Platform Version"] = platform_version();
|
result["Platform Version"] = platform_version();
|
||||||
result["Driver Version"] = driver_version();
|
result["Driver Version"] = driver_version();
|
||||||
|
@ -42,22 +42,22 @@ class DeviceDescription {
|
|||||||
// Returns the platform being run on; this value is primarily intended for
|
// Returns the platform being run on; this value is primarily intended for
|
||||||
// printing, and comes out something like "OpenCL 1.2" or "Compute Capability
|
// printing, and comes out something like "OpenCL 1.2" or "Compute Capability
|
||||||
// 3.5".
|
// 3.5".
|
||||||
const string &platform_version() const { return platform_version_; }
|
const std::string &platform_version() const { return platform_version_; }
|
||||||
|
|
||||||
// Returns the driver version interfacing with the underlying platform. Vendor
|
// Returns the driver version interfacing with the underlying platform. Vendor
|
||||||
// dependent format.
|
// dependent format.
|
||||||
const string &driver_version() const { return driver_version_; }
|
const std::string &driver_version() const { return driver_version_; }
|
||||||
|
|
||||||
// Return the runtime version, if one is provided by the underlying platform.
|
// Return the runtime version, if one is provided by the underlying platform.
|
||||||
// Vendor dependent format / usefulness.
|
// Vendor dependent format / usefulness.
|
||||||
const string &runtime_version() const { return runtime_version_; }
|
const std::string &runtime_version() const { return runtime_version_; }
|
||||||
|
|
||||||
// Returns the name that the device reports. Vendor dependent.
|
// Returns the name that the device reports. Vendor dependent.
|
||||||
const string &name() const { return name_; }
|
const std::string &name() const { return name_; }
|
||||||
|
|
||||||
// Returns the PCI bus identifier for this device, of the form
|
// Returns the PCI bus identifier for this device, of the form
|
||||||
// [domain]:[bus]:[device].[function]
|
// [domain]:[bus]:[device].[function]
|
||||||
const string &pci_bus_id() const { return pci_bus_id_; }
|
const std::string &pci_bus_id() const { return pci_bus_id_; }
|
||||||
|
|
||||||
// Returns the NUMA node associated with this device, for use in
|
// Returns the NUMA node associated with this device, for use in
|
||||||
// determining socket locality. If the NUMA node could not be determined, -1
|
// determining socket locality. If the NUMA node could not be determined, -1
|
||||||
@ -126,7 +126,7 @@ class DeviceDescription {
|
|||||||
|
|
||||||
// Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced
|
// Returns the device vendor string, e.g., "NVIDIA Corporation", "Advanced
|
||||||
// Micro Devices, Inc.", or "GenuineIntel".
|
// Micro Devices, Inc.", or "GenuineIntel".
|
||||||
const string &device_vendor() const { return device_vendor_; }
|
const std::string &device_vendor() const { return device_vendor_; }
|
||||||
|
|
||||||
// Returns the CUDA compute capability if we're running on the CUDA platform.
|
// Returns the CUDA compute capability if we're running on the CUDA platform.
|
||||||
// If a CUDA compute capability is not available, the major version will be
|
// If a CUDA compute capability is not available, the major version will be
|
||||||
@ -150,7 +150,7 @@ class DeviceDescription {
|
|||||||
// TODO(leary): resident blocks per core will be useful.
|
// TODO(leary): resident blocks per core will be useful.
|
||||||
|
|
||||||
// Convenience typedef for the string-based DeviceDescription mapping.
|
// Convenience typedef for the string-based DeviceDescription mapping.
|
||||||
typedef std::map<string, string> Map;
|
typedef std::map<std::string, std::string> Map;
|
||||||
|
|
||||||
// Returns a mapping from readable names to readable values that describe the
|
// Returns a mapping from readable names to readable values that describe the
|
||||||
// device. This is useful for things like printing.
|
// device. This is useful for things like printing.
|
||||||
@ -169,12 +169,12 @@ class DeviceDescription {
|
|||||||
// above.
|
// above.
|
||||||
//
|
//
|
||||||
// N.B. If another field is added, update ToMap() above.
|
// N.B. If another field is added, update ToMap() above.
|
||||||
string device_vendor_;
|
std::string device_vendor_;
|
||||||
string platform_version_;
|
std::string platform_version_;
|
||||||
string driver_version_;
|
std::string driver_version_;
|
||||||
string runtime_version_;
|
std::string runtime_version_;
|
||||||
string pci_bus_id_;
|
std::string pci_bus_id_;
|
||||||
string name_;
|
std::string name_;
|
||||||
|
|
||||||
ThreadDim thread_dim_limit_;
|
ThreadDim thread_dim_limit_;
|
||||||
BlockDim block_dim_limit_;
|
BlockDim block_dim_limit_;
|
||||||
@ -221,22 +221,24 @@ class DeviceDescriptionBuilder {
|
|||||||
// For descriptions of the following fields, see comments on the corresponding
|
// For descriptions of the following fields, see comments on the corresponding
|
||||||
// DeviceDescription::* accessors above.
|
// DeviceDescription::* accessors above.
|
||||||
|
|
||||||
void set_device_vendor(const string &value) {
|
void set_device_vendor(const std::string &value) {
|
||||||
device_description_->device_vendor_ = value;
|
device_description_->device_vendor_ = value;
|
||||||
}
|
}
|
||||||
void set_platform_version(const string &value) {
|
void set_platform_version(const std::string &value) {
|
||||||
device_description_->platform_version_ = value;
|
device_description_->platform_version_ = value;
|
||||||
}
|
}
|
||||||
void set_driver_version(const string &value) {
|
void set_driver_version(const std::string &value) {
|
||||||
device_description_->driver_version_ = value;
|
device_description_->driver_version_ = value;
|
||||||
}
|
}
|
||||||
void set_runtime_version(const string &value) {
|
void set_runtime_version(const std::string &value) {
|
||||||
device_description_->runtime_version_ = value;
|
device_description_->runtime_version_ = value;
|
||||||
}
|
}
|
||||||
void set_pci_bus_id(const string &value) {
|
void set_pci_bus_id(const std::string &value) {
|
||||||
device_description_->pci_bus_id_ = value;
|
device_description_->pci_bus_id_ = value;
|
||||||
}
|
}
|
||||||
void set_name(const string &value) { device_description_->name_ = value; }
|
void set_name(const std::string &value) {
|
||||||
|
device_description_->name_ = value;
|
||||||
|
}
|
||||||
|
|
||||||
void set_thread_dim_limit(const ThreadDim &value) {
|
void set_thread_dim_limit(const ThreadDim &value) {
|
||||||
device_description_->thread_dim_limit_ = value;
|
device_description_->thread_dim_limit_ = value;
|
||||||
|
@ -71,13 +71,13 @@ struct DeviceOptions {
|
|||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToString() {
|
std::string ToString() {
|
||||||
return flags_ == 0 ? "none" : "kDoNotReclaimStackAllocation";
|
return flags_ == 0 ? "none" : "kDoNotReclaimStackAllocation";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Platform-specific device options. Expressed as key-value pairs to avoid
|
// Platform-specific device options. Expressed as key-value pairs to avoid
|
||||||
// DeviceOptions subclass proliferation.
|
// DeviceOptions subclass proliferation.
|
||||||
std::map<string, string> non_portable_tags;
|
std::map<std::string, std::string> non_portable_tags;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
unsigned flags_;
|
unsigned flags_;
|
||||||
|
@ -33,7 +33,7 @@ uint64 AlgorithmDesc::hash() const {
|
|||||||
return absl::Hash<decltype(p)>()(p);
|
return absl::Hash<decltype(p)>()(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
string AlgorithmDesc::ToString() const {
|
std::string AlgorithmDesc::ToString() const {
|
||||||
if (tensor_ops_enabled()) {
|
if (tensor_ops_enabled()) {
|
||||||
return absl::StrCat(algo_id(), "#TC");
|
return absl::StrCat(algo_id(), "#TC");
|
||||||
} else {
|
} else {
|
||||||
@ -74,7 +74,7 @@ bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
string QuantizedActivationModeString(QuantizedActivationMode mode) {
|
std::string QuantizedActivationModeString(QuantizedActivationMode mode) {
|
||||||
switch (mode) {
|
switch (mode) {
|
||||||
case dnn::QuantizedActivationMode::k8Bit:
|
case dnn::QuantizedActivationMode::k8Bit:
|
||||||
return "uint8";
|
return "uint8";
|
||||||
@ -89,7 +89,7 @@ string QuantizedActivationModeString(QuantizedActivationMode mode) {
|
|||||||
return "unknown quantized_activation_mode";
|
return "unknown quantized_activation_mode";
|
||||||
}
|
}
|
||||||
|
|
||||||
string ActivationModeString(ActivationMode mode) {
|
std::string ActivationModeString(ActivationMode mode) {
|
||||||
switch (mode) {
|
switch (mode) {
|
||||||
case ActivationMode::kSigmoid:
|
case ActivationMode::kSigmoid:
|
||||||
return "sigmoid";
|
return "sigmoid";
|
||||||
@ -109,7 +109,7 @@ string ActivationModeString(ActivationMode mode) {
|
|||||||
return "unknown activation_mode";
|
return "unknown activation_mode";
|
||||||
}
|
}
|
||||||
|
|
||||||
string ElementwiseOperationString(ElementwiseOperation op) {
|
std::string ElementwiseOperationString(ElementwiseOperation op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case ElementwiseOperation::kAdd:
|
case ElementwiseOperation::kAdd:
|
||||||
return "add";
|
return "add";
|
||||||
@ -121,7 +121,7 @@ string ElementwiseOperationString(ElementwiseOperation op) {
|
|||||||
return "unknown element wise op";
|
return "unknown element wise op";
|
||||||
}
|
}
|
||||||
|
|
||||||
string DataLayoutString(DataLayout layout) {
|
std::string DataLayoutString(DataLayout layout) {
|
||||||
switch (layout) {
|
switch (layout) {
|
||||||
case DataLayout::kYXDepthBatch:
|
case DataLayout::kYXDepthBatch:
|
||||||
return "YXDepthBatch";
|
return "YXDepthBatch";
|
||||||
@ -139,7 +139,7 @@ string DataLayoutString(DataLayout layout) {
|
|||||||
return "unknown data layout";
|
return "unknown data layout";
|
||||||
}
|
}
|
||||||
|
|
||||||
string FilterLayoutString(FilterLayout layout) {
|
std::string FilterLayoutString(FilterLayout layout) {
|
||||||
switch (layout) {
|
switch (layout) {
|
||||||
case FilterLayout::kOutputInputYX:
|
case FilterLayout::kOutputInputYX:
|
||||||
return "OutputInputYX";
|
return "OutputInputYX";
|
||||||
@ -157,7 +157,7 @@ string FilterLayoutString(FilterLayout layout) {
|
|||||||
return "unknown filter layout";
|
return "unknown filter layout";
|
||||||
}
|
}
|
||||||
|
|
||||||
string PadAlignmentString(PadAlignment alignment) {
|
std::string PadAlignmentString(PadAlignment alignment) {
|
||||||
switch (alignment) {
|
switch (alignment) {
|
||||||
case PadAlignment::kDefault:
|
case PadAlignment::kDefault:
|
||||||
return "default";
|
return "default";
|
||||||
@ -173,7 +173,7 @@ std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
|
|||||||
return str << PadAlignmentString(alignment);
|
return str << PadAlignmentString(alignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ShortPoolingModeString(PoolingMode mode) {
|
std::string ShortPoolingModeString(PoolingMode mode) {
|
||||||
switch (mode) {
|
switch (mode) {
|
||||||
case PoolingMode::kMaximum:
|
case PoolingMode::kMaximum:
|
||||||
return "Max";
|
return "Max";
|
||||||
@ -247,12 +247,12 @@ std::vector<int64> ReorderDims(const std::vector<int64>& input,
|
|||||||
|
|
||||||
// -- AlgorithmConfig
|
// -- AlgorithmConfig
|
||||||
|
|
||||||
string AlgorithmConfig::ToString() const {
|
std::string AlgorithmConfig::ToString() const {
|
||||||
string algo = "none";
|
std::string algo = "none";
|
||||||
if (algorithm().has_value()) {
|
if (algorithm().has_value()) {
|
||||||
algo = algorithm()->ToString();
|
algo = algorithm()->ToString();
|
||||||
}
|
}
|
||||||
string algo_no_scratch = "none";
|
std::string algo_no_scratch = "none";
|
||||||
if (algorithm_no_scratch().has_value()) {
|
if (algorithm_no_scratch().has_value()) {
|
||||||
algo_no_scratch = algorithm_no_scratch()->ToString();
|
algo_no_scratch = algorithm_no_scratch()->ToString();
|
||||||
}
|
}
|
||||||
@ -306,8 +306,8 @@ void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
|
|||||||
quantized_activation_mode_ = other.quantized_activation_mode_;
|
quantized_activation_mode_ = other.quantized_activation_mode_;
|
||||||
}
|
}
|
||||||
|
|
||||||
string BatchDescriptor::ToString() const {
|
std::string BatchDescriptor::ToString() const {
|
||||||
string spatial;
|
std::string spatial;
|
||||||
for (int i = 0; i < ndims(); i++) {
|
for (int i = 0; i < ndims(); i++) {
|
||||||
absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
|
absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
|
||||||
}
|
}
|
||||||
@ -318,19 +318,19 @@ string BatchDescriptor::ToString() const {
|
|||||||
DataLayoutString(layout()));
|
DataLayoutString(layout()));
|
||||||
}
|
}
|
||||||
|
|
||||||
string BatchDescriptor::ToShortString() const {
|
std::string BatchDescriptor::ToShortString() const {
|
||||||
// All the constituent strings are less than 15 characters, so the
|
// All the constituent strings are less than 15 characters, so the
|
||||||
// small string optimization ensures that there will be at most one
|
// small string optimization ensures that there will be at most one
|
||||||
// heap memory allocation.
|
// heap memory allocation.
|
||||||
string depth = absl::StrCat("d", feature_map_count());
|
std::string depth = absl::StrCat("d", feature_map_count());
|
||||||
string batch = absl::StrCat("b", count());
|
std::string batch = absl::StrCat("b", count());
|
||||||
|
|
||||||
string spatial = "s";
|
std::string spatial = "s";
|
||||||
for (int i = 0; i < ndims(); i++) {
|
for (int i = 0; i < ndims(); i++) {
|
||||||
absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
|
absl::StrAppendFormat(&spatial, "%d ", spatial_size()[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
string suffix;
|
std::string suffix;
|
||||||
if (value_min() != value_max()) {
|
if (value_min() != value_max()) {
|
||||||
absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
|
absl::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
|
||||||
}
|
}
|
||||||
@ -419,8 +419,8 @@ void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
|
|||||||
tensor_ = other.tensor_;
|
tensor_ = other.tensor_;
|
||||||
}
|
}
|
||||||
|
|
||||||
string FilterDescriptor::ToString() const {
|
std::string FilterDescriptor::ToString() const {
|
||||||
string desc = absl::StrFormat(
|
std::string desc = absl::StrFormat(
|
||||||
"{output_feature_map_count: %d input_feature_map_count: %d "
|
"{output_feature_map_count: %d input_feature_map_count: %d "
|
||||||
"layout: %s shape: ",
|
"layout: %s shape: ",
|
||||||
output_feature_map_count(), input_feature_map_count(),
|
output_feature_map_count(), input_feature_map_count(),
|
||||||
@ -433,14 +433,14 @@ string FilterDescriptor::ToString() const {
|
|||||||
return desc;
|
return desc;
|
||||||
}
|
}
|
||||||
|
|
||||||
string FilterDescriptor::ToShortString() const {
|
std::string FilterDescriptor::ToShortString() const {
|
||||||
// All the constituent strings are less than 15 characters, so the
|
// All the constituent strings are less than 15 characters, so the
|
||||||
// small string optimization ensures that there will be at most one
|
// small string optimization ensures that there will be at most one
|
||||||
// heap memory allocation.
|
// heap memory allocation.
|
||||||
string od = absl::StrCat("od", output_feature_map_count());
|
std::string od = absl::StrCat("od", output_feature_map_count());
|
||||||
string id = absl::StrCat("id", input_feature_map_count());
|
std::string id = absl::StrCat("id", input_feature_map_count());
|
||||||
|
|
||||||
string spatial = "s";
|
std::string spatial = "s";
|
||||||
for (int i = 0; i < ndims(); i++) {
|
for (int i = 0; i < ndims(); i++) {
|
||||||
absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
|
absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]);
|
||||||
}
|
}
|
||||||
@ -491,10 +491,10 @@ ConvolutionDescriptor::ConvolutionDescriptor()
|
|||||||
|
|
||||||
ConvolutionDescriptor::~ConvolutionDescriptor() {}
|
ConvolutionDescriptor::~ConvolutionDescriptor() {}
|
||||||
|
|
||||||
string ConvolutionDescriptor::ToString() const {
|
std::string ConvolutionDescriptor::ToString() const {
|
||||||
string padding;
|
std::string padding;
|
||||||
string strides;
|
std::string strides;
|
||||||
string dilations;
|
std::string dilations;
|
||||||
for (int i = 0; i < ndims(); i++) {
|
for (int i = 0; i < ndims(); i++) {
|
||||||
absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
|
absl::StrAppendFormat(&padding, "%d ", this->padding()[i]);
|
||||||
absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
|
absl::StrAppendFormat(&strides, "%d ", this->strides()[i]);
|
||||||
@ -507,8 +507,8 @@ string ConvolutionDescriptor::ToString() const {
|
|||||||
padding, PadAlignmentString(pad_alignment()), strides, dilations);
|
padding, PadAlignmentString(pad_alignment()), strides, dilations);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ConvolutionDescriptor::ToShortString() const {
|
std::string ConvolutionDescriptor::ToShortString() const {
|
||||||
string desc;
|
std::string desc;
|
||||||
for (int i = 0; i < ndims(); i++) {
|
for (int i = 0; i < ndims(); i++) {
|
||||||
if (i > 0) absl::StrAppend(&desc, "_");
|
if (i > 0) absl::StrAppend(&desc, "_");
|
||||||
absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
|
absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]);
|
||||||
@ -543,11 +543,11 @@ void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
|
|||||||
propagate_nans_ = other.propagate_nans_;
|
propagate_nans_ = other.propagate_nans_;
|
||||||
}
|
}
|
||||||
|
|
||||||
string PoolingDescriptor::ToString() const {
|
std::string PoolingDescriptor::ToString() const {
|
||||||
const char* mode_string =
|
const char* mode_string =
|
||||||
mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
|
mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
|
||||||
|
|
||||||
string window, strides, padding;
|
std::string window, strides, padding;
|
||||||
for (int i = 0; i < ndims_; i++) {
|
for (int i = 0; i < ndims_; i++) {
|
||||||
absl::StrAppendFormat(&window, "%d ", window_[i]);
|
absl::StrAppendFormat(&window, "%d ", window_[i]);
|
||||||
absl::StrAppendFormat(&strides, "%d ", strides_[i]);
|
absl::StrAppendFormat(&strides, "%d ", strides_[i]);
|
||||||
@ -561,8 +561,8 @@ string PoolingDescriptor::ToString() const {
|
|||||||
mode_string, window, strides, padding, propagate_string);
|
mode_string, window, strides, padding, propagate_string);
|
||||||
}
|
}
|
||||||
|
|
||||||
string PoolingDescriptor::ToShortString() const {
|
std::string PoolingDescriptor::ToShortString() const {
|
||||||
string window, strides, padding;
|
std::string window, strides, padding;
|
||||||
for (int i = 0; i < ndims_; i++) {
|
for (int i = 0; i < ndims_; i++) {
|
||||||
absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
|
absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]);
|
||||||
absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
|
absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]);
|
||||||
@ -592,14 +592,14 @@ void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) {
|
|||||||
segment_size_ = other.segment_size_;
|
segment_size_ = other.segment_size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
string NormalizeDescriptor::ToString() const {
|
std::string NormalizeDescriptor::ToString() const {
|
||||||
return absl::StrFormat(
|
return absl::StrFormat(
|
||||||
"{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
|
"{bias: %f range: %d alpha: %f beta: %f wrap_around: %d "
|
||||||
"segment_size: %d}",
|
"segment_size: %d}",
|
||||||
bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
|
bias_, range_, alpha_, beta_, wrap_around_, segment_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
string NormalizeDescriptor::ToShortString() const {
|
std::string NormalizeDescriptor::ToShortString() const {
|
||||||
return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
|
return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_,
|
||||||
"_beta:", beta_, "_wrap:", wrap_around_,
|
"_beta:", beta_, "_wrap:", wrap_around_,
|
||||||
"_size:", segment_size_);
|
"_size:", segment_size_);
|
||||||
|
@ -101,7 +101,7 @@ inline absl::Span<int64> AsInt64Slice(T* repeated_field) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns a string representation of the given data layout.
|
// Returns a string representation of the given data layout.
|
||||||
string DataLayoutString(DataLayout layout);
|
std::string DataLayoutString(DataLayout layout);
|
||||||
|
|
||||||
// Specifies a quantization for activations in a given BatchDescriptor.
|
// Specifies a quantization for activations in a given BatchDescriptor.
|
||||||
enum class QuantizedActivationMode {
|
enum class QuantizedActivationMode {
|
||||||
@ -209,7 +209,7 @@ class RnnStateTensorDescriptor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Returns a string representation of the given quantization mode.
|
// Returns a string representation of the given quantization mode.
|
||||||
string QuantizedActivationModeString(QuantizedActivationMode mode);
|
std::string QuantizedActivationModeString(QuantizedActivationMode mode);
|
||||||
|
|
||||||
// Describes the dimensions that a layer consumes/produces.
|
// Describes the dimensions that a layer consumes/produces.
|
||||||
//
|
//
|
||||||
@ -260,8 +260,8 @@ class BatchDescriptor {
|
|||||||
// Clones values from 'other' for initialization.
|
// Clones values from 'other' for initialization.
|
||||||
void CloneFrom(const BatchDescriptor& other);
|
void CloneFrom(const BatchDescriptor& other);
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
string ToShortString() const;
|
std::string ToShortString() const;
|
||||||
|
|
||||||
// Pre-condition:
|
// Pre-condition:
|
||||||
// value_max_ == 0
|
// value_max_ == 0
|
||||||
@ -374,7 +374,7 @@ class BatchDescriptor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Returns a string representation of the given filter layout.
|
// Returns a string representation of the given filter layout.
|
||||||
string FilterLayoutString(FilterLayout layout);
|
std::string FilterLayoutString(FilterLayout layout);
|
||||||
|
|
||||||
// Describes a filter for the convolution. This is the "window" from
|
// Describes a filter for the convolution. This is the "window" from
|
||||||
// height-by-width patches of each of the feature maps in the input layer to the
|
// height-by-width patches of each of the feature maps in the input layer to the
|
||||||
@ -439,8 +439,8 @@ class FilterDescriptor {
|
|||||||
|
|
||||||
void CloneFrom(const FilterDescriptor& other);
|
void CloneFrom(const FilterDescriptor& other);
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
string ToShortString() const;
|
std::string ToShortString() const;
|
||||||
TensorDescriptorProto ToProto(DataType data_type) const;
|
TensorDescriptorProto ToProto(DataType data_type) const;
|
||||||
|
|
||||||
// Returns the number of weights required as parameters for a convolution
|
// Returns the number of weights required as parameters for a convolution
|
||||||
@ -486,7 +486,7 @@ enum class PadAlignment : int64 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Returns a string representation of the given padding alignment.
|
// Returns a string representation of the given padding alignment.
|
||||||
string PadAlignmentString(PadAlignment alignment);
|
std::string PadAlignmentString(PadAlignment alignment);
|
||||||
|
|
||||||
// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
|
// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
|
||||||
std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
|
std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
|
||||||
@ -529,8 +529,8 @@ class ConvolutionDescriptor {
|
|||||||
explicit ConvolutionDescriptor(int ndims);
|
explicit ConvolutionDescriptor(int ndims);
|
||||||
~ConvolutionDescriptor();
|
~ConvolutionDescriptor();
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
string ToShortString() const;
|
std::string ToShortString() const;
|
||||||
ConvolutionDescriptorProto ToProto() const { return proto_; }
|
ConvolutionDescriptorProto ToProto() const { return proto_; }
|
||||||
|
|
||||||
ConvolutionDescriptor& set_zero_padding_height(int64 value) {
|
ConvolutionDescriptor& set_zero_padding_height(int64 value) {
|
||||||
@ -578,7 +578,7 @@ class ConvolutionDescriptor {
|
|||||||
: ConvolutionMode::CROSS_CORRELATION);
|
: ConvolutionMode::CROSS_CORRELATION);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
ConvolutionDescriptor& set_name(const string& name) {
|
ConvolutionDescriptor& set_name(const std::string& name) {
|
||||||
proto_.set_name(name);
|
proto_.set_name(name);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -621,7 +621,7 @@ class ConvolutionDescriptor {
|
|||||||
return AsInt64Slice(proto_.paddings());
|
return AsInt64Slice(proto_.paddings());
|
||||||
}
|
}
|
||||||
|
|
||||||
string name() const { return proto_.name(); }
|
std::string name() const { return proto_.name(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
|
absl::Span<int64> strides() { return AsInt64Slice(proto_.mutable_strides()); }
|
||||||
@ -658,7 +658,7 @@ enum class SpaceConcatenateMode : int64 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Returns a short name for the pooling mode, e.g. "Avg".
|
// Returns a short name for the pooling mode, e.g. "Avg".
|
||||||
string ShortPoolingModeString(PoolingMode mode);
|
std::string ShortPoolingModeString(PoolingMode mode);
|
||||||
|
|
||||||
// Describes a pooling operation to be enqueued onto a stream via a platform's
|
// Describes a pooling operation to be enqueued onto a stream via a platform's
|
||||||
// DnnSupport.
|
// DnnSupport.
|
||||||
@ -722,7 +722,7 @@ class PoolingDescriptor {
|
|||||||
propagate_nans_ = value;
|
propagate_nans_ = value;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
PoolingDescriptor& set_name(const string& name) {
|
PoolingDescriptor& set_name(const std::string& name) {
|
||||||
name_ = name;
|
name_ = name;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
@ -730,8 +730,8 @@ class PoolingDescriptor {
|
|||||||
int ndims() const { return ndims_; }
|
int ndims() const { return ndims_; }
|
||||||
void CloneFrom(const PoolingDescriptor& other);
|
void CloneFrom(const PoolingDescriptor& other);
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
string ToShortString() const;
|
std::string ToShortString() const;
|
||||||
|
|
||||||
PoolingMode mode() const { return mode_; }
|
PoolingMode mode() const { return mode_; }
|
||||||
int64 window_height() const { return GetDim(window_, DimIndex::Y); }
|
int64 window_height() const { return GetDim(window_, DimIndex::Y); }
|
||||||
@ -747,13 +747,13 @@ class PoolingDescriptor {
|
|||||||
absl::Span<const int64> padding() const { return padding_; }
|
absl::Span<const int64> padding() const { return padding_; }
|
||||||
absl::Span<const int64> strides() const { return strides_; }
|
absl::Span<const int64> strides() const { return strides_; }
|
||||||
bool propagate_nans() const { return propagate_nans_; }
|
bool propagate_nans() const { return propagate_nans_; }
|
||||||
string name() const { return name_; }
|
std::string name() const { return name_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PoolingMode mode_;
|
PoolingMode mode_;
|
||||||
int ndims_;
|
int ndims_;
|
||||||
bool propagate_nans_;
|
bool propagate_nans_;
|
||||||
string name_; // Name as in Tensorflow NodeDef, for debugging purposes.
|
std::string name_; // Name as in Tensorflow NodeDef, for debugging purposes.
|
||||||
|
|
||||||
// Stored as: ..., y, x.
|
// Stored as: ..., y, x.
|
||||||
std::vector<int64> window_;
|
std::vector<int64> window_;
|
||||||
@ -783,7 +783,7 @@ class AlgorithmDesc {
|
|||||||
|
|
||||||
AlgorithmProto ToProto() const { return proto_; }
|
AlgorithmProto ToProto() const { return proto_; }
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AlgorithmProto proto_;
|
AlgorithmProto proto_;
|
||||||
@ -860,7 +860,7 @@ class AlgorithmConfig {
|
|||||||
bool operator!=(const AlgorithmConfig& other) const {
|
bool operator!=(const AlgorithmConfig& other) const {
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::optional<AlgorithmDesc> algorithm_;
|
absl::optional<AlgorithmDesc> algorithm_;
|
||||||
@ -927,8 +927,8 @@ class NormalizeDescriptor {
|
|||||||
|
|
||||||
void CloneFrom(const NormalizeDescriptor& other);
|
void CloneFrom(const NormalizeDescriptor& other);
|
||||||
|
|
||||||
string ToString() const;
|
std::string ToString() const;
|
||||||
string ToShortString() const;
|
std::string ToShortString() const;
|
||||||
|
|
||||||
float bias() const { return bias_; }
|
float bias() const { return bias_; }
|
||||||
int32 range() const { return range_; }
|
int32 range() const { return range_; }
|
||||||
@ -947,13 +947,13 @@ class NormalizeDescriptor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Returns a string representation of the given activation mode.
|
// Returns a string representation of the given activation mode.
|
||||||
string ActivationModeString(ActivationMode mode);
|
std::string ActivationModeString(ActivationMode mode);
|
||||||
|
|
||||||
// Describes the operation that DoElementwiseOperation should perform on its
|
// Describes the operation that DoElementwiseOperation should perform on its
|
||||||
// inputs.
|
// inputs.
|
||||||
enum class ElementwiseOperation { kAdd, kMultiply };
|
enum class ElementwiseOperation { kAdd, kMultiply };
|
||||||
|
|
||||||
string ElementwiseOperationString(ElementwiseOperation op);
|
std::string ElementwiseOperationString(ElementwiseOperation op);
|
||||||
|
|
||||||
// A simple class representing the version of the backing library, to
|
// A simple class representing the version of the backing library, to
|
||||||
// workaround the "too perfect forwarding" issue in gcc6+ compilers.
|
// workaround the "too perfect forwarding" issue in gcc6+ compilers.
|
||||||
|
@ -91,7 +91,7 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void KernelBase::set_name(absl::string_view name) {
|
void KernelBase::set_name(absl::string_view name) {
|
||||||
name_ = string(name);
|
name_ = std::string(name);
|
||||||
|
|
||||||
// CUDA splitter prefixes stub functions with __device_stub_.
|
// CUDA splitter prefixes stub functions with __device_stub_.
|
||||||
demangled_name_ =
|
demangled_name_ =
|
||||||
|
@ -178,8 +178,8 @@ class KernelBase {
|
|||||||
KernelCacheConfig GetPreferredCacheConfig() const;
|
KernelCacheConfig GetPreferredCacheConfig() const;
|
||||||
|
|
||||||
void set_name(absl::string_view name);
|
void set_name(absl::string_view name);
|
||||||
const string &name() const { return name_; }
|
const std::string &name() const { return name_; }
|
||||||
const string &demangled_name() const { return demangled_name_; }
|
const std::string &demangled_name() const { return demangled_name_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The StreamExecutor that loads this kernel object.
|
// The StreamExecutor that loads this kernel object.
|
||||||
@ -188,8 +188,8 @@ class KernelBase {
|
|||||||
// Implementation delegated to for platform-specific functionality.
|
// Implementation delegated to for platform-specific functionality.
|
||||||
std::unique_ptr<internal::KernelInterface> implementation_;
|
std::unique_ptr<internal::KernelInterface> implementation_;
|
||||||
|
|
||||||
string name_;
|
std::string name_;
|
||||||
string demangled_name_;
|
std::string demangled_name_;
|
||||||
|
|
||||||
KernelMetadata metadata_;
|
KernelMetadata metadata_;
|
||||||
|
|
||||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname)
|
KernelLoaderSpec::KernelLoaderSpec(absl::string_view kernelname)
|
||||||
: kernelname_(string(kernelname)) {}
|
: kernelname_(std::string(kernelname)) {}
|
||||||
|
|
||||||
OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename,
|
OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(absl::string_view filename,
|
||||||
absl::string_view kernelname)
|
absl::string_view kernelname)
|
||||||
: KernelLoaderSpec(kernelname), filename_(string(filename)) {}
|
: KernelLoaderSpec(kernelname), filename_(std::string(filename)) {}
|
||||||
|
|
||||||
CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename,
|
CudaPtxOnDisk::CudaPtxOnDisk(absl::string_view filename,
|
||||||
absl::string_view kernelname)
|
absl::string_view kernelname)
|
||||||
@ -77,13 +77,13 @@ CudaPtxInMemory::CudaPtxInMemory(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string CudaPtxInMemory::DecompressPtx(const char *ptx) {
|
std::string CudaPtxInMemory::DecompressPtx(const char *ptx) {
|
||||||
// Get the length of the PTX string from the beginning of the buffer.
|
// Get the length of the PTX string from the beginning of the buffer.
|
||||||
uint64 ptx_length = *reinterpret_cast<const uint64 *>(ptx);
|
uint64 ptx_length = *reinterpret_cast<const uint64 *>(ptx);
|
||||||
// Get the PTX string from the buffer with offset and length.
|
// Get the PTX string from the buffer with offset and length.
|
||||||
string compressed_ptx(ptx + sizeof(uint64),
|
std::string compressed_ptx(ptx + sizeof(uint64),
|
||||||
ptx + sizeof(uint64) + ptx_length);
|
ptx + sizeof(uint64) + ptx_length);
|
||||||
string decompressed_ptx;
|
std::string decompressed_ptx;
|
||||||
// Decompress the PTX string with bzip2.
|
// Decompress the PTX string with bzip2.
|
||||||
LOG(FATAL) << "bzip2 decompression is not supported yet.";
|
LOG(FATAL) << "bzip2 decompression is not supported yet.";
|
||||||
return decompressed_ptx;
|
return decompressed_ptx;
|
||||||
|
@ -73,7 +73,7 @@ class KernelLoaderSpec {
|
|||||||
virtual ~KernelLoaderSpec() {}
|
virtual ~KernelLoaderSpec() {}
|
||||||
|
|
||||||
// Returns the kernel name to load out of the program.
|
// Returns the kernel name to load out of the program.
|
||||||
const string &kernelname() const { return kernelname_; }
|
const std::string &kernelname() const { return kernelname_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit KernelLoaderSpec(absl::string_view kernelname);
|
explicit KernelLoaderSpec(absl::string_view kernelname);
|
||||||
@ -81,7 +81,7 @@ class KernelLoaderSpec {
|
|||||||
private:
|
private:
|
||||||
// The kernel name that should be loaded out of the program description given
|
// The kernel name that should be loaded out of the program description given
|
||||||
// above.
|
// above.
|
||||||
string kernelname_;
|
std::string kernelname_;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec);
|
SE_DISALLOW_COPY_AND_ASSIGN(KernelLoaderSpec);
|
||||||
};
|
};
|
||||||
@ -94,7 +94,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec {
|
|||||||
~OnDiskKernelLoaderSpec() override {}
|
~OnDiskKernelLoaderSpec() override {}
|
||||||
|
|
||||||
// Returns the path to the on-disk loadable kernel file.
|
// Returns the path to the on-disk loadable kernel file.
|
||||||
const string &filename() const { return filename_; }
|
const std::string &filename() const { return filename_; }
|
||||||
|
|
||||||
// Returns the canonical suffix for this on-disk kernel loader spec format;
|
// Returns the canonical suffix for this on-disk kernel loader spec format;
|
||||||
// e.g. PTX files on disk have a canonical suffix of ".ptx".
|
// e.g. PTX files on disk have a canonical suffix of ".ptx".
|
||||||
@ -104,7 +104,7 @@ class OnDiskKernelLoaderSpec : public KernelLoaderSpec {
|
|||||||
OnDiskKernelLoaderSpec(absl::string_view filename,
|
OnDiskKernelLoaderSpec(absl::string_view filename,
|
||||||
absl::string_view kernelname);
|
absl::string_view kernelname);
|
||||||
|
|
||||||
string filename_;
|
std::string filename_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(OnDiskKernelLoaderSpec);
|
SE_DISALLOW_COPY_AND_ASSIGN(OnDiskKernelLoaderSpec);
|
||||||
@ -128,12 +128,12 @@ class CudaCubinOnDisk : public OnDiskKernelLoaderSpec {
|
|||||||
CudaCubinOnDisk(absl::string_view filename, absl::string_view kernelname);
|
CudaCubinOnDisk(absl::string_view filename, absl::string_view kernelname);
|
||||||
~CudaCubinOnDisk() override {}
|
~CudaCubinOnDisk() override {}
|
||||||
|
|
||||||
const string &filename() const { return filename_; }
|
const std::string &filename() const { return filename_; }
|
||||||
|
|
||||||
const char *CanonicalSuffix() const override { return ".cubin"; }
|
const char *CanonicalSuffix() const override { return ".cubin"; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string filename_;
|
std::string filename_;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinOnDisk);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudaCubinOnDisk);
|
||||||
};
|
};
|
||||||
@ -192,7 +192,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
|
|||||||
int compute_capability_minor) const;
|
int compute_capability_minor) const;
|
||||||
|
|
||||||
// Decompresses the PTX string using bzip2.
|
// Decompresses the PTX string using bzip2.
|
||||||
static string DecompressPtx(const char *ptx);
|
static std::string DecompressPtx(const char *ptx);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// PTX translation unit text contents in memory. The key is of as a tuple
|
// PTX translation unit text contents in memory. The key is of as a tuple
|
||||||
@ -205,7 +205,7 @@ class CudaPtxInMemory : public KernelLoaderSpec {
|
|||||||
|
|
||||||
// Stores all decompressed ptx strings, with original ptx string as keys.
|
// Stores all decompressed ptx strings, with original ptx string as keys.
|
||||||
// It is marked as mutable for lazy decompression.
|
// It is marked as mutable for lazy decompression.
|
||||||
mutable std::map<const char *, string> decompressed_ptx_;
|
mutable std::map<const char *, std::string> decompressed_ptx_;
|
||||||
mutable absl::Mutex mu_;
|
mutable absl::Mutex mu_;
|
||||||
|
|
||||||
// Defines the minimum compute capability possible. Used when PTX has no
|
// Defines the minimum compute capability possible. Used when PTX has no
|
||||||
@ -246,11 +246,11 @@ class OpenCLTextInMemory : public KernelLoaderSpec {
|
|||||||
~OpenCLTextInMemory() override {}
|
~OpenCLTextInMemory() override {}
|
||||||
|
|
||||||
// Returns the OpenCL text contents.
|
// Returns the OpenCL text contents.
|
||||||
const string &text() const { return text_; }
|
const std::string &text() const { return text_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// OpenCL translation unit text contents in memory.
|
// OpenCL translation unit text contents in memory.
|
||||||
string text_;
|
std::string text_;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextInMemory);
|
SE_DISALLOW_COPY_AND_ASSIGN(OpenCLTextInMemory);
|
||||||
};
|
};
|
||||||
|
@ -56,7 +56,7 @@ struct ThreadDim : public Dim3D {
|
|||||||
: Dim3D(x, y, z) {}
|
: Dim3D(x, y, z) {}
|
||||||
|
|
||||||
// Returns a string representation of the thread dimensionality.
|
// Returns a string representation of the thread dimensionality.
|
||||||
string ToString() const {
|
std::string ToString() const {
|
||||||
return absl::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}");
|
return absl::StrCat("ThreadDim{", x, ", ", y, ", ", z, "}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -68,7 +68,7 @@ struct BlockDim : public Dim3D {
|
|||||||
: Dim3D(x, y, z) {}
|
: Dim3D(x, y, z) {}
|
||||||
|
|
||||||
// Returns a string representation of the block dimensionality.
|
// Returns a string representation of the block dimensionality.
|
||||||
string ToString() const {
|
std::string ToString() const {
|
||||||
return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}");
|
return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -33,8 +33,8 @@ namespace port {
|
|||||||
// The API reference of abi::__cxa_demangle() can be found in
|
// The API reference of abi::__cxa_demangle() can be found in
|
||||||
// libstdc++'s manual.
|
// libstdc++'s manual.
|
||||||
// https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html
|
// https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-html-USERS-4.3/a01696.html
|
||||||
string Demangle(const char *mangled) {
|
std::string Demangle(const char *mangled) {
|
||||||
string demangled;
|
std::string demangled;
|
||||||
int status = 0;
|
int status = 0;
|
||||||
char *result = nullptr;
|
char *result = nullptr;
|
||||||
#if HAS_CXA_DEMANGLE
|
#if HAS_CXA_DEMANGLE
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace port {
|
namespace port {
|
||||||
|
|
||||||
string Demangle(const char* mangled);
|
std::string Demangle(const char* mangled);
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -27,12 +27,12 @@ namespace port {
|
|||||||
using tensorflow::Env;
|
using tensorflow::Env;
|
||||||
using tensorflow::Thread;
|
using tensorflow::Thread;
|
||||||
|
|
||||||
inline Status FileExists(const string& filename) {
|
inline Status FileExists(const std::string& filename) {
|
||||||
return Env::Default()->FileExists(filename);
|
return Env::Default()->FileExists(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Status FileExists(const absl::string_view& filename) {
|
inline Status FileExists(const absl::string_view& filename) {
|
||||||
return Env::Default()->FileExists(string(filename));
|
return Env::Default()->FileExists(std::string(filename));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
|
@ -28,7 +28,7 @@ namespace port {
|
|||||||
|
|
||||||
class HumanReadableNumBytes {
|
class HumanReadableNumBytes {
|
||||||
public:
|
public:
|
||||||
static string ToString(int64 num_bytes) {
|
static std::string ToString(int64 num_bytes) {
|
||||||
if (num_bytes == std::numeric_limits<int64>::min()) {
|
if (num_bytes == std::numeric_limits<int64>::min()) {
|
||||||
// Special case for number with not representable nagation.
|
// Special case for number with not representable nagation.
|
||||||
return "-8E";
|
return "-8E";
|
||||||
|
@ -32,7 +32,7 @@ bool safe_strto32(const char* str, int32* value) {
|
|||||||
// Convert strings to floating point values.
|
// Convert strings to floating point values.
|
||||||
// Leading and trailing spaces are allowed.
|
// Leading and trailing spaces are allowed.
|
||||||
// Values may be rounded on over- and underflow.
|
// Values may be rounded on over- and underflow.
|
||||||
bool safe_strto32(const string& str, int32* value) {
|
bool safe_strto32(const std::string& str, int32* value) {
|
||||||
return port::safe_strto32(str.c_str(), value);
|
return port::safe_strto32(str.c_str(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ namespace port {
|
|||||||
// Convert strings to floating point values.
|
// Convert strings to floating point values.
|
||||||
// Leading and trailing spaces are allowed.
|
// Leading and trailing spaces are allowed.
|
||||||
// Values may be rounded on over- and underflow.
|
// Values may be rounded on over- and underflow.
|
||||||
bool safe_strto32(const string& str, int32* value);
|
bool safe_strto32(const std::string& str, int32* value);
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -27,14 +27,14 @@ static bool IsAbsolutePath(absl::string_view path) {
|
|||||||
|
|
||||||
// For an array of paths of length count, append them all together,
|
// For an array of paths of length count, append them all together,
|
||||||
// ensuring that the proper path separators are inserted between them.
|
// ensuring that the proper path separators are inserted between them.
|
||||||
string JoinPathImpl(std::initializer_list<absl::string_view> paths) {
|
std::string JoinPathImpl(std::initializer_list<absl::string_view> paths) {
|
||||||
string result;
|
std::string result;
|
||||||
|
|
||||||
for (absl::string_view path : paths) {
|
for (absl::string_view path : paths) {
|
||||||
if (path.empty()) continue;
|
if (path.empty()) continue;
|
||||||
|
|
||||||
if (result.empty()) {
|
if (result.empty()) {
|
||||||
result = string(path);
|
result = std::string(path);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ namespace port {
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
// TODO(rspringer): Move to cc/implementation file.
|
// TODO(rspringer): Move to cc/implementation file.
|
||||||
// Not part of the public API.
|
// Not part of the public API.
|
||||||
string JoinPathImpl(std::initializer_list<absl::string_view> paths);
|
std::string JoinPathImpl(std::initializer_list<absl::string_view> paths);
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
// Join multiple paths together.
|
// Join multiple paths together.
|
||||||
@ -47,7 +47,7 @@ string JoinPathImpl(std::initializer_list<absl::string_view> paths);
|
|||||||
// string path = file::JoinPath("/var/log", dirname, filename);
|
// string path = file::JoinPath("/var/log", dirname, filename);
|
||||||
// string path = file::JoinPath(FLAGS_test_srcdir, filename);
|
// string path = file::JoinPath(FLAGS_test_srcdir, filename);
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
inline string JoinPath(const T&... args) {
|
inline std::string JoinPath(const T&... args) {
|
||||||
return internal::JoinPathImpl({args...});
|
return internal::JoinPathImpl({args...});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,14 +29,14 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace port {
|
namespace port {
|
||||||
|
|
||||||
string Hostname() {
|
std::string Hostname() {
|
||||||
char hostname[1024];
|
char hostname[1024];
|
||||||
gethostname(hostname, sizeof hostname);
|
gethostname(hostname, sizeof hostname);
|
||||||
hostname[sizeof hostname - 1] = 0;
|
hostname[sizeof hostname - 1] = 0;
|
||||||
return std::string(hostname);
|
return std::string(hostname);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetCurrentDirectory(string* dir) {
|
bool GetCurrentDirectory(std::string* dir) {
|
||||||
size_t len = 128;
|
size_t len = 128;
|
||||||
std::unique_ptr<char[]> a(new char[len]);
|
std::unique_ptr<char[]> a(new char[len]);
|
||||||
for (;;) {
|
for (;;) {
|
||||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace port {
|
namespace port {
|
||||||
|
|
||||||
string Hostname();
|
std::string Hostname();
|
||||||
bool GetCurrentDirectory(string* dir);
|
bool GetCurrentDirectory(std::string* dir);
|
||||||
|
|
||||||
} // namespace port
|
} // namespace port
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -147,19 +147,19 @@ TEST(StatusOr, TestMoveOnlyVector) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(StatusOr, TestMoveWithValuesAndErrors) {
|
TEST(StatusOr, TestMoveWithValuesAndErrors) {
|
||||||
StatusOr<string> status_or(string(1000, '0'));
|
StatusOr<std::string> status_or(std::string(1000, '0'));
|
||||||
StatusOr<string> value1(string(1000, '1'));
|
StatusOr<std::string> value1(std::string(1000, '1'));
|
||||||
StatusOr<string> value2(string(1000, '2'));
|
StatusOr<std::string> value2(std::string(1000, '2'));
|
||||||
StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
|
StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
|
||||||
StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
|
StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
|
||||||
|
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie());
|
||||||
|
|
||||||
// Overwrite the value in status_or with another value.
|
// Overwrite the value in status_or with another value.
|
||||||
status_or = std::move(value1);
|
status_or = std::move(value1);
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie());
|
||||||
|
|
||||||
// Overwrite the value in status_or with an error.
|
// Overwrite the value in status_or with an error.
|
||||||
status_or = std::move(error1);
|
status_or = std::move(error1);
|
||||||
@ -174,23 +174,23 @@ TEST(StatusOr, TestMoveWithValuesAndErrors) {
|
|||||||
// Overwrite the error with a value.
|
// Overwrite the error with a value.
|
||||||
status_or = std::move(value2);
|
status_or = std::move(value2);
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StatusOr, TestCopyWithValuesAndErrors) {
|
TEST(StatusOr, TestCopyWithValuesAndErrors) {
|
||||||
StatusOr<string> status_or(string(1000, '0'));
|
StatusOr<std::string> status_or(std::string(1000, '0'));
|
||||||
StatusOr<string> value1(string(1000, '1'));
|
StatusOr<std::string> value1(std::string(1000, '1'));
|
||||||
StatusOr<string> value2(string(1000, '2'));
|
StatusOr<std::string> value2(std::string(1000, '2'));
|
||||||
StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
|
StatusOr<std::string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
|
||||||
StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
|
StatusOr<std::string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
|
||||||
|
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie());
|
||||||
|
|
||||||
// Overwrite the value in status_or with another value.
|
// Overwrite the value in status_or with another value.
|
||||||
status_or = value1;
|
status_or = value1;
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie());
|
||||||
|
|
||||||
// Overwrite the value in status_or with an error.
|
// Overwrite the value in status_or with an error.
|
||||||
status_or = error1;
|
status_or = error1;
|
||||||
@ -205,13 +205,13 @@ TEST(StatusOr, TestCopyWithValuesAndErrors) {
|
|||||||
// Overwrite the error with a value.
|
// Overwrite the error with a value.
|
||||||
status_or = value2;
|
status_or = value2;
|
||||||
ASSERT_TRUE(status_or.ok());
|
ASSERT_TRUE(status_or.ok());
|
||||||
EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie());
|
||||||
|
|
||||||
// Verify original values unchanged.
|
// Verify original values unchanged.
|
||||||
EXPECT_EQ(string(1000, '1'), value1.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '1'), value1.ValueOrDie());
|
||||||
EXPECT_EQ("error1", error1.status().error_message());
|
EXPECT_EQ("error1", error1.status().error_message());
|
||||||
EXPECT_EQ("error2", error2.status().error_message());
|
EXPECT_EQ("error2", error2.status().error_message());
|
||||||
EXPECT_EQ(string(1000, '2'), value2.ValueOrDie());
|
EXPECT_EQ(std::string(1000, '2'), value2.ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StatusOr, TestDefaultCtor) {
|
TEST(StatusOr, TestDefaultCtor) {
|
||||||
|
@ -39,10 +39,10 @@ class MultiPlatformManagerImpl {
|
|||||||
LOCKS_EXCLUDED(mu_);
|
LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
port::StatusOr<Platform*> InitializePlatformWithName(
|
port::StatusOr<Platform*> InitializePlatformWithName(
|
||||||
absl::string_view target, const std::map<string, string>& options)
|
absl::string_view target,
|
||||||
LOCKS_EXCLUDED(mu_);
|
const std::map<std::string, std::string>& options) LOCKS_EXCLUDED(mu_);
|
||||||
port::StatusOr<Platform*> InitializePlatformWithId(
|
port::StatusOr<Platform*> InitializePlatformWithId(
|
||||||
const Platform::Id& id, const std::map<string, string>& options)
|
const Platform::Id& id, const std::map<std::string, std::string>& options)
|
||||||
LOCKS_EXCLUDED(mu_);
|
LOCKS_EXCLUDED(mu_);
|
||||||
|
|
||||||
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||||
@ -66,13 +66,13 @@ class MultiPlatformManagerImpl {
|
|||||||
absl::Mutex mu_;
|
absl::Mutex mu_;
|
||||||
std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_);
|
std::vector<std::unique_ptr<Listener>> listeners_ GUARDED_BY(mu_);
|
||||||
absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_);
|
absl::flat_hash_map<Platform::Id, Platform*> id_map_ GUARDED_BY(mu_);
|
||||||
absl::flat_hash_map<string, Platform*> name_map_ GUARDED_BY(mu_);
|
absl::flat_hash_map<std::string, Platform*> name_map_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
port::Status MultiPlatformManagerImpl::RegisterPlatform(
|
port::Status MultiPlatformManagerImpl::RegisterPlatform(
|
||||||
std::unique_ptr<Platform> platform) {
|
std::unique_ptr<Platform> platform) {
|
||||||
CHECK(platform != nullptr);
|
CHECK(platform != nullptr);
|
||||||
string key = absl::AsciiStrToLower(platform->Name());
|
std::string key = absl::AsciiStrToLower(platform->Name());
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
if (name_map_.find(key) != name_map_.end()) {
|
if (name_map_.find(key) != name_map_.end()) {
|
||||||
return port::Status(port::error::INTERNAL,
|
return port::Status(port::error::INTERNAL,
|
||||||
@ -118,7 +118,8 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::PlatformWithId(
|
|||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
|
port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
|
||||||
absl::string_view target, const std::map<string, string>& options) {
|
absl::string_view target,
|
||||||
|
const std::map<std::string, std::string>& options) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
|
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByNameLocked(target));
|
||||||
@ -134,7 +135,7 @@ port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithName(
|
|||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
|
port::StatusOr<Platform*> MultiPlatformManagerImpl::InitializePlatformWithId(
|
||||||
const Platform::Id& id, const std::map<string, string>& options) {
|
const Platform::Id& id, const std::map<std::string, std::string>& options) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
|
|
||||||
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
|
SE_ASSIGN_OR_RETURN(Platform * platform, LookupByIdLocked(id));
|
||||||
@ -224,13 +225,14 @@ MultiPlatformManagerImpl& Impl() {
|
|||||||
|
|
||||||
/*static*/ port::StatusOr<Platform*>
|
/*static*/ port::StatusOr<Platform*>
|
||||||
MultiPlatformManager::InitializePlatformWithName(
|
MultiPlatformManager::InitializePlatformWithName(
|
||||||
absl::string_view target, const std::map<string, string>& options) {
|
absl::string_view target,
|
||||||
|
const std::map<std::string, std::string>& options) {
|
||||||
return Impl().InitializePlatformWithName(target, options);
|
return Impl().InitializePlatformWithName(target, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*static*/ port::StatusOr<Platform*>
|
/*static*/ port::StatusOr<Platform*>
|
||||||
MultiPlatformManager::InitializePlatformWithId(
|
MultiPlatformManager::InitializePlatformWithId(
|
||||||
const Platform::Id& id, const std::map<string, string>& options) {
|
const Platform::Id& id, const std::map<std::string, std::string>& options) {
|
||||||
return Impl().InitializePlatformWithId(id, options);
|
return Impl().InitializePlatformWithId(id, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,10 +111,12 @@ class MultiPlatformManager {
|
|||||||
// Ownership of the platform is NOT transferred to the caller --
|
// Ownership of the platform is NOT transferred to the caller --
|
||||||
// the MultiPlatformManager owns the platforms in a singleton-like fashion.
|
// the MultiPlatformManager owns the platforms in a singleton-like fashion.
|
||||||
static port::StatusOr<Platform*> InitializePlatformWithName(
|
static port::StatusOr<Platform*> InitializePlatformWithName(
|
||||||
absl::string_view target, const std::map<string, string>& options);
|
absl::string_view target,
|
||||||
|
const std::map<std::string, std::string>& options);
|
||||||
|
|
||||||
static port::StatusOr<Platform*> InitializePlatformWithId(
|
static port::StatusOr<Platform*> InitializePlatformWithId(
|
||||||
const Platform::Id& id, const std::map<string, string>& options);
|
const Platform::Id& id,
|
||||||
|
const std::map<std::string, std::string>& options);
|
||||||
|
|
||||||
// Retrieves the platforms satisfying the given filter, i.e. returns true.
|
// Retrieves the platforms satisfying the given filter, i.e. returns true.
|
||||||
// Returned Platforms are always initialized.
|
// Returned Platforms are always initialized.
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
|
|
||||||
string PlatformKindString(PlatformKind kind) {
|
std::string PlatformKindString(PlatformKind kind) {
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case PlatformKind::kCuda:
|
case PlatformKind::kCuda:
|
||||||
return "CUDA";
|
return "CUDA";
|
||||||
@ -41,7 +41,7 @@ string PlatformKindString(PlatformKind kind) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PlatformKind PlatformKindFromString(string kind) {
|
PlatformKind PlatformKindFromString(std::string kind) {
|
||||||
for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) {
|
for (int i = 0; i < static_cast<int>(PlatformKind::kSize); ++i) {
|
||||||
if (kind == PlatformKindString(static_cast<PlatformKind>(i))) {
|
if (kind == PlatformKindString(static_cast<PlatformKind>(i))) {
|
||||||
return static_cast<PlatformKind>(i);
|
return static_cast<PlatformKind>(i);
|
||||||
@ -91,7 +91,7 @@ Platform::~Platform() {}
|
|||||||
bool Platform::Initialized() const { return true; }
|
bool Platform::Initialized() const { return true; }
|
||||||
|
|
||||||
port::Status Platform::Initialize(
|
port::Status Platform::Initialize(
|
||||||
const std::map<string, string> &platform_options) {
|
const std::map<std::string, std::string> &platform_options) {
|
||||||
if (!platform_options.empty()) {
|
if (!platform_options.empty()) {
|
||||||
return port::Status(port::error::UNIMPLEMENTED,
|
return port::Status(port::error::UNIMPLEMENTED,
|
||||||
"this platform does not support custom initialization");
|
"this platform does not support custom initialization");
|
||||||
|
@ -60,11 +60,11 @@ bool PlatformIsRunnable(PlatformKind kind);
|
|||||||
bool PlatformIsRunnableOnDevice(PlatformKind kind);
|
bool PlatformIsRunnableOnDevice(PlatformKind kind);
|
||||||
|
|
||||||
// Returns a printable description of a PlatformKind.
|
// Returns a printable description of a PlatformKind.
|
||||||
string PlatformKindString(PlatformKind kind);
|
std::string PlatformKindString(PlatformKind kind);
|
||||||
|
|
||||||
// Returns the PlatformKind corresponding to the input string; returns kInvalid
|
// Returns the PlatformKind corresponding to the input string; returns kInvalid
|
||||||
// in the case of no match.
|
// in the case of no match.
|
||||||
PlatformKind PlatformKindFromString(string platform_string);
|
PlatformKind PlatformKindFromString(std::string platform_string);
|
||||||
|
|
||||||
// Checks that kind takes on a valid value.
|
// Checks that kind takes on a valid value.
|
||||||
void CheckPlatformKindIsValid(PlatformKind kind);
|
void CheckPlatformKindIsValid(PlatformKind kind);
|
||||||
@ -114,7 +114,7 @@ class Platform {
|
|||||||
virtual Id id() const = 0;
|
virtual Id id() const = 0;
|
||||||
|
|
||||||
// Name of this platform.
|
// Name of this platform.
|
||||||
virtual const string& Name() const = 0;
|
virtual const std::string& Name() const = 0;
|
||||||
|
|
||||||
// Returns the number of devices accessible on this platform.
|
// Returns the number of devices accessible on this platform.
|
||||||
//
|
//
|
||||||
@ -133,7 +133,7 @@ class Platform {
|
|||||||
// MultiPlatformManager, this method will be called automatically by
|
// MultiPlatformManager, this method will be called automatically by
|
||||||
// InitializePlatformWithId/InitializePlatformWithName.
|
// InitializePlatformWithId/InitializePlatformWithName.
|
||||||
virtual port::Status Initialize(
|
virtual port::Status Initialize(
|
||||||
const std::map<string, string>& platform_options);
|
const std::map<std::string, std::string>& platform_options);
|
||||||
|
|
||||||
// Returns a populated DeviceDescription for the device at the given ordinal.
|
// Returns a populated DeviceDescription for the device at the given ordinal.
|
||||||
// This should not require device initialization. Note that not all platforms
|
// This should not require device initialization. Note that not all platforms
|
||||||
|
@ -27,7 +27,7 @@ namespace stream_executor {
|
|||||||
const PluginId kNullPlugin = nullptr;
|
const PluginId kNullPlugin = nullptr;
|
||||||
|
|
||||||
// Returns the string representation of the specified PluginKind.
|
// Returns the string representation of the specified PluginKind.
|
||||||
string PluginKindString(PluginKind plugin_kind) {
|
std::string PluginKindString(PluginKind plugin_kind) {
|
||||||
switch (plugin_kind) {
|
switch (plugin_kind) {
|
||||||
case PluginKind::kBlas:
|
case PluginKind::kBlas:
|
||||||
return "BLAS";
|
return "BLAS";
|
||||||
@ -70,7 +70,7 @@ void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
|
|||||||
|
|
||||||
template <typename FACTORY_TYPE>
|
template <typename FACTORY_TYPE>
|
||||||
port::Status PluginRegistry::RegisterFactoryInternal(
|
port::Status PluginRegistry::RegisterFactoryInternal(
|
||||||
PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
|
PluginId plugin_id, const std::string& plugin_name, FACTORY_TYPE factory,
|
||||||
std::map<PluginId, FACTORY_TYPE>* factories) {
|
std::map<PluginId, FACTORY_TYPE>* factories) {
|
||||||
absl::MutexLock lock{&GetPluginRegistryMutex()};
|
absl::MutexLock lock{&GetPluginRegistryMutex()};
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
|
|||||||
if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
|
if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
|
||||||
port::StatusOr<Platform*> status =
|
port::StatusOr<Platform*> status =
|
||||||
MultiPlatformManager::PlatformWithId(platform_id);
|
MultiPlatformManager::PlatformWithId(platform_id);
|
||||||
string platform_name = "<unregistered platform>";
|
std::string platform_name = "<unregistered platform>";
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
platform_name = status.ValueOrDie()->Name();
|
platform_name = status.ValueOrDie()->Name();
|
||||||
}
|
}
|
||||||
@ -194,7 +194,7 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
|
|||||||
\
|
\
|
||||||
template <> \
|
template <> \
|
||||||
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
|
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
|
||||||
Platform::Id platform_id, PluginId plugin_id, const string& name, \
|
Platform::Id platform_id, PluginId plugin_id, const std::string& name, \
|
||||||
PluginRegistry::FACTORY_TYPE factory) { \
|
PluginRegistry::FACTORY_TYPE factory) { \
|
||||||
return RegisterFactoryInternal(plugin_id, name, factory, \
|
return RegisterFactoryInternal(plugin_id, name, factory, \
|
||||||
&factories_[platform_id].FACTORY_VAR); \
|
&factories_[platform_id].FACTORY_VAR); \
|
||||||
@ -202,7 +202,8 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
|
|||||||
\
|
\
|
||||||
template <> \
|
template <> \
|
||||||
port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
|
port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
|
||||||
PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \
|
PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, \
|
||||||
|
const std::string& name, \
|
||||||
PluginRegistry::FACTORY_TYPE factory) { \
|
PluginRegistry::FACTORY_TYPE factory) { \
|
||||||
return RegisterFactoryInternal(plugin_id, name, factory, \
|
return RegisterFactoryInternal(plugin_id, name, factory, \
|
||||||
&generic_factories_.FACTORY_VAR); \
|
&generic_factories_.FACTORY_VAR); \
|
||||||
|
@ -62,13 +62,13 @@ class PluginRegistry {
|
|||||||
// with that platform (but execution should be otherwise unaffected).
|
// with that platform (but execution should be otherwise unaffected).
|
||||||
template <typename FactoryT>
|
template <typename FactoryT>
|
||||||
port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
|
port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
|
||||||
const string& name, FactoryT factory);
|
const std::string& name, FactoryT factory);
|
||||||
|
|
||||||
// Registers the specified factory as usable by _all_ platform types.
|
// Registers the specified factory as usable by _all_ platform types.
|
||||||
// Reports errors just as RegisterFactory.
|
// Reports errors just as RegisterFactory.
|
||||||
template <typename FactoryT>
|
template <typename FactoryT>
|
||||||
port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
|
port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
|
||||||
const string& name,
|
const std::string& name,
|
||||||
FactoryT factory);
|
FactoryT factory);
|
||||||
|
|
||||||
// TODO(b/22689637): Setter for temporary mapping until all users are using
|
// TODO(b/22689637): Setter for temporary mapping until all users are using
|
||||||
@ -122,7 +122,7 @@ class PluginRegistry {
|
|||||||
// Actually performs the work of registration.
|
// Actually performs the work of registration.
|
||||||
template <typename FactoryT>
|
template <typename FactoryT>
|
||||||
port::Status RegisterFactoryInternal(PluginId plugin_id,
|
port::Status RegisterFactoryInternal(PluginId plugin_id,
|
||||||
const string& plugin_name,
|
const std::string& plugin_name,
|
||||||
FactoryT factory,
|
FactoryT factory,
|
||||||
std::map<PluginId, FactoryT>* factories);
|
std::map<PluginId, FactoryT>* factories);
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ class PluginRegistry {
|
|||||||
std::map<Platform::Id, DefaultFactories> default_factories_;
|
std::map<Platform::Id, DefaultFactories> default_factories_;
|
||||||
|
|
||||||
// Lookup table for plugin names.
|
// Lookup table for plugin names.
|
||||||
std::map<PluginId, string> plugin_names_;
|
std::map<PluginId, std::string> plugin_names_;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
|
SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
|
||||||
};
|
};
|
||||||
@ -164,7 +164,7 @@ class PluginRegistry {
|
|||||||
#define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \
|
#define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \
|
||||||
template <> \
|
template <> \
|
||||||
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
|
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
|
||||||
Platform::Id platform_id, PluginId plugin_id, const string& name, \
|
Platform::Id platform_id, PluginId plugin_id, const std::string& name, \
|
||||||
PluginRegistry::FACTORY_TYPE factory); \
|
PluginRegistry::FACTORY_TYPE factory); \
|
||||||
template <> \
|
template <> \
|
||||||
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
|
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
|
||||||
|
@ -35,55 +35,57 @@ namespace {
|
|||||||
// will be VLOG'ed. We need overloads, instead of
|
// will be VLOG'ed. We need overloads, instead of
|
||||||
// e.g. BatchDescriptorToVlogString(), as the code that calls these
|
// e.g. BatchDescriptorToVlogString(), as the code that calls these
|
||||||
// functions does not know what the type of the parameter is.
|
// functions does not know what the type of the parameter is.
|
||||||
string ToVlogString(const dnn::BatchDescriptor &descriptor) {
|
std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
|
||||||
return descriptor.ToShortString();
|
return descriptor.ToShortString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const dnn::FilterDescriptor &descriptor) {
|
std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
|
||||||
return descriptor.ToShortString();
|
return descriptor.ToShortString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
|
std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
|
||||||
return descriptor.ToShortString();
|
return descriptor.ToShortString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
|
std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
|
||||||
return descriptor.ToShortString();
|
return descriptor.ToShortString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
|
std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
|
||||||
return descriptor.ToShortString();
|
return descriptor.ToShortString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(dnn::ActivationMode mode) {
|
std::string ToVlogString(dnn::ActivationMode mode) {
|
||||||
return dnn::ActivationModeString(mode);
|
return dnn::ActivationModeString(mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
|
std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
|
||||||
return algo_config.ToString();
|
return algo_config.ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(dnn::ElementwiseOperation op) {
|
std::string ToVlogString(dnn::ElementwiseOperation op) {
|
||||||
return dnn::ElementwiseOperationString(op);
|
return dnn::ElementwiseOperationString(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(dnn::QuantizedActivationMode mode) {
|
std::string ToVlogString(dnn::QuantizedActivationMode mode) {
|
||||||
return dnn::QuantizedActivationModeString(mode);
|
return dnn::QuantizedActivationModeString(mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
|
std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
|
||||||
|
|
||||||
string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
|
std::string ToVlogString(blas::UpperLower ul) {
|
||||||
|
return blas::UpperLowerString(ul);
|
||||||
|
}
|
||||||
|
|
||||||
string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
|
std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
|
||||||
|
|
||||||
string ToVlogString(blas::Side s) { return blas::SideString(s); }
|
std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
|
||||||
|
|
||||||
string ToVlogString(blas::ComputationType ty) {
|
std::string ToVlogString(blas::ComputationType ty) {
|
||||||
return blas::ComputationTypeString(ty);
|
return blas::ComputationTypeString(ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const void *ptr) {
|
std::string ToVlogString(const void *ptr) {
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
return "null";
|
return "null";
|
||||||
}
|
}
|
||||||
@ -95,7 +97,7 @@ string ToVlogString(const void *ptr) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
string ToVlogString(const std::complex<T> &c) {
|
std::string ToVlogString(const std::complex<T> &c) {
|
||||||
// StrCat does not convert std::complex to text.
|
// StrCat does not convert std::complex to text.
|
||||||
std::ostringstream out;
|
std::ostringstream out;
|
||||||
out << c;
|
out << c;
|
||||||
@ -103,36 +105,36 @@ string ToVlogString(const std::complex<T> &c) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
string ToVlogString(const std::function<T> &f) {
|
std::string ToVlogString(const std::function<T> &f) {
|
||||||
return f == nullptr ? "null" : "<non-null function>";
|
return f == nullptr ? "null" : "<non-null function>";
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const DeviceMemoryBase &memory) {
|
std::string ToVlogString(const DeviceMemoryBase &memory) {
|
||||||
return ToVlogString(memory.opaque());
|
return ToVlogString(memory.opaque());
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const DeviceMemoryBase *memory) {
|
std::string ToVlogString(const DeviceMemoryBase *memory) {
|
||||||
return memory == nullptr ? "null" : ToVlogString(*memory);
|
return memory == nullptr ? "null" : ToVlogString(*memory);
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(const Eigen::half &h) {
|
std::string ToVlogString(const Eigen::half &h) {
|
||||||
return absl::StrCat(static_cast<float>(h));
|
return absl::StrCat(static_cast<float>(h));
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(int i) { return absl::StrCat(i); }
|
std::string ToVlogString(int i) { return absl::StrCat(i); }
|
||||||
|
|
||||||
string ToVlogString(uint32 i) { return absl::StrCat(i); }
|
std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
|
||||||
|
|
||||||
string ToVlogString(uint64 i) { return absl::StrCat(i); }
|
std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
|
||||||
|
|
||||||
string ToVlogString(int64 i) { return absl::StrCat(i); }
|
std::string ToVlogString(int64 i) { return absl::StrCat(i); }
|
||||||
|
|
||||||
string ToVlogString(float f) { return absl::StrCat(f); }
|
std::string ToVlogString(float f) { return absl::StrCat(f); }
|
||||||
|
|
||||||
string ToVlogString(double d) { return absl::StrCat(d); }
|
std::string ToVlogString(double d) { return absl::StrCat(d); }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
|
std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
|
||||||
if (memory_or_constant.is_pointer()) {
|
if (memory_or_constant.is_pointer()) {
|
||||||
return ToVlogString(memory_or_constant.pointer());
|
return ToVlogString(memory_or_constant.pointer());
|
||||||
}
|
}
|
||||||
@ -140,8 +142,8 @@ string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
string ToVlogString(port::ArraySlice<T> elements) {
|
std::string ToVlogString(port::ArraySlice<T> elements) {
|
||||||
string str = absl::StrCat(
|
std::string str = absl::StrCat(
|
||||||
ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
|
ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
|
||||||
elements.size(), "]{");
|
elements.size(), "]{");
|
||||||
const char *separator = "";
|
const char *separator = "";
|
||||||
@ -166,11 +168,11 @@ string ToVlogString(port::ArraySlice<T> elements) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
string ToVlogString(port::MutableArraySlice<T> elements) {
|
std::string ToVlogString(port::MutableArraySlice<T> elements) {
|
||||||
return ToVlogString(port::ArraySlice<T>(elements));
|
return ToVlogString(port::ArraySlice<T>(elements));
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
|
std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
|
||||||
switch (depth_to_space_layout) {
|
switch (depth_to_space_layout) {
|
||||||
case dnn::DepthToSpaceLayout::DepthHeightWidth:
|
case dnn::DepthToSpaceLayout::DepthHeightWidth:
|
||||||
return "DepthToSpaceLayout::DepthHeightWidth";
|
return "DepthToSpaceLayout::DepthHeightWidth";
|
||||||
@ -178,7 +180,7 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
|
|||||||
return "unknown DepthToSpaceLayout";
|
return "unknown DepthToSpaceLayout";
|
||||||
}
|
}
|
||||||
|
|
||||||
string ToVlogString(dnn::DataType data_type) {
|
std::string ToVlogString(dnn::DataType data_type) {
|
||||||
switch (data_type) {
|
switch (data_type) {
|
||||||
case dnn::DataType::kFloat:
|
case dnn::DataType::kFloat:
|
||||||
return "dnn::DataType::kFloat";
|
return "dnn::DataType::kFloat";
|
||||||
@ -205,14 +207,14 @@ string ToVlogString(dnn::DataType data_type) {
|
|||||||
// See VLOG_CALL for a short-hand for this. This way of doing it saves
|
// See VLOG_CALL for a short-hand for this. This way of doing it saves
|
||||||
// a tremendous amount of boilerplate code given how many functions
|
// a tremendous amount of boilerplate code given how many functions
|
||||||
// there are on Stream and how many parameters they each have.
|
// there are on Stream and how many parameters they each have.
|
||||||
string CallStr(const char *function_name, Stream *stream,
|
std::string CallStr(const char *function_name, Stream *stream,
|
||||||
std::vector<std::pair<const char *, string>> params) {
|
std::vector<std::pair<const char *, std::string>> params) {
|
||||||
// Do not call this function unless VLOG is on since just
|
// Do not call this function unless VLOG is on since just
|
||||||
// constructing all the strings in params is expensive.
|
// constructing all the strings in params is expensive.
|
||||||
CHECK(VLOG_IS_ON(1));
|
CHECK(VLOG_IS_ON(1));
|
||||||
|
|
||||||
string str = absl::StrCat(stream->DebugStreamPointers(),
|
std::string str = absl::StrCat(stream->DebugStreamPointers(),
|
||||||
" Called Stream::", function_name, "(");
|
" Called Stream::", function_name, "(");
|
||||||
const char *separator = "";
|
const char *separator = "";
|
||||||
for (const auto ¶m : params) {
|
for (const auto ¶m : params) {
|
||||||
absl::StrAppend(&str, separator, param.first, "=", param.second);
|
absl::StrAppend(&str, separator, param.first, "=", param.second);
|
||||||
@ -5470,7 +5472,7 @@ void Stream::RunAfterBlockHostUntilDoneCallbacks() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string Stream::DebugStreamPointers() const {
|
std::string Stream::DebugStreamPointers() const {
|
||||||
// Relies on the ToVlogString(const void*) overload above.
|
// Relies on the ToVlogString(const void*) overload above.
|
||||||
return absl::StrCat("[stream=", ToVlogString(this),
|
return absl::StrCat("[stream=", ToVlogString(this),
|
||||||
",impl=", ToVlogString(implementation_.get()), "]");
|
",impl=", ToVlogString(implementation_.get()), "]");
|
||||||
|
@ -2014,7 +2014,7 @@ class Stream {
|
|||||||
internal::TemporaryMemoryManager *temporary_memory_manager();
|
internal::TemporaryMemoryManager *temporary_memory_manager();
|
||||||
|
|
||||||
// Returns a debugging string "[stream=0x...,impl=0x...]".
|
// Returns a debugging string "[stream=0x...,impl=0x...]".
|
||||||
string DebugStreamPointers() const;
|
std::string DebugStreamPointers() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class host::HostBlas; // for parent_.
|
friend class host::HostBlas; // for parent_.
|
||||||
|
@ -286,8 +286,9 @@ class StreamExecutorInterface {
|
|||||||
// If ModuleHandle is set then we search for `symbol_name` only within the
|
// If ModuleHandle is set then we search for `symbol_name` only within the
|
||||||
// module corresponding to `module_handle`. Otherwise all loaded modules are
|
// module corresponding to `module_handle`. Otherwise all loaded modules are
|
||||||
// searched.
|
// searched.
|
||||||
virtual bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
|
virtual bool GetSymbol(const std::string &symbol_name,
|
||||||
void **mem, size_t *bytes) {
|
ModuleHandle module_handle, void **mem,
|
||||||
|
size_t *bytes) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ bool FLAGS_check_device_leaks = false;
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
string StackTraceIfVLOG10() {
|
std::string StackTraceIfVLOG10() {
|
||||||
if (VLOG_IS_ON(10)) {
|
if (VLOG_IS_ON(10)) {
|
||||||
return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
|
return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
|
||||||
} else {
|
} else {
|
||||||
@ -149,7 +149,7 @@ StreamExecutor::StreamExecutor(
|
|||||||
mem_alloc_bytes_(0),
|
mem_alloc_bytes_(0),
|
||||||
memory_limit_bytes_(GetMemoryLimitBytes()),
|
memory_limit_bytes_(GetMemoryLimitBytes()),
|
||||||
allocator_(this) {
|
allocator_(this) {
|
||||||
string name = absl::AsciiStrToLower(platform_->Name());
|
std::string name = absl::AsciiStrToLower(platform_->Name());
|
||||||
if (name == "cuda") {
|
if (name == "cuda") {
|
||||||
platform_kind_ = PlatformKind::kCuda;
|
platform_kind_ = PlatformKind::kCuda;
|
||||||
} else if (name == "rocm") {
|
} else if (name == "rocm") {
|
||||||
@ -239,7 +239,7 @@ port::Status StreamExecutor::SetDeviceSharedMemoryConfig(
|
|||||||
if (config != SharedMemoryConfig::kDefault &&
|
if (config != SharedMemoryConfig::kDefault &&
|
||||||
config != SharedMemoryConfig::kFourByte &&
|
config != SharedMemoryConfig::kFourByte &&
|
||||||
config != SharedMemoryConfig::kEightByte) {
|
config != SharedMemoryConfig::kEightByte) {
|
||||||
string error_msg = absl::StrFormat(
|
std::string error_msg = absl::StrFormat(
|
||||||
"Invalid shared memory config specified: %d", static_cast<int>(config));
|
"Invalid shared memory config specified: %d", static_cast<int>(config));
|
||||||
LOG(ERROR) << error_msg;
|
LOG(ERROR) << error_msg;
|
||||||
return port::Status(port::error::INVALID_ARGUMENT, error_msg);
|
return port::Status(port::error::INVALID_ARGUMENT, error_msg);
|
||||||
@ -492,7 +492,7 @@ DeviceMemoryBase StreamExecutor::Allocate(uint64 size, int64 memory_space) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
|
port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
|
||||||
const string &symbol_name, ModuleHandle module_handle) {
|
const std::string &symbol_name, ModuleHandle module_handle) {
|
||||||
// If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
|
// If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
|
||||||
// be nullptr/0 for consistency with DeviceMemory semantics.
|
// be nullptr/0 for consistency with DeviceMemory semantics.
|
||||||
void *opaque = nullptr;
|
void *opaque = nullptr;
|
||||||
@ -515,7 +515,7 @@ port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool StreamExecutor::GetSymbol(const string &symbol_name,
|
bool StreamExecutor::GetSymbol(const std::string &symbol_name,
|
||||||
ModuleHandle module_handle, void **mem,
|
ModuleHandle module_handle, void **mem,
|
||||||
size_t *bytes) {
|
size_t *bytes) {
|
||||||
return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
|
return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
|
||||||
|
@ -50,7 +50,7 @@ struct AllocRecord {
|
|||||||
// Holds a representation of the stack at the time the associated buffer was
|
// Holds a representation of the stack at the time the associated buffer was
|
||||||
// allocated. Produced in a form described in
|
// allocated. Produced in a form described in
|
||||||
// //util/symbolize/symbolized_stacktrace.h.
|
// //util/symbolize/symbolized_stacktrace.h.
|
||||||
string stack_trace;
|
std::string stack_trace;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Forward declaration of private friend class.
|
// Forward declaration of private friend class.
|
||||||
@ -175,12 +175,12 @@ class StreamExecutor {
|
|||||||
// If `module_handle` is set then searches only within the module
|
// If `module_handle` is set then searches only within the module
|
||||||
// corresponding to `module_handle`.
|
// corresponding to `module_handle`.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name,
|
port::StatusOr<DeviceMemory<T>> GetSymbol(const std::string &symbol_name,
|
||||||
ModuleHandle module_handle = {});
|
ModuleHandle module_handle = {});
|
||||||
|
|
||||||
// An untyped version of GetSymbol.
|
// An untyped version of GetSymbol.
|
||||||
port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
|
port::StatusOr<DeviceMemoryBase> GetUntypedSymbol(
|
||||||
const string &symbol_name, ModuleHandle module_handle = {});
|
const std::string &symbol_name, ModuleHandle module_handle = {});
|
||||||
|
|
||||||
// Deallocate the DeviceMemory previously allocated via this interface.
|
// Deallocate the DeviceMemory previously allocated via this interface.
|
||||||
// Deallocation of a nullptr-representative value is permitted.
|
// Deallocation of a nullptr-representative value is permitted.
|
||||||
@ -554,7 +554,7 @@ class StreamExecutor {
|
|||||||
|
|
||||||
// Finds and retrieves device memory for the symbol on the underlying
|
// Finds and retrieves device memory for the symbol on the underlying
|
||||||
// platform.
|
// platform.
|
||||||
bool GetSymbol(const string &symbol_name, ModuleHandle module_handle,
|
bool GetSymbol(const std::string &symbol_name, ModuleHandle module_handle,
|
||||||
void **mem, size_t *bytes);
|
void **mem, size_t *bytes);
|
||||||
|
|
||||||
// Entrains a memcpy operation onto stream, with a host destination location
|
// Entrains a memcpy operation onto stream, with a host destination location
|
||||||
@ -805,7 +805,7 @@ inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
|
inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol(
|
||||||
const string &symbol_name, ModuleHandle module_handle) {
|
const std::string &symbol_name, ModuleHandle module_handle) {
|
||||||
port::StatusOr<DeviceMemoryBase> untyped_symbol =
|
port::StatusOr<DeviceMemoryBase> untyped_symbol =
|
||||||
GetUntypedSymbol(symbol_name, module_handle);
|
GetUntypedSymbol(symbol_name, module_handle);
|
||||||
if (!untyped_symbol.ok()) {
|
if (!untyped_symbol.ok()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user