Support of broadcasting for ElementwiseTwoInput.

MUL and ADD added to ElementwiseTwoInput(for some cases).
Removed ApplyMask(done as MUL in ElementwiseTwoInput).

PiperOrigin-RevId: 292931827
Change-Id: I32ef9b5ece3f47011d6da3f361a5a6b25a8b9127
This commit is contained in:
Raman Sarokin 2020-02-03 09:05:30 -08:00 committed by TensorFlower Gardener
parent 05a122df52
commit 59145d293a
9 changed files with 234 additions and 222 deletions

View File

@ -105,13 +105,15 @@ ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition,
ElementwiseTwoInput::ElementwiseTwoInput(ElementwiseTwoInput&& operation)
: ElementwiseOperation(std::move(operation)),
link_index_(operation.link_index_),
op_type_(operation.op_type_) {}
op_type_(operation.op_type_),
broadcast_(operation.broadcast_) {}
ElementwiseTwoInput& ElementwiseTwoInput::operator=(
ElementwiseTwoInput&& operation) {
if (this != &operation) {
link_index_ = operation.link_index_;
op_type_ = operation.op_type_;
broadcast_ = operation.broadcast_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
@ -121,31 +123,46 @@ void ElementwiseTwoInput::SetLinkIndex(int index) { link_index_ = index; }
std::string ElementwiseTwoInput::GetCoreCode(
const LinkingContext& context) const {
const std::string size_name = "src_size_" + std::to_string(link_index_);
TensorCodeGenerator src_tensor(
absl::StrCat("src_data_", link_index_),
WHSPoint{"src_size.x", "src_size.y", "src_size.z"},
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
std::string result;
const std::string x_coord = broadcast_.width ? "0" : context.x_coord;
const std::string y_coord = broadcast_.height ? "0" : context.y_coord;
const std::string s_coord = broadcast_.channels ? "0" : context.s_coord;
const std::string second_var = "second_var_" + std::to_string(link_index_);
std::string result = " FLT4 " + second_var + " = " +
src_tensor.ReadWHS(x_coord, y_coord, s_coord) + ";\n";
if (broadcast_.channels) {
result += " " + second_var + ".y = " + second_var + ".x;\n";
result += " " + second_var + ".z = " + second_var + ".x;\n";
result += " " + second_var + ".w = " + second_var + ".x;\n";
}
switch (op_type_) {
case OperationType::ADD:
result += "$0 += $1;\n";
break;
case OperationType::DIV:
result = "$0 /= $1;\n";
result += "$0 /= $1;\n";
break;
case OperationType::MUL:
result += "$0 *= $1;\n";
break;
case OperationType::POW:
result = "$0 = pow($0, $1);\n";
result += "$0 = pow($0, $1);\n";
break;
case OperationType::SQUARED_DIFF:
result = "$0 -= $1;\n";
result += "$0 -= $1;\n";
result += "$0 *= $0;\n";
break;
case OperationType::SUB:
result = "$0 -= $1;\n";
result += "$0 -= $1;\n";
break;
default:
return "Unknown operation type;\n";
}
return absl::Substitute(
result, context.var_name,
src_tensor.ReadWHS(context.x_coord, context.y_coord, context.s_coord));
return absl::Substitute(result, context.var_name, second_var);
}
std::string ElementwiseTwoInput::GetArgsDeclaration() const {
@ -164,9 +181,21 @@ Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) {
return OkStatus();
}
ElementwiseTwoInput CreateElementwiseTwoInput(
const OperationDef& definition, const OperationType& op_type,
const BroadcastSettings& broadcast) {
ElementwiseTwoInput operation(definition, op_type, broadcast);
operation.SetLinkIndex(0);
return operation;
}
ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type) {
ElementwiseTwoInput operation(definition, op_type);
BroadcastSettings broadcast;
broadcast.width = false;
broadcast.height = false;
broadcast.channels = false;
ElementwiseTwoInput operation(definition, op_type, broadcast);
operation.SetLinkIndex(0);
return operation;
}

View File

@ -48,13 +48,22 @@ class ElementwiseOneInput : public ElementwiseOperation {
ElementwiseOneInput CreateElementwiseOneInput(const OperationDef& definition,
const OperationType& op_type);
struct BroadcastSettings {
bool width;
bool height;
bool channels;
};
// Class for simple two input operations without any parameters, for example
// sub, div and etc.
class ElementwiseTwoInput : public ElementwiseOperation {
public:
explicit ElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type)
: ElementwiseOperation(definition), op_type_(op_type) {}
const OperationType& op_type,
const BroadcastSettings& broadcast)
: ElementwiseOperation(definition),
op_type_(op_type),
broadcast_(broadcast) {}
// Move only
ElementwiseTwoInput(ElementwiseTwoInput&& operation);
@ -70,8 +79,13 @@ class ElementwiseTwoInput : public ElementwiseOperation {
private:
int link_index_;
OperationType op_type_;
BroadcastSettings broadcast_;
};
ElementwiseTwoInput CreateElementwiseTwoInput(
const OperationDef& definition, const OperationType& op_type,
const BroadcastSettings& broadcast);
ElementwiseTwoInput CreateElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type);

View File

@ -397,6 +397,128 @@ TEST_F(OpenCLOperationTest, Pow) {
}
}
TEST_F(OpenCLOperationTest, Add) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::ADD);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {1.5f, 3.0f, 6.0f, 6.0f}));
}
}
}
TEST_F(OpenCLOperationTest, Mul) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 2);
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
src_tensor_1.data = {0.5f, 1.0f, 3.0f, 1.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::MUL);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.5f, 2.0f, 9.0f, 6.75f}));
}
}
}
TEST_F(OpenCLOperationTest, MulBroadcastHW) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 1, 1, 2);
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
src_tensor_1.data = {0.5f, 3.0f};
BroadcastSettings broadcast;
broadcast.width = true;
broadcast.height = true;
broadcast.channels = false;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.5f, 6.0f, 1.5f, 13.5f}));
}
}
}
TEST_F(OpenCLOperationTest, MulBroadcastChannels) {
TensorFloat32 src_tensor_0, src_tensor_1;
src_tensor_0.shape = BHWC(1, 2, 1, 2);
src_tensor_1.shape = BHWC(1, 2, 1, 1);
src_tensor_0.data = {1.0f, 2.0f, 3.0f, 4.5f};
src_tensor_1.data = {0.5f, 3.0f};
BroadcastSettings broadcast;
broadcast.width = false;
broadcast.height = false;
broadcast.channels = true;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, OperationType::MUL, broadcast);
ASSERT_OK(ExecuteGPUOperation({src_tensor_0, src_tensor_1},
creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.5f, 1.0f, 9.0f, 13.5f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -175,76 +175,6 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
return OkStatus();
}
ApplyMask::ApplyMask(ApplyMask&& operation)
: ElementwiseOperation(std::move(operation)),
mask_type_(operation.mask_type_),
link_index_(operation.link_index_) {}
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
if (this != &operation) {
mask_type_ = operation.mask_type_;
link_index_ = operation.link_index_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
TensorCodeGenerator mask(
tensor_name,
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
switch (mask_type_) {
case MaskType::TENSOR:
return context.var_name + " *= " +
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
";\n";
case MaskType::CHANNELS:
return context.var_name +
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
case MaskType::LAYER:
return context.var_name +
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
".x;\n";
}
}
std::string ApplyMask::GetArgsDeclaration() const {
std::string args;
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ, tensor_name,
definition_.src_tensors[1]));
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
absl::StrAppend(&args, ",\n int4 ", size_name);
return args;
}
Status ApplyMask::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
return OkStatus();
}
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape) {
ApplyMask::MaskType mask_type;
if (mask_shape == src_shape) {
mask_type = ApplyMask::MaskType::TENSOR;
} else if (mask_shape.c == 1) {
mask_type = ApplyMask::MaskType::LAYER;
} else {
mask_type = ApplyMask::MaskType::CHANNELS;
}
ApplyMask operation(definition, mask_type);
operation.SetLinkIndex(0);
return operation;
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -126,36 +126,6 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
return OkStatus();
}
class ApplyMask : public ElementwiseOperation {
public:
// Move only
ApplyMask(ApplyMask&& operation);
ApplyMask& operator=(ApplyMask&& operation);
ApplyMask(const ApplyMask&) = delete;
ApplyMask& operator=(const ApplyMask&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
private:
friend ApplyMask CreateApplyMask(const OperationDef& definition,
const BHWC& src_shape,
const BHWC& mask_shape);
enum class MaskType { LAYER, CHANNELS, TENSOR };
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
: ElementwiseOperation(definition), mask_type_(mask_type) {}
MaskType mask_type_;
int link_index_;
};
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape);
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -181,96 +181,6 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
}
}
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 1);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
3.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 2);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
1.5f, 4.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 1, 1, 2);
mask_tensor.data = {2.0f, 0.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
1.5f, 8.0f, 3.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -26,11 +26,32 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace cl {
namespace {
bool IsWidthBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.w != inputs[1]->tensor.shape.w &&
inputs[1]->tensor.shape.w == 1;
}
bool IsHeightBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.h != inputs[1]->tensor.shape.h &&
inputs[1]->tensor.shape.h == 1;
}
bool IsChannelsBroadcastedForSecondInput(
const std::vector<Value<TensorRef<BHWC>>*>& inputs) {
return inputs.size() == 2 &&
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c &&
inputs[1]->tensor.shape.c == 1;
}
} // namespace
Status GPUOperationFromNode(const CreationContext& creation_context,
const OperationDef& op_def, ModelHints hints,
@ -59,12 +80,23 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
if (adds || adds_scalar) {
return SelectBroadcastAdd(attr, creation_context, op_def, gpu_op);
} else {
auto output = outputs[0];
std::vector<int> channels(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
channels[i] = inputs[i]->tensor.shape.c;
BroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
if (broadcast.width || broadcast.height || broadcast.channels) {
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, op_type, broadcast);
*gpu_op =
absl::make_unique<ElementwiseTwoInput>(std::move(operation));
} else {
auto output = outputs[0];
std::vector<int> channels(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
channels[i] = inputs[i]->tensor.shape.c;
}
SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op);
}
SelectAdd(op_def, channels, output->tensor.shape.c, gpu_op);
return OkStatus();
}
}
@ -121,8 +153,20 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
} else {
SelectApplyMask(op_def, inputs[0]->tensor.shape,
inputs[1]->tensor.shape, gpu_op);
if (inputs.size() == 2) {
BroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, op_type, broadcast);
*gpu_op =
absl::make_unique<ElementwiseTwoInput>(std::move(operation));
return OkStatus();
} else {
return UnimplementedError(
"No support of multiply with more than 2 inputs");
}
return OkStatus();
}
}
@ -190,8 +234,12 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
case OperationType::POW:
case OperationType::SQUARED_DIFF:
case OperationType::SUB: {
BroadcastSettings broadcast;
broadcast.width = IsWidthBroadcastedForSecondInput(inputs);
broadcast.height = IsHeightBroadcastedForSecondInput(inputs);
broadcast.channels = IsChannelsBroadcastedForSecondInput(inputs);
ElementwiseTwoInput operation =
CreateElementwiseTwoInput(op_def, op_type);
CreateElementwiseTwoInput(op_def, op_type, broadcast);
*gpu_op = absl::make_unique<ElementwiseTwoInput>(std::move(operation));
return OkStatus();
}

View File

@ -43,13 +43,6 @@ namespace tflite {
namespace gpu {
namespace cl {
void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape,
const BHWC& mask_shape,
std::unique_ptr<GPUOperation>* ptr) {
ApplyMask operation = CreateApplyMask(op_def, src_shape, mask_shape);
*ptr = absl::make_unique<ApplyMask>(std::move(operation));
}
void SelectLSTM(const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
LSTM operation = CreateLSTM(op_def);

View File

@ -27,10 +27,6 @@ namespace tflite {
namespace gpu {
namespace cl {
void SelectApplyMask(const OperationDef& op_def, const BHWC& src_shape,
const BHWC& mask_shape,
std::unique_ptr<GPUOperation>* ptr);
void SelectLSTM(const OperationDef& op_def, std::unique_ptr<GPUOperation>* ptr);
void SelectReLU(const CreationContext& creation_context,