Support a broadcast case of elementwise op and fix broadcast case of add op for Android.
PiperOrigin-RevId: 290131982 Change-Id: I197c766556a66b7d31ab20125667396fb7bab5ad
This commit is contained in:
parent
2aa9c418da
commit
f3daa69cbc
@ -47,6 +47,7 @@ class Add : public NodeShader {
|
|||||||
inputs[0]->tensor.shape != inputs[1]->tensor.shape &&
|
inputs[0]->tensor.shape != inputs[1]->tensor.shape &&
|
||||||
inputs[1]->tensor.shape.h == 1 && inputs[1]->tensor.shape.w == 1 &&
|
inputs[1]->tensor.shape.h == 1 && inputs[1]->tensor.shape.w == 1 &&
|
||||||
inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c) {
|
inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c) {
|
||||||
|
// TODO(b/147771327): investigate why input_data_1[gid.z] worked before
|
||||||
*generated_code = {
|
*generated_code = {
|
||||||
/*parameters=*/{},
|
/*parameters=*/{},
|
||||||
/*objects=*/{},
|
/*objects=*/{},
|
||||||
@ -54,8 +55,8 @@ class Add : public NodeShader {
|
|||||||
/*workload=*/uint3(),
|
/*workload=*/uint3(),
|
||||||
/*workgroup=*/uint3(),
|
/*workgroup=*/uint3(),
|
||||||
/*source_code=*/
|
/*source_code=*/
|
||||||
"value_0 = $input_data_1[gid.z]$ + $input_data_0[gid.x, gid.y, "
|
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ + "
|
||||||
"gid.z]$;",
|
" $input_data_1[gid.z]$;",
|
||||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||||
/*output=*/IOStructure::AUTO,
|
/*output=*/IOStructure::AUTO,
|
||||||
};
|
};
|
||||||
|
@ -108,7 +108,8 @@ class ElementwiseTwoArguments : public NodeShader {
|
|||||||
public:
|
public:
|
||||||
explicit ElementwiseTwoArguments(OperationType operation_type)
|
explicit ElementwiseTwoArguments(OperationType operation_type)
|
||||||
: operation_type_(operation_type) {}
|
: operation_type_(operation_type) {}
|
||||||
static bool IsSupported(const GenerationContext& ctx) {
|
|
||||||
|
bool IsSupportedElemwise(const GenerationContext& ctx) const {
|
||||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
|
|
||||||
// Implementation supports concatenation of 2 tensors only.
|
// Implementation supports concatenation of 2 tensors only.
|
||||||
@ -123,16 +124,11 @@ class ElementwiseTwoArguments : public NodeShader {
|
|||||||
if (shape0 != shape1) {
|
if (shape0 != shape1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
Status ImplementElementwise(const GenerationContext& ctx,
|
||||||
GeneratedCode* generated_code) const final {
|
GeneratedCode* generated_code) const {
|
||||||
if (!IsSupported(ctx)) {
|
|
||||||
return InvalidArgumentError(
|
|
||||||
"This case is not supported by subtract operation");
|
|
||||||
}
|
|
||||||
std::string source;
|
std::string source;
|
||||||
switch (operation_type_) {
|
switch (operation_type_) {
|
||||||
case OperationType::SUB: {
|
case OperationType::SUB: {
|
||||||
@ -171,6 +167,62 @@ class ElementwiseTwoArguments : public NodeShader {
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
|
||||||
|
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
|
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
|
||||||
|
|
||||||
|
if (inputs.size() != 2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (inputs[1]->tensor.shape.h != 1 || inputs[1]->tensor.shape.w != 1 ||
|
||||||
|
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ImplementElementwiseBroadcast(const GenerationContext& ctx,
|
||||||
|
GeneratedCode* generated_code) const {
|
||||||
|
std::string source;
|
||||||
|
switch (operation_type_) {
|
||||||
|
case OperationType::SQUARED_DIFF: {
|
||||||
|
source = R"(
|
||||||
|
vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ -
|
||||||
|
$input_data_1[0, 0, gid.z]$;
|
||||||
|
value_0 = diff * diff;
|
||||||
|
)";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"Incorrect elementwise with two arguments operation type.");
|
||||||
|
}
|
||||||
|
*generated_code = {
|
||||||
|
/*parameters=*/{},
|
||||||
|
/*objects=*/{},
|
||||||
|
/*shared_variables=*/{},
|
||||||
|
/*workload=*/uint3(),
|
||||||
|
/*workgroup=*/uint3(),
|
||||||
|
/*source_code=*/source,
|
||||||
|
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||||
|
/*output=*/IOStructure::AUTO,
|
||||||
|
};
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GenerateCode(const GenerationContext& ctx,
|
||||||
|
GeneratedCode* generated_code) const final {
|
||||||
|
if (IsSupportedElemwise(ctx)) {
|
||||||
|
return ImplementElementwise(ctx, generated_code);
|
||||||
|
}
|
||||||
|
if (IsSupportedBroadcast(ctx)) {
|
||||||
|
return ImplementElementwiseBroadcast(ctx, generated_code);
|
||||||
|
}
|
||||||
|
return InvalidArgumentError(
|
||||||
|
"This case is not supported by subtract operation");
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OperationType operation_type_;
|
OperationType operation_type_;
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user