Added support of batch to all operations with any storage types.
Changed layout for BUFFER and IMAGE_BUFFER PiperOrigin-RevId: 272498500
This commit is contained in:
parent
e7fb9e5202
commit
811e7b67c8
@ -182,14 +182,6 @@ std::vector<TensorStorageType> Environment::GetSupportedTextureStorages()
|
||||
return storage_types;
|
||||
}
|
||||
|
||||
std::vector<TensorStorageType> Environment::GetSupportedBatchStorages() const {
|
||||
std::vector<TensorStorageType> storage_types = {TensorStorageType::BUFFER};
|
||||
if (device_.IsAdreno() && device_.SupportsImageBuffer()) {
|
||||
storage_types.push_back(TensorStorageType::IMAGE_BUFFER);
|
||||
}
|
||||
return storage_types;
|
||||
}
|
||||
|
||||
std::vector<TensorStorageType> Environment::GetSupportedStorages() const {
|
||||
std::vector<TensorStorageType> storage_types = {TensorStorageType::TEXTURE_2D,
|
||||
TensorStorageType::BUFFER};
|
||||
|
@ -56,7 +56,6 @@ class Environment {
|
||||
std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
|
||||
bool IsSupported(CalculationsPrecision precision) const;
|
||||
std::vector<TensorStorageType> GetSupportedTextureStorages() const;
|
||||
std::vector<TensorStorageType> GetSupportedBatchStorages() const;
|
||||
std::vector<TensorStorageType> GetSupportedStorages() const;
|
||||
|
||||
void SetHighPerformance() const;
|
||||
|
@ -234,9 +234,20 @@ std::string TensorCodeGenerator::GetGlobalAddressNoDeclaration(
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", x,
|
||||
y, z, b, sizes_.depth, sizes_.height,
|
||||
sizes_.width);
|
||||
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", b,
|
||||
x, y, z, sizes_.height, sizes_.width,
|
||||
sizes_.batch_size);
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
return absl::Substitute("(int2)(($0) * ($4) + ($1), ($2) * $5 + ($3))", x,
|
||||
b, y, z, sizes_.batch_size, sizes_.depth);
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
return absl::Substitute("(int2)(($0) * ($3) + ($1), ($2))", x, b, y,
|
||||
sizes_.batch_size);
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return absl::Substitute("(int4)(($0) * ($4) + ($1), ($2), ($3), 0)", x, b,
|
||||
y, z, sizes_.batch_size);
|
||||
case TensorStorageType::UNKNOWN:
|
||||
return "error";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
|
@ -95,8 +95,6 @@ class Tensor {
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
return (((b * Depth() + d) * shape_.h + y) * shape_.w + x) * 4 +
|
||||
sub_d; // BDHWC4
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
return (((d * shape_.h + y) * shape_.w + x) * shape_.b + b) * 4 +
|
||||
sub_d; // DHWBC4
|
||||
|
Loading…
x
Reference in New Issue
Block a user