Support of broadcasting for ElementwiseTwoInput.
MUL and ADD added to ElementwiseTwoInput(for some cases). Removed ApplyMask(done as MUL in ElementwiseTwoInput). PiperOrigin-RevId: 292931827 Change-Id: I32ef9b5ece3f47011d6da3f361a5a6b25a8b9127
This commit is contained in:
parent
05a122df52
commit
59145d293a
@ -105,13 +105,15 @@ ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition,
|
||||
ElementwiseTwoInput::ElementwiseTwoInput(ElementwiseTwoInput&& operation)
|
||||
: ElementwiseOperation(std::move(operation)),
|
||||
link_index_(operation.link_index_),
|
||||
op_type_(operation.op_type_) {}
|
||||
op_type_(operation.op_type_),
|
||||
broadcast_(operation.broadcast_) {}
|
||||
|
||||
ElementwiseTwoInput& ElementwiseTwoInput::operator=(
|
||||
ElementwiseTwoInput&& operation) {
|
||||
if (this != &operation) {
|
||||
link_index_ = operation.link_index_;
|
||||
op_type_ = operation.op_type_;
|
||||
broadcast_ = operation.broadcast_;
|
||||
ElementwiseOperation::operator=(std::move(operation));
|
||||
}
|
||||
return *this;
|
||||
@ -121,31 +123,46 @@ void ElementwiseTwoInput::SetLinkIndex(int index) { link_index_ = index; }
|
||||
|
||||
std::string ElementwiseTwoInput::GetCoreCode(
|
||||
const LinkingContext& context) const {
|
||||
const std::string size_name = "src_size_" + std::to_string(link_index_);
|
||||
TensorCodeGenerator src_tensor(
|
||||
absl::StrCat("src_data_", link_index_),
|
||||
WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
|
||||
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
||||
definition_.src_tensors[1]);
|
||||
std::string result;
|
||||
const std::string x_coord = broadcast_.width ? "0" : context.x_coord;
|
||||
const std::string y_coord = broadcast_.height ? "0" : context.y_coord;
|
||||
const std::string s_coord = broadcast_.channels ? "0" : context.s_coord;
|
||||
const std::string second_var = "second_var_" + std::to_string(link_index_);
|
||||
std::string result = " FLT4 " + second_var + " = " +
|
||||
src_tensor.ReadWHS(x_coord, y_coord, s_coord) + ";\n";
|
||||
if (broadcast_.channels) {
|
||||
result += " " + second_var + ".y = " + second_var + ".x;\n";
|
||||
result += " " + second_var + ".z = " + second_var + ".x;\n";
|
||||
result += " " + second_var + ".w = " + second_var + ".x;\n";
|
||||
}
|
||||
switch (op_type_) {
|
||||
case OperationType::ADD:
|
||||
result += "$0 += $1;\n";
|
||||
break;
|
||||
case OperationType::DIV:
|
||||
result = "$0 /= $1;\n";
|
||||
result += "$0 /= $1;\n";
|
||||
break;
|
||||
case OperationType::MUL:
|
||||
result += "$0 *= $1;\n";
|
||||
break;
|
||||
case OperationType::POW:
|
||||
result = "$0 = pow($0, $1);\n";
|
||||
result += "$0 = pow($0, $1);\n";
|
||||
break;
|
||||
case OperationType::SQUARED_DIFF:
|
||||
result = "$0 -= $1;\n";
|
||||
result += "$0 -= $1;\n";
|
||||
result += "$0 *= $0;\n";
|
||||
break;
|
||||
case OperationType::SUB:
|
||||
result = "$0 -= $1;\n";
|
||||
result += "$0 -= $1;\n";
|
||||
break;
|
||||
default:
|
||||
return "Unknown operation type;\n";
|
||||
}
|
||||
return absl::Substitute(
|
||||
result, context.var_name,
|
||||
src_tensor.ReadWHS(context.x_coord, context.y_coord, context.s_coord));
|
||||
return absl::Substitute(result, context.var_name, second_var);
|
||||
}
|
||||
|
||||
std::string ElementwiseTwoInput::GetArgsDeclaration() const {
|
||||
@ -164,9 +181,21 @@ Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ElementwiseTwoInput CreateElementwiseTwoInput(
|
||||
const OperationDef& definition, const OperationType& op_type,
|
||||
const BroadcastSettings& broadcast) {
|
||||
ElementwiseTwoInput operation(definition, op_type, broadcast);
|
||||
operation.SetLinkIndex(0);
|
||||
return operation;
|
||||
}
|
||||
|
||||
ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition,
|
||||
const OperationType& op_type) {
|
||||
ElementwiseTwoInput operation(definition, op_type);
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = false;
|
||||
broadcast.height = false;
|
||||
broadcast.channels = false;
|
||||
ElementwiseTwoInput operation(definition, op_type, broadcast);
|
||||
operation.SetLinkIndex(0);
|
||||
return operation;
|
||||
}
|
||||
|
@ -48,13 +48,22 @@ class ElementwiseOneInput : public ElementwiseOperation {
|
||||
ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition,
|
||||
const OperationType& op_type);
|
||||
|
||||
struct BroadcastSettings {
|
||||
bool width;
|
||||
bool height;
|
||||
bool channels;
|
||||
};
|
||||
|
||||
// Class for simple two input operations without any parameters, for example
|
||||
// sub, div and etc.
|
||||
class ElementwiseTwoInput : public ElementwiseOperation {
|
||||
public:
|
||||
explicit ElementwiseTwoInput(const OperationDef& definition,
|
||||
const OperationType& op_type)
|
||||
: ElementwiseOperation(definition), op_type_(op_type) {}
|
||||
const OperationType& op_type,
|
||||
const BroadcastSettings& broadcast)
|
||||
: ElementwiseOperation(definition),
|
||||
op_type_(op_type),
|
||||
broadcast_(broadcast) {}
|
||||
|
||||
// Move only
|
||||
ElementwiseTwoInput(ElementwiseTwoInput&& operation);
|
||||
@ -70,8 +79,13 @@ class ElementwiseTwoInput : public ElementwiseOperation {
|
||||
private:
|
||||
int link_index_;
|
||||
OperationType op_type_;
|
||||
BroadcastSettings broadcast_;
|
||||
};
|
||||
|
||||
ElementwiseTwoInput CreateElementwiseTwoInput(
|
||||
const OperationDef& definition, const OperationType& op_type,
|
||||
const BroadcastSettings& broadcast);
|
||||
|
||||
ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition,
|
||||
const OperationType& op_type);
|
||||
|
||||
|
@ -397,6 +397,128 @@ TEST_F(OpenCLOperationTest, Pow) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Add) {
|
||||
TensorFloat32 src_tensor_0, src_tensor_1;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_1.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
|
||||
src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, OperationType::ADD);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {1.5f, 3.0f, 6.0f, 6.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, Mul) {
|
||||
TensorFloat32 src_tensor_0, src_tensor_1;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_1.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
|
||||
src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, OperationType::MUL);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.5f, 2.0f, 9.0f, 6.75f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, MulBroadcastHW) {
|
||||
TensorFloat32 src_tensor_0, src_tensor_1;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_1.shape = BHWC(1, 1, 1, 2);
|
||||
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
|
||||
src_tensor_1.data = {0.5f, 3.0f};
|
||||
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = true;
|
||||
broadcast.height = true;
|
||||
broadcast.channels = false;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.5f, 6.0f, 1.5f, 13.5f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, MulBroadcastChannels) {
|
||||
TensorFloat32 src_tensor_0, src_tensor_1;
|
||||
src_tensor_0.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor_1.shape = BHWC(1, 2, 1, 1);
|
||||
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
|
||||
src_tensor_1.data = {0.5f, 3.0f};
|
||||
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = false;
|
||||
broadcast.height = false;
|
||||
broadcast.channels = true;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.5f, 1.0f, 9.0f, 13.5f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -175,76 +175,6 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ApplyMask::ApplyMask(ApplyMask&& operation)
|
||||
: ElementwiseOperation(std::move(operation)),
|
||||
mask_type_(operation.mask_type_),
|
||||
link_index_(operation.link_index_) {}
|
||||
|
||||
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
|
||||
if (this != &operation) {
|
||||
mask_type_ = operation.mask_type_;
|
||||
link_index_ = operation.link_index_;
|
||||
ElementwiseOperation::operator=(std::move(operation));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
|
||||
|
||||
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
TensorCodeGenerator mask(
|
||||
tensor_name,
|
||||
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
||||
definition_.src_tensors[1]);
|
||||
switch (mask_type_) {
|
||||
case MaskType::TENSOR:
|
||||
return context.var_name + " *= " +
|
||||
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
|
||||
";\n";
|
||||
case MaskType::CHANNELS:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
|
||||
case MaskType::LAYER:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
|
||||
".x;\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ApplyMask::GetArgsDeclaration() const {
|
||||
std::string args;
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
absl::StrAppend(&args, ",\n",
|
||||
GetTensorDeclaration(AccessType::READ, tensor_name,
|
||||
definition_.src_tensors[1]));
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
absl::StrAppend(&args, ",\n int4 ", size_name);
|
||||
return args;
|
||||
}
|
||||
|
||||
Status ApplyMask::BindArguments(CLKernel* kernel) {
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape) {
|
||||
ApplyMask::MaskType mask_type;
|
||||
if (mask_shape == src_shape) {
|
||||
mask_type = ApplyMask::MaskType::TENSOR;
|
||||
} else if (mask_shape.c == 1) {
|
||||
mask_type = ApplyMask::MaskType::LAYER;
|
||||
} else {
|
||||
mask_type = ApplyMask::MaskType::CHANNELS;
|
||||
}
|
||||
ApplyMask operation(definition, mask_type);
|
||||
operation.SetLinkIndex(0);
|
||||
return operation;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -126,36 +126,6 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
class ApplyMask : public ElementwiseOperation {
|
||||
public:
|
||||
// Move only
|
||||
ApplyMask(ApplyMask&& operation);
|
||||
ApplyMask& operator=(ApplyMask&& operation);
|
||||
ApplyMask(const ApplyMask&) = delete;
|
||||
ApplyMask& operator=(const ApplyMask&) = delete;
|
||||
|
||||
void SetLinkIndex(int index) override;
|
||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||
std::string GetArgsDeclaration() const override;
|
||||
Status BindArguments(CLKernel* kernel) override;
|
||||
|
||||
private:
|
||||
friend ApplyMask CreateApplyMask(const OperationDef& definition,
|
||||
const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
enum class MaskType { LAYER, CHANNELS, TENSOR };
|
||||
|
||||
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
|
||||
: ElementwiseOperation(definition), mask_type_(mask_type) {}
|
||||
|
||||
MaskType mask_type_;
|
||||
int link_index_;
|
||||
};
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -181,96 +181,6 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 1);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
|
||||
3.0f, 0.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
|
||||
1.5f, 4.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 1, 1, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
|
||||
1.5f, 8.0f, 3.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -26,11 +26,32 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
bool IsWidthBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
|
||||
inputs[1]->tensor.shape.w == 1;
|
||||
}
|
||||
bool IsHeightBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
|
||||
inputs[1]->tensor.shape.h == 1;
|
||||
}
|
||||
bool IsChannelsBroadcastedForSecondInput(
|
||||
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
|
||||
return inputs.size() == 2 &&
|
||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
|
||||
inputs[1]->tensor.shape.c == 1;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
const OperationDef& op_def, ModelHints hints,
|
||||
@ -59,12 +80,23 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
if (adds || adds_scalar) {
|
||||
return SelectBroadcastAdd(attr, creation_context, op_def, gpu_op);
|
||||
} else {
|
||||
auto output = outputs[0];
|
||||
std::vector<int> channels(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
channels[i] = inputs[i]->tensor.shape.c;
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
|
||||
if (broadcast.width || broadcast.height || broadcast.channels) {
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, op_type, broadcast);
|
||||
*gpu_op =
|
||||
absl::make_unique<ElementwiseTwoInput>(std::move(operation));
|
||||
} else {
|
||||
auto output = outputs[0];
|
||||
std::vector<int> channels(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
channels[i] = inputs[i]->tensor.shape.c;
|
||||
}
|
||||
SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op);
|
||||
}
|
||||
SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op);
|
||||
return OkStatus();
|
||||
}
|
||||
}
|
||||
@ -121,8 +153,20 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
|
||||
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
|
||||
} else {
|
||||
SelectApplyMask(op_def, inputs[0]->tensor.shape,
|
||||
inputs[1]->tensor.shape, gpu_op);
|
||||
if (inputs.size() == 2) {
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, op_type, broadcast);
|
||||
*gpu_op =
|
||||
absl::make_unique<ElementwiseTwoInput>(std::move(operation));
|
||||
return OkStatus();
|
||||
} else {
|
||||
return UnimplementedError(
|
||||
"No support of multiply with more than 2 inputs");
|
||||
}
|
||||
return OkStatus();
|
||||
}
|
||||
}
|
||||
@ -190,8 +234,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
case OperationType::POW:
|
||||
case OperationType::SQUARED_DIFF:
|
||||
case OperationType::SUB: {
|
||||
BroadcastSettings broadcast;
|
||||
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
|
||||
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
|
||||
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
|
||||
ElementwiseTwoInput operation =
|
||||
CreateElementwiseTwoInput(op_def, op_type);
|
||||
CreateElementwiseTwoInput(op_def, op_type, broadcast);
|
||||
*gpu_op = absl::make_unique<ElementwiseTwoInput>(std::move(operation));
|
||||
return OkStatus();
|
||||
}
|
||||
|
@ -43,13 +43,6 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape,
|
||||
const BHWC& mask_shape,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
ApplyMask operation = CreateApplyMask(op_def, src_shape, mask_shape);
|
||||
*ptr = absl::make_unique<ApplyMask>(std::move(operation));
|
||||
}
|
||||
|
||||
void SelectLSTM(const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
LSTM operation = CreateLSTM(op_def);
|
||||
|
@ -27,10 +27,6 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape,
|
||||
const BHWC& mask_shape,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
void SelectLSTM(const OperationDef& op_def, std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
void SelectReLU(const CreationContext& creation_context,
|
||||
|
Loading…
Reference in New Issue
Block a user