Added support of batch to all operations with any storage types.
Changed layout for BUFFER and IMAGE_BUFFER PiperOrigin-RevId: 272498500
This commit is contained in:
parent
e7fb9e5202
commit
811e7b67c8
@ -182,14 +182,6 @@ std::vector<TensorStorageType> Environment::GetSupportedTextureStorages()
|
|||||||
return storage_types;
|
return storage_types;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TensorStorageType> Environment::GetSupportedBatchStorages() const {
|
|
||||||
std::vector<TensorStorageType> storage_types = {TensorStorageType::BUFFER};
|
|
||||||
if (device_.IsAdreno() && device_.SupportsImageBuffer()) {
|
|
||||||
storage_types.push_back(TensorStorageType::IMAGE_BUFFER);
|
|
||||||
}
|
|
||||||
return storage_types;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<TensorStorageType> Environment::GetSupportedStorages() const {
|
std::vector<TensorStorageType> Environment::GetSupportedStorages() const {
|
||||||
std::vector<TensorStorageType> storage_types = {TensorStorageType::TEXTURE_2D,
|
std::vector<TensorStorageType> storage_types = {TensorStorageType::TEXTURE_2D,
|
||||||
TensorStorageType::BUFFER};
|
TensorStorageType::BUFFER};
|
||||||
|
@ -56,7 +56,6 @@ class Environment {
|
|||||||
std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
|
std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
|
||||||
bool IsSupported(CalculationsPrecision precision) const;
|
bool IsSupported(CalculationsPrecision precision) const;
|
||||||
std::vector<TensorStorageType> GetSupportedTextureStorages() const;
|
std::vector<TensorStorageType> GetSupportedTextureStorages() const;
|
||||||
std::vector<TensorStorageType> GetSupportedBatchStorages() const;
|
|
||||||
std::vector<TensorStorageType> GetSupportedStorages() const;
|
std::vector<TensorStorageType> GetSupportedStorages() const;
|
||||||
|
|
||||||
void SetHighPerformance() const;
|
void SetHighPerformance() const;
|
||||||
|
@ -234,9 +234,20 @@ std::string TensorCodeGenerator::GetGlobalAddressNoDeclaration(
|
|||||||
switch (descriptor_.storage_type) {
|
switch (descriptor_.storage_type) {
|
||||||
case TensorStorageType::BUFFER:
|
case TensorStorageType::BUFFER:
|
||||||
case TensorStorageType::IMAGE_BUFFER:
|
case TensorStorageType::IMAGE_BUFFER:
|
||||||
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", x,
|
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", b,
|
||||||
y, z, b, sizes_.depth, sizes_.height,
|
x, y, z, sizes_.height, sizes_.width,
|
||||||
sizes_.width);
|
sizes_.batch_size);
|
||||||
|
case TensorStorageType::TEXTURE_2D:
|
||||||
|
return absl::Substitute("(int2)(($0) * ($4) + ($1), ($2) * $5 + ($3))", x,
|
||||||
|
b, y, z, sizes_.batch_size, sizes_.depth);
|
||||||
|
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||||
|
return absl::Substitute("(int2)(($0) * ($3) + ($1), ($2))", x, b, y,
|
||||||
|
sizes_.batch_size);
|
||||||
|
case TensorStorageType::TEXTURE_ARRAY:
|
||||||
|
return absl::Substitute("(int4)(($0) * ($4) + ($1), ($2), ($3), 0)", x, b,
|
||||||
|
y, z, sizes_.batch_size);
|
||||||
|
case TensorStorageType::UNKNOWN:
|
||||||
|
return "error";
|
||||||
default:
|
default:
|
||||||
return "error";
|
return "error";
|
||||||
}
|
}
|
||||||
|
@ -95,8 +95,6 @@ class Tensor {
|
|||||||
switch (descriptor_.storage_type) {
|
switch (descriptor_.storage_type) {
|
||||||
case TensorStorageType::BUFFER:
|
case TensorStorageType::BUFFER:
|
||||||
case TensorStorageType::IMAGE_BUFFER:
|
case TensorStorageType::IMAGE_BUFFER:
|
||||||
return (((b * Depth() + d) * shape_.h + y) * shape_.w + x) * 4 +
|
|
||||||
sub_d; // BDHWC4
|
|
||||||
case TensorStorageType::TEXTURE_ARRAY:
|
case TensorStorageType::TEXTURE_ARRAY:
|
||||||
return (((d * shape_.h + y) * shape_.w + x) * shape_.b + b) * 4 +
|
return (((d * shape_.h + y) * shape_.w + x) * shape_.b + b) * 4 +
|
||||||
sub_d; // DHWBC4
|
sub_d; // DHWBC4
|
||||||
|
Loading…
x
Reference in New Issue
Block a user