Make elementwise operations with two inputs support all of the cases: elementwise, scalar, broadcast and const vector.
PiperOrigin-RevId: 302066884 Change-Id: I94a7497f006b466cc6d7a1b1fdba090b4ef30a00
This commit is contained in:
parent
2fd08c48a3
commit
16051cb33c
@ -236,6 +236,12 @@ int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
|
|||||||
return number_of_runtime_inputs;
|
return number_of_runtime_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node) {
|
||||||
|
return tflite_node->inputs->size -
|
||||||
|
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||||
|
}
|
||||||
|
|
||||||
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
|
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node) {
|
const TfLiteNode* tflite_node) {
|
||||||
int number_of_runtime_outputs = 0;
|
int number_of_runtime_outputs = 0;
|
||||||
@ -258,6 +264,42 @@ Status CheckTensorIsAvailable(const TfLiteContext* context,
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckInputsOutputs(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node, int runtime_inputs,
|
||||||
|
int outputs) {
|
||||||
|
int runtime_inputs_from_model =
|
||||||
|
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||||
|
if (runtime_inputs_from_model != runtime_inputs) {
|
||||||
|
return InternalError(absl::StrFormat(
|
||||||
|
"Expected %d runtime input tensor(s), but node has %d runtime "
|
||||||
|
"input(s).",
|
||||||
|
runtime_inputs, runtime_inputs_from_model));
|
||||||
|
}
|
||||||
|
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
|
||||||
|
if (runtime_outputs != outputs) {
|
||||||
|
return InternalError(
|
||||||
|
absl::StrFormat("Expected %d output tensor(s), but node has %d "
|
||||||
|
"output(s).",
|
||||||
|
outputs, runtime_outputs));
|
||||||
|
}
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CheckInputsConstsOutputs(const TfLiteContext* context,
|
||||||
|
const TfLiteNode* tflite_node,
|
||||||
|
int runtime_inputs, int const_inputs,
|
||||||
|
int outputs) {
|
||||||
|
int const_inputs_from_model =
|
||||||
|
GetNumberOfConstInputsForNode(context, tflite_node);
|
||||||
|
if (const_inputs_from_model != const_inputs) {
|
||||||
|
return InternalError(absl::StrFormat(
|
||||||
|
"Expected %d const input tensor(s), but node has %d const "
|
||||||
|
"input(s).",
|
||||||
|
const_inputs, const_inputs_from_model));
|
||||||
|
}
|
||||||
|
return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
class ObjectReader {
|
class ObjectReader {
|
||||||
public:
|
public:
|
||||||
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
|
ObjectReader(GraphFloat32* graph, TfLiteContext* context,
|
||||||
@ -367,6 +409,13 @@ class ObjectReader {
|
|||||||
: nullptr;
|
: nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node,
|
||||||
|
int runtime_inputs, int const_inputs,
|
||||||
|
int outputs) {
|
||||||
|
return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs,
|
||||||
|
const_inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphFloat32* graph_ = nullptr;
|
GraphFloat32* graph_ = nullptr;
|
||||||
const TfLiteContext* context_ = nullptr;
|
const TfLiteContext* context_ = nullptr;
|
||||||
@ -374,59 +423,6 @@ class ObjectReader {
|
|||||||
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
std::vector<Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CheckInputsOutputs(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node, int inputs,
|
|
||||||
int outputs) {
|
|
||||||
int runtime_inputs = GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
|
||||||
if (runtime_inputs != inputs) {
|
|
||||||
return InternalError(
|
|
||||||
absl::StrFormat("Expected %d input tensor(s), but node has %d runtime "
|
|
||||||
"input(s).",
|
|
||||||
inputs, runtime_inputs));
|
|
||||||
}
|
|
||||||
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
|
|
||||||
if (runtime_outputs != outputs) {
|
|
||||||
return InternalError(
|
|
||||||
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
|
|
||||||
"output(s).",
|
|
||||||
outputs, runtime_outputs));
|
|
||||||
}
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The function checks input tensors including 1 constant tensor.
|
|
||||||
Status CheckInputsOutputsAllowingOneConstInput(const TfLiteContext* context,
|
|
||||||
const TfLiteNode* tflite_node,
|
|
||||||
int inputs, int outputs) {
|
|
||||||
int number_of_const_inputs = 0;
|
|
||||||
int number_of_runtime_inputs = 0;
|
|
||||||
for (int i = 0; i < tflite_node->inputs->size; i++) {
|
|
||||||
if (IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
|
|
||||||
number_of_const_inputs++;
|
|
||||||
} else {
|
|
||||||
number_of_runtime_inputs++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tflite_node->inputs->size != inputs) {
|
|
||||||
return InternalError(absl::StrFormat(
|
|
||||||
"Expected %d input tensor(s), but node has %d input(s).", inputs,
|
|
||||||
tflite_node->inputs->size));
|
|
||||||
}
|
|
||||||
if (number_of_const_inputs > 1) {
|
|
||||||
return InternalError(absl::StrFormat(
|
|
||||||
"Expected 1 const input tensor, but node has %d const input(s).",
|
|
||||||
number_of_const_inputs));
|
|
||||||
}
|
|
||||||
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
|
|
||||||
if (runtime_outputs != outputs) {
|
|
||||||
return InternalError(
|
|
||||||
absl::StrFormat("Expected %d output tensor(s), but node has %d runtime "
|
|
||||||
"output(s).",
|
|
||||||
outputs, runtime_outputs));
|
|
||||||
}
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
// A parser responsible for parsing TFLite operation and adding it to a graph.
|
// A parser responsible for parsing TFLite operation and adding it to a graph.
|
||||||
class TFLiteOperationParser {
|
class TFLiteOperationParser {
|
||||||
public:
|
public:
|
||||||
@ -893,8 +889,8 @@ class Conv2DOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||||
TfLiteConvParams* tf_options = nullptr;
|
TfLiteConvParams* tf_options = nullptr;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
@ -977,8 +973,8 @@ class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||||
TfLiteDepthwiseConvParams* tf_options;
|
TfLiteDepthwiseConvParams* tf_options;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
@ -1095,16 +1091,20 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
if (IsOneArgumentOperation()) {
|
if (IsOneArgumentOperation()) {
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
|
||||||
/*outputs=*/1));
|
/*runtime_inputs=*/1,
|
||||||
|
/*const_inputs=*/0,
|
||||||
|
/*outputs=*/1));
|
||||||
} else if (IsTwoArgumentOperation()) {
|
} else if (IsTwoArgumentOperation()) {
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/2,
|
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
|
||||||
/*outputs=*/1));
|
/*runtime_inputs=*/2,
|
||||||
|
/*const_inputs=*/0,
|
||||||
|
/*outputs=*/1));
|
||||||
} else if (IsTwoArgumentOperationWithConst()) {
|
} else if (IsTwoArgumentOperationWithConst()) {
|
||||||
RETURN_IF_ERROR(CheckInputsOutputsAllowingOneConstInput(context,
|
RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
|
||||||
tflite_node,
|
/*runtime_inputs=*/1,
|
||||||
/*inputs=*/2,
|
/*const_inputs=*/1,
|
||||||
/*outputs=*/1));
|
/*outputs=*/1));
|
||||||
} else {
|
} else {
|
||||||
return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
|
return InvalidArgumentError("Op can only handle 1 or 2 operand(s).");
|
||||||
}
|
}
|
||||||
@ -1120,8 +1120,17 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||||||
node->operation.type = ToString(operation_type_);
|
node->operation.type = ToString(operation_type_);
|
||||||
|
|
||||||
if (IsOneArgumentOperation()) {
|
if (IsOneArgumentOperation()) {
|
||||||
|
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
|
||||||
|
/*runtime_inputs=*/1,
|
||||||
|
/*const_inputs=*/0,
|
||||||
|
/*outputs=*/1));
|
||||||
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
} else if (IsTwoArgumentOperation()) {
|
} else if (IsTwoArgumentOperation()) {
|
||||||
|
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
|
||||||
|
/*runtime_inputs=*/2,
|
||||||
|
/*const_inputs=*/0,
|
||||||
|
/*outputs=*/1));
|
||||||
if (tflite_node->inputs->size != 2) {
|
if (tflite_node->inputs->size != 2) {
|
||||||
return InvalidArgumentError("Applies only two input tensors");
|
return InvalidArgumentError("Applies only two input tensors");
|
||||||
}
|
}
|
||||||
@ -1156,14 +1165,12 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||||||
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
|
MaybeFuseActivationToTheSingleOutput(activation, graph, node));
|
||||||
}
|
}
|
||||||
} else if (IsTwoArgumentOperationWithConst()) {
|
} else if (IsTwoArgumentOperationWithConst()) {
|
||||||
|
RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
|
||||||
|
/*runtime_inputs=*/1,
|
||||||
|
/*const_inputs=*/1,
|
||||||
|
/*outputs=*/1));
|
||||||
ElementwiseAttributes attr;
|
ElementwiseAttributes attr;
|
||||||
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
|
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
|
||||||
auto const_vector =
|
|
||||||
absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
|
||||||
&attr.param);
|
|
||||||
if (const_vector) {
|
|
||||||
return InvalidArgumentError("Constant vector is not supported");
|
|
||||||
}
|
|
||||||
node->operation.attributes = std::move(attr);
|
node->operation.attributes = std::move(attr);
|
||||||
} else {
|
} else {
|
||||||
return InvalidArgumentError("Incorrect operation type passed");
|
return InvalidArgumentError("Incorrect operation type passed");
|
||||||
@ -1228,6 +1235,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||||||
switch (operation_type_) {
|
switch (operation_type_) {
|
||||||
case OperationType::MINIMUM:
|
case OperationType::MINIMUM:
|
||||||
case OperationType::MAXIMUM:
|
case OperationType::MAXIMUM:
|
||||||
|
case OperationType::SUB:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -1311,7 +1319,7 @@ class HardSwishOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration*) final {
|
const TfLiteRegistration*) final {
|
||||||
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
||||||
/*outputs=*/1);
|
/*outputs=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1350,7 +1358,8 @@ class LSTMOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2));
|
RETURN_IF_ERROR(CheckExactSupportedOpVersion(registration, 2));
|
||||||
// TODO(eignasheva): Fix bad check.
|
// TODO(eignasheva): Fix bad check.
|
||||||
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/5,
|
// RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
|
// /*runtime_inputs=*/5,
|
||||||
// /*outputs=*/4));
|
// /*outputs=*/4));
|
||||||
TfLiteLSTMParams* tf_options = nullptr;
|
TfLiteLSTMParams* tf_options = nullptr;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
@ -1599,8 +1608,8 @@ class PadOperationParser : public TFLiteOperationParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
@ -1648,11 +1657,13 @@ class Pooling2DOperationParser : public TFLiteOperationParser {
|
|||||||
TfLitePoolParams* tf_options = nullptr;
|
TfLitePoolParams* tf_options = nullptr;
|
||||||
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
|
||||||
if (status.ok()) { // custom case with indices as a second output
|
if (status.ok()) { // custom case with indices as a second output
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
|
/*runtime_inputs=*/1,
|
||||||
/*outputs=*/2));
|
/*outputs=*/2));
|
||||||
} else { // common pooling with 1 output
|
} else { // common pooling with 1 output
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
|
/*runtime_inputs=*/1,
|
||||||
/*outputs=*/1));
|
/*outputs=*/1));
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
||||||
@ -1752,8 +1763,8 @@ class ReshapeOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
// TODO(eignasheva): add shape checking
|
// TODO(eignasheva): add shape checking
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
@ -1786,8 +1797,8 @@ class Resize2DOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
|
|
||||||
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
|
RETURN_IF_ERROR(CheckOnlyUpsamplingIsSupported(context, tflite_node));
|
||||||
bool align_corners;
|
bool align_corners;
|
||||||
@ -1974,8 +1985,8 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
TfLiteSoftmaxParams* tf_options = nullptr;
|
TfLiteSoftmaxParams* tf_options = nullptr;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
|
||||||
if (tf_options->beta != 1) {
|
if (tf_options->beta != 1) {
|
||||||
@ -2018,8 +2029,8 @@ class SpaceToDepthOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
// TODO(impjdi): Dims check.
|
// TODO(impjdi): Dims check.
|
||||||
TfLiteSpaceToDepthParams* s2d_params = nullptr;
|
TfLiteSpaceToDepthParams* s2d_params = nullptr;
|
||||||
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
|
RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
|
||||||
@ -2280,8 +2291,8 @@ class TransposeOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2317,8 +2328,8 @@ class Unpooling2DOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
TfLitePoolParams* tf_options = nullptr;
|
TfLitePoolParams* tf_options = nullptr;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
|
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||||
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
|
||||||
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
RETURN_IF_ERROR(CheckKernelsAndStrides(
|
||||||
tf_options->filter_height, tf_options->filter_width,
|
tf_options->filter_height, tf_options->filter_width,
|
||||||
@ -2445,8 +2456,8 @@ class RoIToTransformMatrixOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/1, /*outputs=*/1));
|
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2478,8 +2489,8 @@ class TransformTensorOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
|
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2515,8 +2526,8 @@ class TransformLandmarksOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||||
CheckInputsOutputs(context, tflite_node, /*inputs=*/2, /*outputs=*/1));
|
/*runtime_inputs=*/2, /*outputs=*/1));
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2549,7 +2560,7 @@ class Landmarks2TransformMatrixOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
||||||
/*outputs=*/1);
|
/*outputs=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2581,7 +2592,7 @@ class MeanOperationParser : public TFLiteOperationParser {
|
|||||||
Status IsSupported(const TfLiteContext* context,
|
Status IsSupported(const TfLiteContext* context,
|
||||||
const TfLiteNode* tflite_node,
|
const TfLiteNode* tflite_node,
|
||||||
const TfLiteRegistration* registration) final {
|
const TfLiteRegistration* registration) final {
|
||||||
return CheckInputsOutputs(context, tflite_node, /*inputs=*/1,
|
return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
|
||||||
/*outputs=*/1);
|
/*outputs=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2970,7 +2981,6 @@ bool IsAllFloatTensors(const TfLiteContext* context,
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||||
|
@ -198,6 +198,7 @@ cc_library(
|
|||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
"//tensorflow/lite/delegates/gpu/gl:node_shader",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
|
||||||
@ -130,89 +131,9 @@ class ElementwiseTwoArguments : public NodeShader {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ImplementElementwise(const GenerationContext& ctx,
|
|
||||||
GeneratedCode* generated_code) const {
|
|
||||||
std::string source;
|
|
||||||
switch (operation_type_) {
|
|
||||||
case OperationType::SUB: {
|
|
||||||
source = "value_0 -= value_1;";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::DIV: {
|
|
||||||
source = "value_0 /= value_1;";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::MAXIMUM: {
|
|
||||||
source = "value_0 = max(value_0, value_1);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::MINIMUM: {
|
|
||||||
source = "value_0 = min(value_0, value_1);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::POW: {
|
|
||||||
// From documentation :
|
|
||||||
// The result is undefined if x<0 or if x=0 and y≤0.
|
|
||||||
source = "value_0 = pow(value_0, value_1);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::SQUARED_DIFF: {
|
|
||||||
source = "value_0 = (value_0 - value_1) * (value_0 - value_1);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return InvalidArgumentError(
|
|
||||||
"Incorrect elementwise with two arguments operation type.");
|
|
||||||
}
|
|
||||||
*generated_code = {
|
|
||||||
/*parameters=*/{},
|
|
||||||
/*objects=*/{},
|
|
||||||
/*shared_variables=*/{},
|
|
||||||
/*workload=*/uint3(),
|
|
||||||
/*workgroup=*/uint3(),
|
|
||||||
/*source_code=*/source,
|
|
||||||
/*input=*/IOStructure::AUTO,
|
|
||||||
/*output=*/IOStructure::AUTO,
|
|
||||||
};
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ImplementElementwiseWithScalar(const GenerationContext& ctx,
|
|
||||||
const float scalar,
|
|
||||||
GeneratedCode* generated_code) const {
|
|
||||||
std::string source;
|
|
||||||
switch (operation_type_) {
|
|
||||||
case OperationType::MAXIMUM: {
|
|
||||||
source = "value_0 = max(value_0, $scalar$);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case OperationType::MINIMUM: {
|
|
||||||
source = "value_0 = min(value_0, $scalar$);";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return InvalidArgumentError(
|
|
||||||
"Incorrect elementwise with scalar operation type.");
|
|
||||||
}
|
|
||||||
*generated_code = {
|
|
||||||
/*parameters=*/{{"scalar", scalar}},
|
|
||||||
/*objects=*/{},
|
|
||||||
/*shared_variables=*/{},
|
|
||||||
/*workload=*/uint3(),
|
|
||||||
/*workgroup=*/uint3(),
|
|
||||||
/*source_code=*/source,
|
|
||||||
/*input=*/IOStructure::AUTO,
|
|
||||||
/*output=*/IOStructure::AUTO,
|
|
||||||
};
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
|
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
|
||||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
|
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
|
||||||
|
|
||||||
if (inputs.size() != 2) {
|
if (inputs.size() != 2) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -223,57 +144,87 @@ class ElementwiseTwoArguments : public NodeShader {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ImplementElementwiseBroadcast(const GenerationContext& ctx,
|
Status GenerateCode(const GenerationContext& ctx,
|
||||||
GeneratedCode* generated_code) const {
|
GeneratedCode* generated_code) const final {
|
||||||
std::string source;
|
std::vector<Variable> parameters;
|
||||||
switch (operation_type_) {
|
std::vector<std::pair<std::string, Object>> objects;
|
||||||
case OperationType::SQUARED_DIFF: {
|
std::string argument0, argument1;
|
||||||
source = R"(
|
if (IsSupportedElemwise(ctx)) {
|
||||||
vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ -
|
argument0 = "value_0";
|
||||||
$input_data_1[0, 0, gid.z]$;
|
argument1 = "value_1";
|
||||||
value_0 = diff * diff;
|
} else if (IsSupportedBroadcast(ctx)) {
|
||||||
)";
|
argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
|
||||||
break;
|
argument1 = "$input_data_1[0, 0, gid.z]$";
|
||||||
|
} else { // Scalar of const vector case
|
||||||
|
const ElementwiseAttributes* attr = absl::any_cast<ElementwiseAttributes>(
|
||||||
|
&ctx.node->operation.attributes);
|
||||||
|
if (!attr) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Couldn't read attributes for the scalar of const vector case.");
|
||||||
|
}
|
||||||
|
auto* tensor =
|
||||||
|
absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||||
|
&attr->param);
|
||||||
|
auto* scalar = absl::get_if<float>(&attr->param);
|
||||||
|
if (!tensor && !scalar) {
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Couldn't read scalar of const vector data from the attributes.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
argument0 = "value_0";
|
||||||
|
if (tensor) {
|
||||||
|
argument1 = "$const_data[gid.z]$";
|
||||||
|
objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
|
||||||
|
} else {
|
||||||
|
argument1 = "vec4($const_data$)";
|
||||||
|
parameters.push_back({"const_data", *scalar});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string source;
|
||||||
|
switch (operation_type_) {
|
||||||
|
case OperationType::DIV: {
|
||||||
|
source = "value_0 = $0/$1;";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OperationType::MAXIMUM: {
|
||||||
|
source = "value_0 = max($0, $1);";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OperationType::MINIMUM: {
|
||||||
|
source = "value_0 = min($0, $1);";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OperationType::SQUARED_DIFF: {
|
||||||
|
source = "value_0 = ($0 - $1) * ($0 - $1);";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OperationType::SUB: {
|
||||||
|
source = "value_0 = $0 - $1;";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OperationType::POW: {
|
||||||
|
source = "value_0 = pow($0, $1);";
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return InvalidArgumentError(
|
return InvalidArgumentError(
|
||||||
"Incorrect elementwise with two arguments operation type.");
|
"Incorrect elementwise with scalar operation type.");
|
||||||
}
|
}
|
||||||
|
source = absl::Substitute(source, argument0, argument1);
|
||||||
*generated_code = {
|
*generated_code = {
|
||||||
/*parameters=*/{},
|
/*parameters=*/std::move(parameters),
|
||||||
/*objects=*/{},
|
/*objects=*/std::move(objects),
|
||||||
/*shared_variables=*/{},
|
/*shared_variables=*/{},
|
||||||
/*workload=*/uint3(),
|
/*workload=*/uint3(),
|
||||||
/*workgroup=*/uint3(),
|
/*workgroup=*/uint3(),
|
||||||
/*source_code=*/source,
|
/*source_code=*/source,
|
||||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
/*input=*/IOStructure::AUTO,
|
||||||
/*output=*/IOStructure::AUTO,
|
/*output=*/IOStructure::AUTO,
|
||||||
};
|
};
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
|
||||||
GeneratedCode* generated_code) const final {
|
|
||||||
if (IsSupportedElemwise(ctx)) {
|
|
||||||
return ImplementElementwise(ctx, generated_code);
|
|
||||||
}
|
|
||||||
if (IsSupportedBroadcast(ctx)) {
|
|
||||||
return ImplementElementwiseBroadcast(ctx, generated_code);
|
|
||||||
}
|
|
||||||
const ElementwiseAttributes* attr =
|
|
||||||
absl::any_cast<ElementwiseAttributes>(&ctx.node->operation.attributes);
|
|
||||||
if (attr) {
|
|
||||||
auto scalar = absl::get_if<float>(&attr->param);
|
|
||||||
if (scalar) {
|
|
||||||
return ImplementElementwiseWithScalar(ctx, *scalar, generated_code);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return InvalidArgumentError(
|
|
||||||
"This case is not supported by elementwise with two arguments "
|
|
||||||
"operation");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OperationType operation_type_;
|
OperationType operation_type_;
|
||||||
};
|
};
|
||||||
|
@ -36,7 +36,7 @@ TensorRef<BHWC> GetTensorRef(int ref, const BHWC& shape) {
|
|||||||
return tensor_ref;
|
return tensor_ref;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Abs) {
|
TEST(ElementwiseOneArgumentTest, Abs) {
|
||||||
OperationType op_type = OperationType::ABS;
|
OperationType op_type = OperationType::ABS;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
@ -48,7 +48,7 @@ TEST(ElementwiseTest, Abs) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 6.2, 2.0, 4.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Cos) {
|
TEST(ElementwiseOneArgumentTest, Cos) {
|
||||||
OperationType op_type = OperationType::COS;
|
OperationType op_type = OperationType::COS;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
@ -60,21 +60,7 @@ TEST(ElementwiseTest, Cos) {
|
|||||||
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
|
Pointwise(FloatNear(1e-6), {1.0, -1.0, -1.0, 0.540302}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Div) {
|
TEST(ElementwiseOneArgumentTest, Exp) {
|
||||||
OperationType op_type = OperationType::DIV;
|
|
||||||
const BHWC shape(1, 2, 2, 1);
|
|
||||||
SingleOpModel model(
|
|
||||||
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
|
||||||
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
|
||||||
/*outputs=*/{GetTensorRef(2, shape)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ElementwiseTest, Exp) {
|
|
||||||
OperationType op_type = OperationType::EXP;
|
OperationType op_type = OperationType::EXP;
|
||||||
const BHWC shape(1, 1, 1, 7);
|
const BHWC shape(1, 1, 1, 7);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
@ -90,7 +76,7 @@ TEST(ElementwiseTest, Exp) {
|
|||||||
std::exp(-0.01f)}));
|
std::exp(-0.01f)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, HardSwish) {
|
TEST(ElementwiseOneArgumentTest, HardSwish) {
|
||||||
OperationType op_type = OperationType::HARD_SWISH;
|
OperationType op_type = OperationType::HARD_SWISH;
|
||||||
const BHWC shape(1, 1, 1, 7);
|
const BHWC shape(1, 1, 1, 7);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
@ -104,7 +90,7 @@ TEST(ElementwiseTest, HardSwish) {
|
|||||||
{0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}));
|
{0.0f, 0.0f, -0.375f, 0.0f, 1.125f, 3.f, 4.5f}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Log) {
|
TEST(ElementwiseOneArgumentTest, Log) {
|
||||||
OperationType op_type = OperationType::LOG;
|
OperationType op_type = OperationType::LOG;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
@ -116,7 +102,142 @@ TEST(ElementwiseTest, Log) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.14473, 0.0, 0.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Maximum) {
|
TEST(ElementwiseOneArgumentTest, Rsqrt) {
|
||||||
|
OperationType op_type = OperationType::RSQRT;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseOneArgumentTest, Sigmoid) {
|
||||||
|
OperationType op_type = OperationType::SIGMOID;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseOneArgumentTest, Sin) {
|
||||||
|
OperationType op_type = OperationType::SIN;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseOneArgumentTest, Sqrt) {
|
||||||
|
OperationType op_type = OperationType::SQRT;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseOneArgumentTest, Square) {
|
||||||
|
OperationType op_type = OperationType::SQUARE;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseOneArgumentTest, Tanh) {
|
||||||
|
OperationType op_type = OperationType::TANH;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(1, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, DivElementwise) {
|
||||||
|
OperationType op_type = OperationType::DIV;
|
||||||
|
const BHWC shape(1, 2, 2, 1);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape), GetTensorRef(1, shape)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.2, 2.0, 4.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {1.0, 2.0, -0.5, 4.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, -3.1, -4.0, 1.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, DivBroadcast) {
|
||||||
|
OperationType op_type = OperationType::DIV;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 5.0, 4.0, 15.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, DivScalar) {
|
||||||
|
OperationType op_type = OperationType::DIV;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
attr.param = static_cast<float>(0.5);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 2.0, 4.0, 6.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, DivConstVector) {
|
||||||
|
OperationType op_type = OperationType::DIV;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {0.4, 0.5};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 2.0, 5.0, 6.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, MaximumElementwise) {
|
||||||
OperationType op_type = OperationType::MAXIMUM;
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model(
|
SingleOpModel model(
|
||||||
@ -130,7 +251,22 @@ TEST(ElementwiseTest, Maximum) {
|
|||||||
Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0}));
|
Pointwise(FloatNear(1e-6), {1.0, 2.0, 3.0, -2.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, MaximumWithScalar) {
|
TEST(ElementwiseTwoArgumentsTest, MaximumBroadcast) {
|
||||||
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.5, 1.0, 2.0, 3.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, MaximumScalar) {
|
||||||
OperationType op_type = OperationType::MAXIMUM;
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
ElementwiseAttributes attr;
|
ElementwiseAttributes attr;
|
||||||
@ -145,7 +281,27 @@ TEST(ElementwiseTest, MaximumWithScalar) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0}));
|
Pointwise(FloatNear(1e-6), {0.0, -1.0, 2.0, -1.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Minimum) {
|
TEST(ElementwiseTwoArgumentsTest, MaximumConstVector) {
|
||||||
|
OperationType op_type = OperationType::MAXIMUM;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {0.4, 0.5};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.4, 1.0, 2.0, 3.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, MinimumElementwise) {
|
||||||
OperationType op_type = OperationType::MINIMUM;
|
OperationType op_type = OperationType::MINIMUM;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model(
|
SingleOpModel model(
|
||||||
@ -159,7 +315,22 @@ TEST(ElementwiseTest, Minimum) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0}));
|
Pointwise(FloatNear(1e-6), {0.0, -6.2, 2.0, -3.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, MinimumWithScalar) {
|
TEST(ElementwiseTwoArgumentsTest, MinimumBroadcast) {
|
||||||
|
OperationType op_type = OperationType::MINIMUM;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {0.5, 0.2}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, MinimumScalar) {
|
||||||
OperationType op_type = OperationType::MINIMUM;
|
OperationType op_type = OperationType::MINIMUM;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
ElementwiseAttributes attr;
|
ElementwiseAttributes attr;
|
||||||
@ -174,7 +345,27 @@ TEST(ElementwiseTest, MinimumWithScalar) {
|
|||||||
Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0}));
|
Pointwise(FloatNear(1e-6), {-1.0, -6.2, -1.0, -3.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Pow) {
|
TEST(ElementwiseTwoArgumentsTest, MinimumConstVector) {
|
||||||
|
OperationType op_type = OperationType::MINIMUM;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {0.5, 0.2};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {0.0, 0.2, 0.5, 0.2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, PowElementwise) {
|
||||||
OperationType op_type = OperationType::POW;
|
OperationType op_type = OperationType::POW;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model(
|
SingleOpModel model(
|
||||||
@ -188,67 +379,57 @@ TEST(ElementwiseTest, Pow) {
|
|||||||
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 8.0, 256.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Rsqrt) {
|
TEST(ElementwiseTwoArgumentsTest, PowBroadcast) {
|
||||||
OperationType op_type = OperationType::RSQRT;
|
OperationType op_type = OperationType::POW;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
SingleOpModel model(
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 4.0, 9.0}));
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {2.0, 0.5}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {1.0, 0.707106, 0.5, 0.333333}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Sigmoid) {
|
TEST(ElementwiseTwoArgumentsTest, PowScalar) {
|
||||||
OperationType op_type = OperationType::SIGMOID;
|
OperationType op_type = OperationType::POW;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
ElementwiseAttributes attr;
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
attr.param = 2.0f;
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
SingleOpModel model(
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
{/*type=*/ToString(op_type), /*attributes=*/std::move(attr)},
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
/*inputs=*/{GetTensorRef(0, shape)},
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
/*outputs=*/{GetTensorRef(2, shape)});
|
||||||
Pointwise(FloatNear(1e-6), {0.5, 0.002473, 0.880797, 0.982014}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ElementwiseTest, Sin) {
|
|
||||||
OperationType op_type = OperationType::SIN;
|
|
||||||
const BHWC shape(1, 2, 2, 1);
|
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 3.1415926, -3.1415926, 1.0}));
|
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 0.0, 0.0, 0.841471}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ElementwiseTest, Sqrt) {
|
|
||||||
OperationType op_type = OperationType::SQRT;
|
|
||||||
const BHWC shape(1, 2, 2, 1);
|
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {0.0, 1.0, 1.414213, 2.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 16.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Square) {
|
TEST(ElementwiseTwoArgumentsTest, PowConstVector) {
|
||||||
OperationType op_type = OperationType::SQUARE;
|
OperationType op_type = OperationType::POW;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
ElementwiseAttributes attr;
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1.0, 2.0, 0.5, -3.0}));
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {2.0, 0.5};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 4.0}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {1.0, 4.0, 0.25, 9.0}));
|
Pointwise(FloatNear(1e-6), {0.0, 1.0, 4.0, 2.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, SquaredDiff) {
|
TEST(ElementwiseTwoArgumentsTest, SquaredDiffElementwise) {
|
||||||
OperationType op_type = OperationType::SQUARED_DIFF;
|
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model(
|
SingleOpModel model(
|
||||||
@ -262,7 +443,56 @@ TEST(ElementwiseTest, SquaredDiff) {
|
|||||||
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
|
Pointwise(FloatNear(1e-6), {1.0, 1.0, 9.0, 0.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Sub) {
|
TEST(ElementwiseTwoArgumentsTest, SquaredDiffBroadcast) {
|
||||||
|
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
|
SingleOpModel model(
|
||||||
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {-1.0, 5.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, SquaredDiffScalar) {
|
||||||
|
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
attr.param = static_cast<float>(5.0);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {25.0, 16.0, 9.0, 4.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, SquaredDiffConstVector) {
|
||||||
|
OperationType op_type = OperationType::SQUARED_DIFF;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {-1.0, 5.0};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {1.0, 16.0, 9.0, 4.0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, SubElementwise) {
|
||||||
OperationType op_type = OperationType::SUB;
|
OperationType op_type = OperationType::SUB;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape(1, 2, 2, 1);
|
||||||
SingleOpModel model(
|
SingleOpModel model(
|
||||||
@ -276,16 +506,53 @@ TEST(ElementwiseTest, Sub) {
|
|||||||
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
|
Pointwise(FloatNear(1e-6), {-1.0, -8.2, -1.0, 0.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ElementwiseTest, Tanh) {
|
TEST(ElementwiseTwoArgumentsTest, SubBroadcast) {
|
||||||
OperationType op_type = OperationType::TANH;
|
OperationType op_type = OperationType::SUB;
|
||||||
const BHWC shape(1, 2, 2, 1);
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
SingleOpModel model({/*type=*/ToString(op_type), /*attributes=*/{}},
|
const BHWC shape1(1, 1, 1, 2);
|
||||||
/*inputs=*/{GetTensorRef(0, shape)},
|
SingleOpModel model(
|
||||||
/*outputs=*/{GetTensorRef(1, shape)});
|
{/*type=*/ToString(op_type), /*attributes=*/{}},
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.0, -6.0, 2.0, 4.0}));
|
/*inputs=*/{GetTensorRef(0, shape0), GetTensorRef(1, shape1)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(1, {0.3, 0.2}));
|
||||||
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
EXPECT_THAT(model.GetOutput(0),
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
Pointwise(FloatNear(1e-6), {0.0, -0.999987, 0.964027, 0.999329}));
|
Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, SubScalar) {
|
||||||
|
OperationType op_type = OperationType::SUB;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
attr.param = static_cast<float>(0.5);
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {-0.5, 0.5, 1.5, 2.5}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ElementwiseTwoArgumentsTest, SubConstVector) {
|
||||||
|
OperationType op_type = OperationType::SUB;
|
||||||
|
const BHWC shape0(1, 2, 1, 2);
|
||||||
|
|
||||||
|
ElementwiseAttributes attr;
|
||||||
|
Tensor<Linear, DataType::FLOAT32> param;
|
||||||
|
param.shape = Linear(2);
|
||||||
|
param.id = 1;
|
||||||
|
param.data = {0.3, 0.2};
|
||||||
|
attr.param = std::move(param);
|
||||||
|
|
||||||
|
SingleOpModel model({/*type=*/ToString(op_type), attr},
|
||||||
|
/*inputs=*/{GetTensorRef(0, shape0)},
|
||||||
|
/*outputs=*/{GetTensorRef(2, shape0)});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.0, 1.0, 2.0, 3.0}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewElementwiseNodeShader(op_type)));
|
||||||
|
EXPECT_THAT(model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {-0.3, 0.8, 1.7, 2.8}));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user