diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 2a03ff9ff14..b37c3542413 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -236,6 +236,12 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context, return number_of_runtime_inputs; } +int GetNumberOfConstInputsForNode(const TfLiteContext* context, + const TfLiteNode* tflite_node) { + return tflite_node->inputs->size - + GetNumberOfRuntimeInputsForNode(context, tflite_node); +} + int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context, const TfLiteNode* tflite_node) { int number_of_runtime_outputs = 0; @@ -258,6 +264,42 @@ Status CheckTensorIsAvailable(const TfLiteContext* context, return OkStatus(); } +Status CheckInputsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, int runtime_inputs, + int outputs) { + int runtime_inputs_from_model = + GetNumberOfRuntimeInputsForNode(context, tflite_node); + if (runtime_inputs_from_model != runtime_inputs) { + return InternalError(absl::StrFormat( + "Expected %d runtime input tensor(s), but node has %d runtime " + "input(s).", + runtime_inputs, runtime_inputs_from_model)); + } + int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); + if (runtime_outputs != outputs) { + return InternalError( + absl::StrFormat("Expected %d output tensor(s), but node has %d " + "output(s).", + outputs, runtime_outputs)); + } + return OkStatus(); +} + +Status CheckInputsConstsOutputs(const TfLiteContext* context, + const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { + int const_inputs_from_model = + GetNumberOfConstInputsForNode(context, tflite_node); + if (const_inputs_from_model != const_inputs) { + return InternalError(absl::StrFormat( + "Expected %d const input tensor(s), but node has %d const " + "input(s).", + const_inputs, const_inputs_from_model)); + } + return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs); +} + class ObjectReader { public: ObjectReader(GraphFloat32* graph, TfLiteContext* context, @@ -367,6 +409,13 @@ class ObjectReader { : nullptr; } + Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node, + int runtime_inputs, int const_inputs, + int outputs) { + return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs, + const_inputs, outputs); + } + private: GraphFloat32* graph_ = nullptr; const TfLiteContext* context_ = nullptr; @@ -374,59 +423,6 @@ class ObjectReader { std::vector>*>* tensor_to_value_; }; -Status CheckInputsOutputs(const TfLiteContext* context, - const TfLiteNode* tflite_node, int inputs, - int outputs) { - int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node); - if (runtime_inputs != inputs) { - return InternalError( - absl::StrFormat("Expected %d input tensor(s), but node has %d runtime " - "input(s).", - inputs, runtime_inputs)); - } - int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); - if (runtime_outputs != outputs) { - return InternalError( - absl::StrFormat("Expected %d output tensor(s), but node has %d runtime " - "output(s).", - outputs, runtime_outputs)); - } - return OkStatus(); -} - -// The function checks input tensors including 1 constant tensor. -Status CheckInputsOutputsAllowingOneConstInput(const TfLiteContext* context, - const TfLiteNode* tflite_node, - int inputs, int outputs) { - int number_of_const_inputs = 0; - int number_of_runtime_inputs = 0; - for (int i = 0; i < tflite_node->inputs->size; i++) { - if (IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) { - number_of_const_inputs++; - } else { - number_of_runtime_inputs++; - } - } - if (tflite_node->inputs->size != inputs) { - return InternalError(absl::StrFormat( - "Expected %d input tensor(s), but node has %d input(s).", inputs, - tflite_node->inputs->size)); - } - if (number_of_const_inputs > 1) { - return InternalError(absl::StrFormat( - "Expected 1 const input tensor, but node has %d const input(s).", - number_of_const_inputs)); - } - int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node); - if (runtime_outputs != outputs) { - return InternalError( - absl::StrFormat("Expected %d output tensor(s), but node has %d runtime " - "output(s).", - outputs, runtime_outputs)); - } - return OkStatus(); -} - // A parser responsible for parsing TFLite operation and adding it to a graph. class TFLiteOperationParser { public: @@ -893,8 +889,8 @@ class Conv2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteConvParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -977,8 +973,8 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); TfLiteDepthwiseConvParams* tf_options; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1095,16 +1091,20 @@ class ElementwiseOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); if (IsOneArgumentOperation()) { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/0, + /*outputs=*/1)); } else if (IsTwoArgumentOperation()) { - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1)); } else if (IsTwoArgumentOperationWithConst()) { - RETURN_IF_ERROR(CheckInputsOutputsAllowingOneConstInput(context, - tflite_node, - /*inputs=*/2, - /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/1, + /*outputs=*/1)); } else { return InvalidArgumentError("Op can only handle 1 or 2 operand(s)."); } @@ -1120,8 +1120,17 @@ class ElementwiseOperationParser : public TFLiteOperationParser { node->operation.type = ToString(operation_type_); if (IsOneArgumentOperation()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/0, + /*outputs=*/1)); + RETURN_IF_ERROR(reader->AddInput(node, 0)); } else if (IsTwoArgumentOperation()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/2, + /*const_inputs=*/0, + /*outputs=*/1)); if (tflite_node->inputs->size != 2) { return InvalidArgumentError("Applies only two input tensors"); } @@ -1156,14 +1165,12 @@ class ElementwiseOperationParser : public TFLiteOperationParser { MaybeFuseActivationToTheSingleOutput(activation, graph, node)); } } else if (IsTwoArgumentOperationWithConst()) { + RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node, + /*runtime_inputs=*/1, + /*const_inputs=*/1, + /*outputs=*/1)); ElementwiseAttributes attr; RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); - auto const_vector = - absl::get_if<::tflite::gpu::Tensor>( - &attr.param); - if (const_vector) { - return InvalidArgumentError("Constant vector is not supported"); - } node->operation.attributes = std::move(attr); } else { return InvalidArgumentError("Incorrect operation type passed"); @@ -1228,6 +1235,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser { switch (operation_type_) { case OperationType::MINIMUM: case OperationType::MAXIMUM: + case OperationType::SUB: return true; default: return false; @@ -1311,7 +1319,7 @@ class HardSwishOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration*) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -1350,7 +1358,8 @@ class LSTMOperationParser : public TFLiteOperationParser { const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2)); // TODO(eignasheva): Fix bad check. - // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5, + // RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + // /*runtime_inputs=*/5, // /*outputs=*/4)); TfLiteLSTMParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); @@ -1599,8 +1608,8 @@ class PadOperationParser : public TFLiteOperationParser { } } RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); return OkStatus(); } @@ -1648,11 +1657,13 @@ class Pooling2DOperationParser : public TFLiteOperationParser { TfLitePoolParams* tf_options = nullptr; auto status = RetrieveCustomInitialData(tflite_node, &tf_options); if (status.ok()) { // custom case with indices as a second output - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/2)); } else { // common pooling with 1 output RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); - RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); } RETURN_IF_ERROR(CheckKernelsAndStrides( @@ -1752,8 +1763,8 @@ class ReshapeOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(eignasheva): add shape checking return OkStatus(); } @@ -1786,8 +1797,8 @@ class Resize2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node)); bool align_corners; @@ -1974,8 +1985,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); TfLiteSoftmaxParams* tf_options = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options)); if (tf_options->beta != 1) { @@ -2018,8 +2029,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); // TODO(impjdi): Dims check. TfLiteSpaceToDepthParams* s2d_params = nullptr; RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params)); @@ -2280,8 +2291,8 @@ class TransposeOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1)); - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); return OkStatus(); } @@ -2317,8 +2328,8 @@ class Unpooling2DOperationParser : public TFLiteOperationParser { const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { TfLitePoolParams* tf_options = nullptr; - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options)); RETURN_IF_ERROR(CheckKernelsAndStrides( tf_options->filter_height, tf_options->filter_width, @@ -2445,8 +2456,8 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/1, /*outputs=*/1)); return OkStatus(); } @@ -2478,8 +2489,8 @@ class TransformTensorOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); return OkStatus(); } @@ -2515,8 +2526,8 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - RETURN_IF_ERROR( - CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1)); + RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, + /*runtime_inputs=*/2, /*outputs=*/1)); return OkStatus(); } @@ -2549,7 +2560,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2581,7 +2592,7 @@ class MeanOperationParser : public TFLiteOperationParser { Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, const TfLiteRegistration* registration) final { - return CheckInputsOutputs(context, tflite_node, /*inputs=*/1, + return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1); } @@ -2970,7 +2981,6 @@ bool IsAllFloatTensors(const TfLiteContext* context, } return true; } - } // namespace Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor, diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD index 30d759df724..d2ef617a8e2 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/kernels/BUILD @@ -198,6 +198,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc index 941a32a8769..35b233cbdcc 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/substitute.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -130,89 +131,9 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - Status ImplementElementwise(const GenerationContext& ctx, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::SUB: { - source = "value_0 -= value_1;"; - break; - } - case OperationType::DIV: { - source = "value_0 /= value_1;"; - break; - } - case OperationType::MAXIMUM: { - source = "value_0 = max(value_0, value_1);"; - break; - } - case OperationType::MINIMUM: { - source = "value_0 = min(value_0, value_1);"; - break; - } - case OperationType::POW: { - // From documentation : - // The result is undefined if x<0 or if x=0 and y≤0. - source = "value_0 = pow(value_0, value_1);"; - break; - } - case OperationType::SQUARED_DIFF: { - source = "value_0 = (value_0 - value_1) * (value_0 - value_1);"; - break; - } - - default: - return InvalidArgumentError( - "Incorrect elementwise with two arguments operation type."); - } - *generated_code = { - /*parameters=*/{}, - /*objects=*/{}, - /*shared_variables=*/{}, - /*workload=*/uint3(), - /*workgroup=*/uint3(), - /*source_code=*/source, - /*input=*/IOStructure::AUTO, - /*output=*/IOStructure::AUTO, - }; - return OkStatus(); - } - - Status ImplementElementwiseWithScalar(const GenerationContext& ctx, - const float scalar, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::MAXIMUM: { - source = "value_0 = max(value_0, $scalar$);"; - break; - } - case OperationType::MINIMUM: { - source = "value_0 = min(value_0, $scalar$);"; - break; - } - - default: - return InvalidArgumentError( - "Incorrect elementwise with scalar operation type."); - } - *generated_code = { - /*parameters=*/{{"scalar", scalar}}, - /*objects=*/{}, - /*shared_variables=*/{}, - /*workload=*/uint3(), - /*workgroup=*/uint3(), - /*source_code=*/source, - /*input=*/IOStructure::AUTO, - /*output=*/IOStructure::AUTO, - }; - return OkStatus(); - } - bool IsSupportedBroadcast(const GenerationContext& ctx) const { auto inputs = ctx.graph->FindInputs(ctx.node->id); auto outputs = ctx.graph->FindOutputs(ctx.node->id); - if (inputs.size() != 2) { return false; } @@ -223,57 +144,87 @@ class ElementwiseTwoArguments : public NodeShader { return true; } - Status ImplementElementwiseBroadcast(const GenerationContext& ctx, - GeneratedCode* generated_code) const { - std::string source; - switch (operation_type_) { - case OperationType::SQUARED_DIFF: { - source = R"( - vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ - - $input_data_1[0, 0, gid.z]$; - value_0 = diff * diff; - )"; - break; + Status GenerateCode(const GenerationContext& ctx, + GeneratedCode* generated_code) const final { + std::vector parameters; + std::vector> objects; + std::string argument0, argument1; + if (IsSupportedElemwise(ctx)) { + argument0 = "value_0"; + argument1 = "value_1"; + } else if (IsSupportedBroadcast(ctx)) { + argument0 = "$input_data_0[gid.x, gid.y, gid.z]$"; + argument1 = "$input_data_1[0, 0, gid.z]$"; + } else { // Scalar of const vector case + const ElementwiseAttributes* attr = absl::any_cast( + &ctx.node->operation.attributes); + if (!attr) { + return InvalidArgumentError( + "Couldn't read attributes for the scalar of const vector case."); + } + auto* tensor = + absl::get_if<::tflite::gpu::Tensor>( + &attr->param); + auto* scalar = absl::get_if(&attr->param); + if (!tensor && !scalar) { + return InvalidArgumentError( + "Couldn't read scalar of const vector data from the attributes."); } + argument0 = "value_0"; + if (tensor) { + argument1 = "$const_data[gid.z]$"; + objects.push_back({"const_data", MakeReadonlyObject(tensor->data)}); + } else { + argument1 = "vec4($const_data$)"; + parameters.push_back({"const_data", *scalar}); + } + } + + std::string source; + switch (operation_type_) { + case OperationType::DIV: { + source = "value_0 = $0/$1;"; + break; + } + case OperationType::MAXIMUM: { + source = "value_0 = max($0, $1);"; + break; + } + case OperationType::MINIMUM: { + source = "value_0 = min($0, $1);"; + break; + } + case OperationType::SQUARED_DIFF: { + source = "value_0 = ($0 - $1) * ($0 - $1);"; + break; + } + case OperationType::SUB: { + source = "value_0 = $0 - $1;"; + break; + } + case OperationType::POW: { + source = "value_0 = pow($0, $1);"; + break; + } default: return InvalidArgumentError( - "Incorrect elementwise with two arguments operation type."); + "Incorrect elementwise with scalar operation type."); } + source = absl::Substitute(source, argument0, argument1); *generated_code = { - /*parameters=*/{}, - /*objects=*/{}, + /*parameters=*/std::move(parameters), + /*objects=*/std::move(objects), /*shared_variables=*/{}, /*workload=*/uint3(), /*workgroup=*/uint3(), /*source_code=*/source, - /*input=*/IOStructure::ONLY_DEFINITIONS, + /*input=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO, }; return OkStatus(); } - Status GenerateCode(const GenerationContext& ctx, - GeneratedCode* generated_code) const final { - if (IsSupportedElemwise(ctx)) { - return ImplementElementwise(ctx, generated_code); - } - if (IsSupportedBroadcast(ctx)) { - return ImplementElementwiseBroadcast(ctx, generated_code); - } - const ElementwiseAttributes* attr = - absl::any_cast(&ctx.node->operation.attributes); - if (attr) { - auto scalar = absl::get_if(&attr->param); - if (scalar) { - return ImplementElementwiseWithScalar(ctx, *scalar, generated_code); - } - } - return InvalidArgumentError( - "This case is not supported by elementwise with two arguments " - "operation"); - } - private: OperationType operation_type_; }; diff --git a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc index 3316395f5e3..625a09eebf4 100644 --- a/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/kernels/elementwise_test.cc @@ -36,7 +36,7 @@ TensorRef GetTensorRef(int ref, const BHWC& shape) { return tensor_ref; } -TEST(ElementwiseTest, Abs) { +TEST(ElementwiseOneArgumentTest, Abs) { OperationType op_type = OperationType::ABS; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -48,7 +48,7 @@ TEST(ElementwiseTest, Abs) { Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0})); } -TEST(ElementwiseTest, Cos) { +TEST(ElementwiseOneArgumentTest, Cos) { OperationType op_type = OperationType::COS; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -60,21 +60,7 @@ TEST(ElementwiseTest, Cos) { Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302})); } -TEST(ElementwiseTest, Div) { - OperationType op_type = OperationType::DIV; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model( - {/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)}, - /*outputs=*/{GetTensorRef(2, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); - ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); -} - -TEST(ElementwiseTest, Exp) { +TEST(ElementwiseOneArgumentTest, Exp) { OperationType op_type = OperationType::EXP; const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -90,7 +76,7 @@ TEST(ElementwiseTest, Exp) { std::exp(-0.01f)})); } -TEST(ElementwiseTest, HardSwish) { +TEST(ElementwiseOneArgumentTest, HardSwish) { OperationType op_type = OperationType::HARD_SWISH; const BHWC shape(1, 1, 1, 7); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -104,7 +90,7 @@ TEST(ElementwiseTest, HardSwish) { {0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f})); } -TEST(ElementwiseTest, Log) { +TEST(ElementwiseOneArgumentTest, Log) { OperationType op_type = OperationType::LOG; const BHWC shape(1, 2, 2, 1); SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, @@ -116,7 +102,142 @@ TEST(ElementwiseTest, Log) { Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0})); } -TEST(ElementwiseTest, Maximum) { +TEST(ElementwiseOneArgumentTest, Rsqrt) { + OperationType op_type = OperationType::RSQRT; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333})); +} + +TEST(ElementwiseOneArgumentTest, Sigmoid) { + OperationType op_type = OperationType::SIGMOID; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014})); +} + +TEST(ElementwiseOneArgumentTest, Sin) { + OperationType op_type = OperationType::SIN; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471})); +} + +TEST(ElementwiseOneArgumentTest, Sqrt) { + OperationType op_type = OperationType::SQRT; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0})); +} + +TEST(ElementwiseOneArgumentTest, Square) { + OperationType op_type = OperationType::SQUARE; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0})); +} + +TEST(ElementwiseOneArgumentTest, Tanh) { + OperationType op_type = OperationType::TANH; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(1, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329})); +} + +TEST(ElementwiseTwoArgumentsTest, DivElementwise) { + OperationType op_type = OperationType::DIV; + const BHWC shape(1, 2, 2, 1); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)}, + /*outputs=*/{GetTensorRef(2, shape)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivBroadcast) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 5.0, 4.0, 15.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivScalar) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(0.5); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 2.0, 4.0, 6.0})); +} + +TEST(ElementwiseTwoArgumentsTest, DivConstVector) { + OperationType op_type = OperationType::DIV; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.4, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 2.0, 5.0, 6.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MaximumElementwise) { OperationType op_type = OperationType::MAXIMUM; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -130,7 +251,22 @@ TEST(ElementwiseTest, Maximum) { Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0})); } -TEST(ElementwiseTest, MaximumWithScalar) { +TEST(ElementwiseTwoArgumentsTest, MaximumBroadcast) { + OperationType op_type = OperationType::MAXIMUM; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.5, 1.0, 2.0, 3.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MaximumScalar) { OperationType op_type = OperationType::MAXIMUM; const BHWC shape(1, 2, 2, 1); ElementwiseAttributes attr; @@ -145,7 +281,27 @@ TEST(ElementwiseTest, MaximumWithScalar) { Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0})); } -TEST(ElementwiseTest, Minimum) { +TEST(ElementwiseTwoArgumentsTest, MaximumConstVector) { + OperationType op_type = OperationType::MAXIMUM; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.4, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.4, 1.0, 2.0, 3.0})); +} + +TEST(ElementwiseTwoArgumentsTest, MinimumElementwise) { OperationType op_type = OperationType::MINIMUM; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -159,7 +315,22 @@ TEST(ElementwiseTest, Minimum) { Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0})); } -TEST(ElementwiseTest, MinimumWithScalar) { +TEST(ElementwiseTwoArgumentsTest, MinimumBroadcast) { + OperationType op_type = OperationType::MINIMUM; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2})); +} + +TEST(ElementwiseTwoArgumentsTest, MinimumScalar) { OperationType op_type = OperationType::MINIMUM; const BHWC shape(1, 2, 2, 1); ElementwiseAttributes attr; @@ -174,7 +345,27 @@ TEST(ElementwiseTest, MinimumWithScalar) { Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0})); } -TEST(ElementwiseTest, Pow) { +TEST(ElementwiseTwoArgumentsTest, MinimumConstVector) { + OperationType op_type = OperationType::MINIMUM; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.5, 0.2}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2})); +} + +TEST(ElementwiseTwoArgumentsTest, PowElementwise) { OperationType op_type = OperationType::POW; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -188,67 +379,57 @@ TEST(ElementwiseTest, Pow) { Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0})); } -TEST(ElementwiseTest, Rsqrt) { - OperationType op_type = OperationType::RSQRT; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0})); +TEST(ElementwiseTwoArgumentsTest, PowBroadcast) { + OperationType op_type = OperationType::POW; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); + ASSERT_TRUE(model.PopulateTensor(1, {2.0, 0.5})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0})); } -TEST(ElementwiseTest, Sigmoid) { - OperationType op_type = OperationType::SIGMOID; +TEST(ElementwiseTwoArgumentsTest, PowScalar) { + OperationType op_type = OperationType::POW; const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014})); -} - -TEST(ElementwiseTest, Sin) { - OperationType op_type = OperationType::SIN; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0})); - ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); - EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471})); -} - -TEST(ElementwiseTest, Sqrt) { - OperationType op_type = OperationType::SQRT; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); + ElementwiseAttributes attr; + attr.param = 2.0f; + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/std::move(attr)}, + /*inputs=*/{GetTensorRef(0, shape)}, + /*outputs=*/{GetTensorRef(2, shape)}); ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 16.0})); } -TEST(ElementwiseTest, Square) { - OperationType op_type = OperationType::SQUARE; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0})); +TEST(ElementwiseTwoArgumentsTest, PowConstVector) { + OperationType op_type = OperationType::POW; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {2.0, 0.5}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0})); + Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0})); } -TEST(ElementwiseTest, SquaredDiff) { +TEST(ElementwiseTwoArgumentsTest, SquaredDiffElementwise) { OperationType op_type = OperationType::SQUARED_DIFF; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -262,7 +443,56 @@ TEST(ElementwiseTest, SquaredDiff) { Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0})); } -TEST(ElementwiseTest, Sub) { +TEST(ElementwiseTwoArgumentsTest, SquaredDiffBroadcast) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {-1.0, 5.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SquaredDiffScalar) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(5.0); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {25.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SquaredDiffConstVector) { + OperationType op_type = OperationType::SQUARED_DIFF; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {-1.0, 5.0}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0})); +} + +TEST(ElementwiseTwoArgumentsTest, SubElementwise) { OperationType op_type = OperationType::SUB; const BHWC shape(1, 2, 2, 1); SingleOpModel model( @@ -276,16 +506,53 @@ TEST(ElementwiseTest, Sub) { Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0})); } -TEST(ElementwiseTest, Tanh) { - OperationType op_type = OperationType::TANH; - const BHWC shape(1, 2, 2, 1); - SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}}, - /*inputs=*/{GetTensorRef(0, shape)}, - /*outputs=*/{GetTensorRef(1, shape)}); - ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0})); +TEST(ElementwiseTwoArgumentsTest, SubBroadcast) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + const BHWC shape1(1, 1, 1, 2); + SingleOpModel model( + {/*type=*/ToString(op_type), /*attributes=*/{}}, + /*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_TRUE(model.PopulateTensor(1, {0.3, 0.2})); ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); EXPECT_THAT(model.GetOutput(0), - Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329})); + Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8})); +} + +TEST(ElementwiseTwoArgumentsTest, SubScalar) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + ElementwiseAttributes attr; + attr.param = static_cast(0.5); + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-0.5, 0.5, 1.5, 2.5})); +} + +TEST(ElementwiseTwoArgumentsTest, SubConstVector) { + OperationType op_type = OperationType::SUB; + const BHWC shape0(1, 2, 1, 2); + + ElementwiseAttributes attr; + Tensor param; + param.shape = Linear(2); + param.id = 1; + param.data = {0.3, 0.2}; + attr.param = std::move(param); + + SingleOpModel model({/*type=*/ToString(op_type), attr}, + /*inputs=*/{GetTensorRef(0, shape0)}, + /*outputs=*/{GetTensorRef(2, shape0)}); + ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0})); + ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type))); + EXPECT_THAT(model.GetOutput(0), + Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8})); } } // namespace