Make elementwise operations with two inputs support all of the cases: elementwise, scalar, broadcast and const vector.

PiperOrigin-RevId: 302066884
Change-Id: I94a7497f006b466cc6d7a1b1fdba090b4ef30a00
This commit is contained in:
A. Unique TensorFlower 2020-03-20 11:32:07 -07:00 committed by TensorFlower Gardener
parent 2fd08c48a3
commit 16051cb33c
4 changed files with 523 additions and 294 deletions

View File

@ -236,6 +236,12 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
return number_of_runtime_inputs;
}
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
return tflite_node->inputs->size -
GetNumberOfRuntimeInputsForNode(context, tflite_node);
}
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
const TfLiteNode* tflite_node) {
int number_of_runtime_outputs = 0;
@ -258,6 +264,42 @@ Status CheckTensorIsAvailable(const TfLiteContext* context,
return OkStatus();
}
Status CheckInputsOutputs(const TfLiteContext* context,
const TfLiteNode* tflite_node, int runtime_inputs,
int outputs) {
int runtime_inputs_from_model =
GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs_from_model != runtime_inputs) {
return InternalError(absl::StrFormat(
"Expected %d runtime input tensor(s), but node has %d runtime "
"input(s).",
runtime_inputs, runtime_inputs_from_model));
}
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != outputs) {
return InternalError(
absl::StrFormat("Expected %d output tensor(s), but node has %d "
"output(s).",
outputs, runtime_outputs));
}
return OkStatus();
}
Status CheckInputsConstsOutputs(const TfLiteContext* context,
const TfLiteNode* tflite_node,
int runtime_inputs, int const_inputs,
int outputs) {
int const_inputs_from_model =
GetNumberOfConstInputsForNode(context, tflite_node);
if (const_inputs_from_model != const_inputs) {
return InternalError(absl::StrFormat(
"Expected %d const input tensor(s), but node has %d const "
"input(s).",
const_inputs, const_inputs_from_model));
}
return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
}
class ObjectReader {
public:
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
@ -367,6 +409,13 @@ class ObjectReader {
: nullptr;
}
Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node,
int runtime_inputs, int const_inputs,
int outputs) {
return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs,
const_inputs, outputs);
}
private:
GraphFloat32* graph_ = nullptr;
const TfLiteContext* context_ = nullptr;
@ -374,59 +423,6 @@ class ObjectReader {
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
};
Status CheckInputsOutputs(const TfLiteContext* context,
const TfLiteNode* tflite_node, int inputs,
int outputs) {
int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node);
if (runtime_inputs != inputs) {
return InternalError(
absl::StrFormat("Expected %d input tensor(s), but node has %d runtime "
"input(s).",
inputs, runtime_inputs));
}
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != outputs) {
return InternalError(
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
"output(s).",
outputs, runtime_outputs));
}
return OkStatus();
}
// The function checks input tensors including 1 constant tensor.
Status CheckInputsOutputsAllowingOneConstInput(const TfLiteContext* context,
const TfLiteNode* tflite_node,
int inputs, int outputs) {
int number_of_const_inputs = 0;
int number_of_runtime_inputs = 0;
for (int i = 0; i < tflite_node->inputs->size; i++) {
if (IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
number_of_const_inputs++;
} else {
number_of_runtime_inputs++;
}
}
if (tflite_node->inputs->size != inputs) {
return InternalError(absl::StrFormat(
"Expected %d input tensor(s), but node has %d input(s).", inputs,
tflite_node->inputs->size));
}
if (number_of_const_inputs > 1) {
return InternalError(absl::StrFormat(
"Expected 1 const input tensor, but node has %d const input(s).",
number_of_const_inputs));
}
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
if (runtime_outputs != outputs) {
return InternalError(
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
"output(s).",
outputs, runtime_outputs));
}
return OkStatus();
}
// A parser responsible for parsing TFLite operation and adding it to a graph.
class TFLiteOperationParser {
public:
@ -893,8 +889,8 @@ class Conv2DOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteConvParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
@ -977,8 +973,8 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
TfLiteDepthwiseConvParams* tf_options;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
@ -1095,16 +1091,20 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
/*outputs=*/1));
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
} else if (IsTwoArgumentOperation()) {
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2,
/*outputs=*/1));
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1));
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(CheckInputsOutputsAllowingOneConstInput(context,
tflite_node,
/*inputs=*/2,
/*outputs=*/1));
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
} else {
return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
}
@ -1120,8 +1120,17 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
node->operation.type = ToString(operation_type_);
if (IsOneArgumentOperation()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/0,
/*outputs=*/1));
RETURN_IF_ERROR(reader->AddInput(node, 0));
} else if (IsTwoArgumentOperation()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/2,
/*const_inputs=*/0,
/*outputs=*/1));
if (tflite_node->inputs->size != 2) {
return InvalidArgumentError("Applies only two input tensors");
}
@ -1156,14 +1165,12 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
}
} else if (IsTwoArgumentOperationWithConst()) {
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
/*runtime_inputs=*/1,
/*const_inputs=*/1,
/*outputs=*/1));
ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
auto const_vector =
absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr.param);
if (const_vector) {
return InvalidArgumentError("Constant vector is not supported");
}
node->operation.attributes = std::move(attr);
} else {
return InvalidArgumentError("Incorrect operation type passed");
@ -1228,6 +1235,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
switch (operation_type_) {
case OperationType::MINIMUM:
case OperationType::MAXIMUM:
case OperationType::SUB:
return true;
default:
return false;
@ -1311,7 +1319,7 @@ class HardSwishOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration*) final {
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
@ -1350,7 +1358,8 @@ class LSTMOperationParser : public TFLiteOperationParser {
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2));
// TODO(eignasheva): Fix bad check.
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5,
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
// /*runtime_inputs=*/5,
// /*outputs=*/4));
TfLiteLSTMParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
@ -1599,8 +1608,8 @@ class PadOperationParser : public TFLiteOperationParser {
}
}
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
return OkStatus();
}
@ -1648,11 +1657,13 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
TfLitePoolParams* tf_options = nullptr;
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
if (status.ok()) { // custom case with indices as a second output
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/2));
} else { // common pooling with 1 output
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1,
/*outputs=*/1));
}
RETURN_IF_ERROR(CheckKernelsAndStrides(
@ -1752,8 +1763,8 @@ class ReshapeOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(eignasheva): add shape checking
return OkStatus();
}
@ -1786,8 +1797,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
bool align_corners;
@ -1974,8 +1985,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
TfLiteSoftmaxParams* tf_options = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
if (tf_options->beta != 1) {
@ -2018,8 +2029,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
// TODO(impjdi): Dims check.
TfLiteSpaceToDepthParams* s2d_params = nullptr;
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
@ -2280,8 +2291,8 @@ class TransposeOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return OkStatus();
}
@ -2317,8 +2328,8 @@ class Unpooling2DOperationParser : public TFLiteOperationParser {
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
TfLitePoolParams* tf_options = nullptr;
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
RETURN_IF_ERROR(CheckKernelsAndStrides(
tf_options->filter_height, tf_options->filter_width,
@ -2445,8 +2456,8 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/1, /*outputs=*/1));
return OkStatus();
}
@ -2478,8 +2489,8 @@ class TransformTensorOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return OkStatus();
}
@ -2515,8 +2526,8 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
RETURN_IF_ERROR(
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
/*runtime_inputs=*/2, /*outputs=*/1));
return OkStatus();
}
@ -2549,7 +2560,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
@ -2581,7 +2592,7 @@ class MeanOperationParser : public TFLiteOperationParser {
Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) final {
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
/*outputs=*/1);
}
@ -2970,7 +2981,6 @@ bool IsAllFloatTensors(const TfLiteContext* context,
}
return true;
}
} // namespace
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,

View File

@ -198,6 +198,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/gl:node_shader",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@ -130,89 +131,9 @@ class ElementwiseTwoArguments : public NodeShader {
return true;
}
Status ImplementElementwise(const GenerationContext& ctx,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::SUB: {
source = "value_0 -= value_1;";
break;
}
case OperationType::DIV: {
source = "value_0 /= value_1;";
break;
}
case OperationType::MAXIMUM: {
source = "value_0 = max(value_0, value_1);";
break;
}
case OperationType::MINIMUM: {
source = "value_0 = min(value_0, value_1);";
break;
}
case OperationType::POW: {
// From documentation :
// The result is undefined if x<0 or if x=0 and y≤0.
source = "value_0 = pow(value_0, value_1);";
break;
}
case OperationType::SQUARED_DIFF: {
source = "value_0 = (value_0 - value_1) * (value_0 - value_1);";
break;
}
default:
return InvalidArgumentError(
"Incorrect elementwise with two arguments operation type.");
}
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/source,
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
Status ImplementElementwiseWithScalar(const GenerationContext& ctx,
const float scalar,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::MAXIMUM: {
source = "value_0 = max(value_0, $scalar$);";
break;
}
case OperationType::MINIMUM: {
source = "value_0 = min(value_0, $scalar$);";
break;
}
default:
return InvalidArgumentError(
"Incorrect elementwise with scalar operation type.");
}
*generated_code = {
/*parameters=*/{{"scalar", scalar}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/source,
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
auto inputs = ctx.graph->FindInputs(ctx.node->id);
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
if (inputs.size() != 2) {
return false;
}
@ -223,57 +144,87 @@ class ElementwiseTwoArguments : public NodeShader {
return true;
}
Status ImplementElementwiseBroadcast(const GenerationContext& ctx,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::SQUARED_DIFF: {
source = R"(
vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ -
$input_data_1[0, 0, gid.z]$;
value_0 = diff * diff;
)";
break;
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
std::vector<Variable> parameters;
std::vector<std::pair<std::string, Object>> objects;
std::string argument0, argument1;
if (IsSupportedElemwise(ctx)) {
argument0 = "value_0";
argument1 = "value_1";
} else if (IsSupportedBroadcast(ctx)) {
argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
argument1 = "$input_data_1[0, 0, gid.z]$";
} else { // Scalar of const vector case
const ElementwiseAttributes* attr = absl::any_cast<ElementwiseAttributes>(
&ctx.node->operation.attributes);
if (!attr) {
return InvalidArgumentError(
"Couldn't read attributes for the scalar of const vector case.");
}
auto* tensor =
absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr->param);
auto* scalar = absl::get_if<float>(&attr->param);
if (!tensor && !scalar) {
return InvalidArgumentError(
"Couldn't read scalar of const vector data from the attributes.");
}
argument0 = "value_0";
if (tensor) {
argument1 = "$const_data[gid.z]$";
objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
} else {
argument1 = "vec4($const_data$)";
parameters.push_back({"const_data", *scalar});
}
}
std::string source;
switch (operation_type_) {
case OperationType::DIV: {
source = "value_0 = $0/$1;";
break;
}
case OperationType::MAXIMUM: {
source = "value_0 = max($0, $1);";
break;
}
case OperationType::MINIMUM: {
source = "value_0 = min($0, $1);";
break;
}
case OperationType::SQUARED_DIFF: {
source = "value_0 = ($0 - $1) * ($0 - $1);";
break;
}
case OperationType::SUB: {
source = "value_0 = $0 - $1;";
break;
}
case OperationType::POW: {
source = "value_0 = pow($0, $1);";
break;
}
default:
return InvalidArgumentError(
"Incorrect elementwise with two arguments operation type.");
"Incorrect elementwise with scalar operation type.");
}
source = absl::Substitute(source, argument0, argument1);
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
/*parameters=*/std::move(parameters),
/*objects=*/std::move(objects),
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/source,
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
if (IsSupportedElemwise(ctx)) {
return ImplementElementwise(ctx, generated_code);
}
if (IsSupportedBroadcast(ctx)) {
return ImplementElementwiseBroadcast(ctx, generated_code);
}
const ElementwiseAttributes* attr =
absl::any_cast<ElementwiseAttributes>(&ctx.node->operation.attributes);
if (attr) {
auto scalar = absl::get_if<float>(&attr->param);
if (scalar) {
return ImplementElementwiseWithScalar(ctx, *scalar, generated_code);
}
}
return InvalidArgumentError(
"This case is not supported by elementwise with two arguments "
"operation");
}
private:
OperationType operation_type_;
};

View File

@ -36,7 +36,7 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
return tensor_ref;
}
TEST(ElementwiseTest, Abs) {
TEST(ElementwiseOneArgumentTest, Abs) {
OperationType op_type = OperationType::ABS;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
@ -48,7 +48,7 @@ TEST(ElementwiseTest, Abs) {
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
}
TEST(ElementwiseTest, Cos) {
TEST(ElementwiseOneArgumentTest, Cos) {
OperationType op_type = OperationType::COS;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
@ -60,21 +60,7 @@ TEST(ElementwiseTest, Cos) {
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
}
TEST(ElementwiseTest, Div) {
OperationType op_type = OperationType::DIV;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
}
TEST(ElementwiseTest, Exp) {
TEST(ElementwiseOneArgumentTest, Exp) {
OperationType op_type = OperationType::EXP;
const BHWC shape(1, 1, 1, 7);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
@ -90,7 +76,7 @@ TEST(ElementwiseTest, Exp) {
std::exp(-0.01f)}));
}
TEST(ElementwiseTest, HardSwish) {
TEST(ElementwiseOneArgumentTest, HardSwish) {
OperationType op_type = OperationType::HARD_SWISH;
const BHWC shape(1, 1, 1, 7);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
@ -104,7 +90,7 @@ TEST(ElementwiseTest, HardSwish) {
{0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}));
}
TEST(ElementwiseTest, Log) {
TEST(ElementwiseOneArgumentTest, Log) {
OperationType op_type = OperationType::LOG;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
@ -116,7 +102,142 @@ TEST(ElementwiseTest, Log) {
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
}
TEST(ElementwiseTest, Maximum) {
TEST(ElementwiseOneArgumentTest, Rsqrt) {
OperationType op_type = OperationType::RSQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
}
TEST(ElementwiseOneArgumentTest, Sigmoid) {
OperationType op_type = OperationType::SIGMOID;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
}
TEST(ElementwiseOneArgumentTest, Sin) {
OperationType op_type = OperationType::SIN;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
}
TEST(ElementwiseOneArgumentTest, Sqrt) {
OperationType op_type = OperationType::SQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
}
TEST(ElementwiseOneArgumentTest, Square) {
OperationType op_type = OperationType::SQUARE;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
}
TEST(ElementwiseOneArgumentTest, Tanh) {
OperationType op_type = OperationType::TANH;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
}
TEST(ElementwiseTwoArgumentsTest, DivElementwise) {
OperationType op_type = OperationType::DIV;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
}
TEST(ElementwiseTwoArgumentsTest, DivBroadcast) {
OperationType op_type = OperationType::DIV;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 5.0, 4.0, 15.0}));
}
TEST(ElementwiseTwoArgumentsTest, DivScalar) {
OperationType op_type = OperationType::DIV;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
attr.param = static_cast<float>(0.5);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 2.0, 4.0, 6.0}));
}
TEST(ElementwiseTwoArgumentsTest, DivConstVector) {
OperationType op_type = OperationType::DIV;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {0.4, 0.5};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 2.0, 5.0, 6.0}));
}
TEST(ElementwiseTwoArgumentsTest, MaximumElementwise) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
@ -130,7 +251,22 @@ TEST(ElementwiseTest, Maximum) {
Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0}));
}
TEST(ElementwiseTest, MaximumWithScalar) {
TEST(ElementwiseTwoArgumentsTest, MaximumBroadcast) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.5, 1.0, 2.0, 3.0}));
}
TEST(ElementwiseTwoArgumentsTest, MaximumScalar) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape(1, 2, 2, 1);
ElementwiseAttributes attr;
@ -145,7 +281,27 @@ TEST(ElementwiseTest, MaximumWithScalar) {
Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0}));
}
TEST(ElementwiseTest, Minimum) {
TEST(ElementwiseTwoArgumentsTest, MaximumConstVector) {
OperationType op_type = OperationType::MAXIMUM;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {0.4, 0.5};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.4, 1.0, 2.0, 3.0}));
}
TEST(ElementwiseTwoArgumentsTest, MinimumElementwise) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
@ -159,7 +315,22 @@ TEST(ElementwiseTest, Minimum) {
Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0}));
}
TEST(ElementwiseTest, MinimumWithScalar) {
TEST(ElementwiseTwoArgumentsTest, MinimumBroadcast) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2}));
}
TEST(ElementwiseTwoArgumentsTest, MinimumScalar) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape(1, 2, 2, 1);
ElementwiseAttributes attr;
@ -174,7 +345,27 @@ TEST(ElementwiseTest, MinimumWithScalar) {
Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0}));
}
TEST(ElementwiseTest, Pow) {
TEST(ElementwiseTwoArgumentsTest, MinimumConstVector) {
OperationType op_type = OperationType::MINIMUM;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {0.5, 0.2};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2}));
}
TEST(ElementwiseTwoArgumentsTest, PowElementwise) {
OperationType op_type = OperationType::POW;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
@ -188,67 +379,57 @@ TEST(ElementwiseTest, Pow) {
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
}
TEST(ElementwiseTest, Rsqrt) {
OperationType op_type = OperationType::RSQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
TEST(ElementwiseTwoArgumentsTest, PowBroadcast) {
OperationType op_type = OperationType::POW;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_TRUE(model.PopulateTensor(1, {2.0, 0.5}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0}));
}
TEST(ElementwiseTest, Sigmoid) {
OperationType op_type = OperationType::SIGMOID;
TEST(ElementwiseTwoArgumentsTest, PowScalar) {
OperationType op_type = OperationType::POW;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
}
TEST(ElementwiseTest, Sin) {
OperationType op_type = OperationType::SIN;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
}
TEST(ElementwiseTest, Sqrt) {
OperationType op_type = OperationType::SQRT;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ElementwiseAttributes attr;
attr.param = 2.0f;
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/std::move(attr)},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(2, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 16.0}));
}
TEST(ElementwiseTest, Square) {
OperationType op_type = OperationType::SQUARE;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
TEST(ElementwiseTwoArgumentsTest, PowConstVector) {
OperationType op_type = OperationType::POW;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {2.0, 0.5};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0}));
}
TEST(ElementwiseTest, SquaredDiff) {
TEST(ElementwiseTwoArgumentsTest, SquaredDiffElementwise) {
OperationType op_type = OperationType::SQUARED_DIFF;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
@ -262,7 +443,56 @@ TEST(ElementwiseTest, SquaredDiff) {
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
}
TEST(ElementwiseTest, Sub) {
TEST(ElementwiseTwoArgumentsTest, SquaredDiffBroadcast) {
OperationType op_type = OperationType::SQUARED_DIFF;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {-1.0, 5.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0}));
}
TEST(ElementwiseTwoArgumentsTest, SquaredDiffScalar) {
OperationType op_type = OperationType::SQUARED_DIFF;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
attr.param = static_cast<float>(5.0);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {25.0, 16.0, 9.0, 4.0}));
}
TEST(ElementwiseTwoArgumentsTest, SquaredDiffConstVector) {
OperationType op_type = OperationType::SQUARED_DIFF;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {-1.0, 5.0};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0}));
}
TEST(ElementwiseTwoArgumentsTest, SubElementwise) {
OperationType op_type = OperationType::SUB;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model(
@ -276,16 +506,53 @@ TEST(ElementwiseTest, Sub) {
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
}
TEST(ElementwiseTest, Tanh) {
OperationType op_type = OperationType::TANH;
const BHWC shape(1, 2, 2, 1);
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape)},
/*outputs=*/{GetTensorRef(1, shape)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
TEST(ElementwiseTwoArgumentsTest, SubBroadcast) {
OperationType op_type = OperationType::SUB;
const BHWC shape0(1, 2, 1, 2);
const BHWC shape1(1, 1, 1, 2);
SingleOpModel model(
{/*type=*/ToString(op_type), /*attributes=*/{}},
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_TRUE(model.PopulateTensor(1, {0.3, 0.2}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8}));
}
TEST(ElementwiseTwoArgumentsTest, SubScalar) {
OperationType op_type = OperationType::SUB;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
attr.param = static_cast<float>(0.5);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-0.5, 0.5, 1.5, 2.5}));
}
TEST(ElementwiseTwoArgumentsTest, SubConstVector) {
OperationType op_type = OperationType::SUB;
const BHWC shape0(1, 2, 1, 2);
ElementwiseAttributes attr;
Tensor<Linear, DataType::FLOAT32> param;
param.shape = Linear(2);
param.id = 1;
param.data = {0.3, 0.2};
attr.param = std::move(param);
SingleOpModel model({/*type=*/ToString(op_type), attr},
/*inputs=*/{GetTensorRef(0, shape0)},
/*outputs=*/{GetTensorRef(2, shape0)});
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
EXPECT_THAT(model.GetOutput(0),
Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8}));
}
} // namespace