Support a broadcast case of elementwise op and fix broadcast case of add op for Android.
PiperOrigin-RevId: 290131982 Change-Id: I197c766556a66b7d31ab20125667396fb7bab5ad
This commit is contained in:
parent
2aa9c418da
commit
f3daa69cbc
@ -47,6 +47,7 @@ class Add : public NodeShader {
|
||||
inputs[0]->tensor.shape != inputs[1]->tensor.shape &&
|
||||
inputs[1]->tensor.shape.h == 1 && inputs[1]->tensor.shape.w == 1 &&
|
||||
inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c) {
|
||||
// TODO(b/147771327): investigate why input_data_1[gid.z] worked before
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{},
|
||||
@ -54,8 +55,8 @@ class Add : public NodeShader {
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/
|
||||
"value_0 = $input_data_1[gid.z]$ + $input_data_0[gid.x, gid.y, "
|
||||
"gid.z]$;",
|
||||
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ + "
|
||||
" $input_data_1[gid.z]$;",
|
||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
|
@ -108,7 +108,8 @@ class ElementwiseTwoArguments : public NodeShader {
|
||||
public:
|
||||
explicit ElementwiseTwoArguments(OperationType operation_type)
|
||||
: operation_type_(operation_type) {}
|
||||
static bool IsSupported(const GenerationContext& ctx) {
|
||||
|
||||
bool IsSupportedElemwise(const GenerationContext& ctx) const {
|
||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
|
||||
// Implementation supports concatenation of 2 tensors only.
|
||||
@ -123,16 +124,11 @@ class ElementwiseTwoArguments : public NodeShader {
|
||||
if (shape0 != shape1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
if (!IsSupported(ctx)) {
|
||||
return InvalidArgumentError(
|
||||
"This case is not supported by subtract operation");
|
||||
}
|
||||
Status ImplementElementwise(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const {
|
||||
std::string source;
|
||||
switch (operation_type_) {
|
||||
case OperationType::SUB: {
|
||||
@ -171,6 +167,62 @@ class ElementwiseTwoArguments : public NodeShader {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
|
||||
auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
|
||||
|
||||
if (inputs.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
if (inputs[1]->tensor.shape.h != 1 || inputs[1]->tensor.shape.w != 1 ||
|
||||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Status ImplementElementwiseBroadcast(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const {
|
||||
std::string source;
|
||||
switch (operation_type_) {
|
||||
case OperationType::SQUARED_DIFF: {
|
||||
source = R"(
|
||||
vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ -
|
||||
$input_data_1[0, 0, gid.z]$;
|
||||
value_0 = diff * diff;
|
||||
)";
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
return InvalidArgumentError(
|
||||
"Incorrect elementwise with two arguments operation type.");
|
||||
}
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/source,
|
||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
if (IsSupportedElemwise(ctx)) {
|
||||
return ImplementElementwise(ctx, generated_code);
|
||||
}
|
||||
if (IsSupportedBroadcast(ctx)) {
|
||||
return ImplementElementwiseBroadcast(ctx, generated_code);
|
||||
}
|
||||
return InvalidArgumentError(
|
||||
"This case is not supported by subtract operation");
|
||||
}
|
||||
|
||||
private:
|
||||
OperationType operation_type_;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user