Support a broadcast case of elementwise op and fix broadcast case of add op for Android.

PiperOrigin-RevId: 290131982
Change-Id: I197c766556a66b7d31ab20125667396fb7bab5ad
This commit is contained in:
A. Unique TensorFlower 2020-01-16 13:25:28 -08:00 committed by TensorFlower Gardener
parent 2aa9c418da
commit f3daa69cbc
2 changed files with 63 additions and 10 deletions

View File

@ -47,6 +47,7 @@ class Add : public NodeShader {
inputs[0]->tensor.shape != inputs[1]->tensor.shape &&
inputs[1]->tensor.shape.h == 1 && inputs[1]->tensor.shape.w == 1 &&
inputs[0]->tensor.shape.c == inputs[1]->tensor.shape.c) {
// TODO(b/147771327): investigate why input_data_1[gid.z] worked before
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
@ -54,8 +55,8 @@ class Add : public NodeShader {
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/
"value_0 = $input_data_1[gid.z]$ + $input_data_0[gid.x, gid.y, "
"gid.z]$;",
"value_0 = $input_data_0[gid.x, gid.y, gid.z]$ + "
" $input_data_1[gid.z]$;",
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*output=*/IOStructure::AUTO,
};

View File

@ -108,7 +108,8 @@ class ElementwiseTwoArguments : public NodeShader {
public:
explicit ElementwiseTwoArguments(OperationType operation_type)
: operation_type_(operation_type) {}
static bool IsSupported(const GenerationContext& ctx) {
bool IsSupportedElemwise(const GenerationContext& ctx) const {
auto inputs = ctx.graph->FindInputs(ctx.node->id);
// Implementation supports concatenation of 2 tensors only.
@ -123,16 +124,11 @@ class ElementwiseTwoArguments : public NodeShader {
if (shape0 != shape1) {
return false;
}
return true;
}
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
if (!IsSupported(ctx)) {
return InvalidArgumentError(
"This case is not supported by subtract operation");
}
Status ImplementElementwise(const GenerationContext& ctx,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::SUB: {
@ -171,6 +167,62 @@ class ElementwiseTwoArguments : public NodeShader {
return OkStatus();
}
bool IsSupportedBroadcast(const GenerationContext& ctx) const {
auto inputs = ctx.graph->FindInputs(ctx.node->id);
auto outputs = ctx.graph->FindOutputs(ctx.node->id);
if (inputs.size() != 2) {
return false;
}
if (inputs[1]->tensor.shape.h != 1 || inputs[1]->tensor.shape.w != 1 ||
inputs[0]->tensor.shape.c != inputs[1]->tensor.shape.c) {
return false;
}
return true;
}
Status ImplementElementwiseBroadcast(const GenerationContext& ctx,
GeneratedCode* generated_code) const {
std::string source;
switch (operation_type_) {
case OperationType::SQUARED_DIFF: {
source = R"(
vec4 diff = $input_data_0[gid.x, gid.y, gid.z]$ -
$input_data_1[0, 0, gid.z]$;
value_0 = diff * diff;
)";
break;
}
default:
return InvalidArgumentError(
"Incorrect elementwise with two arguments operation type.");
}
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/source,
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
if (IsSupportedElemwise(ctx)) {
return ImplementElementwise(ctx, generated_code);
}
if (IsSupportedBroadcast(ctx)) {
return ImplementElementwiseBroadcast(ctx, generated_code);
}
return InvalidArgumentError(
"This case is not supported by subtract operation");
}
private:
OperationType operation_type_;
};