Added support of batch to all operations with any storage types.

Changed layout for BUFFER and IMAGE_BUFFER

PiperOrigin-RevId: 272498500
This commit is contained in:
A. Unique TensorFlower 2019-10-02 12:49:28 -07:00 committed by TensorFlower Gardener
parent e7fb9e5202
commit 811e7b67c8
4 changed files with 14 additions and 14 deletions

View File

@ -182,14 +182,6 @@ std::vector<TensorStorageType> Environment::GetSupportedTextureStorages()
return storage_types;
}
std::vector<TensorStorageType> Environment::GetSupportedBatchStorages() const {
std::vector<TensorStorageType> storage_types = {TensorStorageType::BUFFER};
if (device_.IsAdreno() && device_.SupportsImageBuffer()) {
storage_types.push_back(TensorStorageType::IMAGE_BUFFER);
}
return storage_types;
}
std::vector<TensorStorageType> Environment::GetSupportedStorages() const {
std::vector<TensorStorageType> storage_types = {TensorStorageType::TEXTURE_2D,
TensorStorageType::BUFFER};

View File

@ -56,7 +56,6 @@ class Environment {
std::vector<CalculationsPrecision> GetSupportedPrecisions() const;
bool IsSupported(CalculationsPrecision precision) const;
std::vector<TensorStorageType> GetSupportedTextureStorages() const;
std::vector<TensorStorageType> GetSupportedBatchStorages() const;
std::vector<TensorStorageType> GetSupportedStorages() const;
void SetHighPerformance() const;

View File

@ -234,9 +234,20 @@ std::string TensorCodeGenerator::GetGlobalAddressNoDeclaration(
switch (descriptor_.storage_type) {
case TensorStorageType::BUFFER:
case TensorStorageType::IMAGE_BUFFER:
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", x,
y, z, b, sizes_.depth, sizes_.height,
sizes_.width);
return absl::Substitute("(((($3) * $4 + $2) * $5 + ($1)) * $6 + ($0))", b,
x, y, z, sizes_.height, sizes_.width,
sizes_.batch_size);
case TensorStorageType::TEXTURE_2D:
return absl::Substitute("(int2)(($0) * ($4) + ($1), ($2) * $5 + ($3))", x,
b, y, z, sizes_.batch_size, sizes_.depth);
case TensorStorageType::SINGLE_TEXTURE_2D:
return absl::Substitute("(int2)(($0) * ($3) + ($1), ($2))", x, b, y,
sizes_.batch_size);
case TensorStorageType::TEXTURE_ARRAY:
return absl::Substitute("(int4)(($0) * ($4) + ($1), ($2), ($3), 0)", x, b,
y, z, sizes_.batch_size);
case TensorStorageType::UNKNOWN:
return "error";
default:
return "error";
}

View File

@ -95,8 +95,6 @@ class Tensor {
switch (descriptor_.storage_type) {
case TensorStorageType::BUFFER:
case TensorStorageType::IMAGE_BUFFER:
return (((b * Depth() + d) * shape_.h + y) * shape_.w + x) * 4 +
sub_d; // BDHWC4
case TensorStorageType::TEXTURE_ARRAY:
return (((d * shape_.h + y) * shape_.w + x) * shape_.b + b) * 4 +
sub_d; // DHWBC4