TFLite GPU: Rename SOFT_MAX and SoftMax to SOFTMAX and Softmax, respectively.
PiperOrigin-RevId: 260970382
This commit is contained in:
parent
49002f2e95
commit
39db34c731
@ -1737,7 +1737,7 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||
ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::SOFT_MAX);
|
||||
node->operation.type = ToString(OperationType::SOFTMAX);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
|
||||
@ -1753,8 +1753,7 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
||||
// auto mul_node = reader->NewPassthroughNode(node);
|
||||
// mul_node->operation.type = ToString(OperationType::MUL);
|
||||
}
|
||||
// TODO(impjdi): Rename to SoftmaxAttributes.
|
||||
SoftMaxAttributes attr;
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::CHANNELS; // always by channels
|
||||
node->operation.attributes = attr;
|
||||
return OkStatus();
|
||||
|
@ -106,8 +106,8 @@ std::string ToString(enum OperationType op) {
|
||||
return "sin";
|
||||
case OperationType::SLICE:
|
||||
return "slice";
|
||||
case OperationType::SOFT_MAX:
|
||||
return "soft_max";
|
||||
case OperationType::SOFTMAX:
|
||||
return "softmax";
|
||||
case OperationType::SPACE_TO_BATCH:
|
||||
return "space_to_batch";
|
||||
case OperationType::SQRT:
|
||||
@ -158,7 +158,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"sigmoid", OperationType::SIGMOID},
|
||||
{"sin", OperationType::SIN},
|
||||
{"slice", OperationType::SLICE},
|
||||
{"soft_max", OperationType::SOFT_MAX},
|
||||
{"softmax", OperationType::SOFTMAX},
|
||||
{"sqrt", OperationType::SQRT},
|
||||
{"square", OperationType::SQUARE},
|
||||
{"subtract", OperationType::SUB},
|
||||
|
@ -63,7 +63,7 @@ enum class OperationType {
|
||||
SIGMOID,
|
||||
SIN,
|
||||
SLICE,
|
||||
SOFT_MAX,
|
||||
SOFTMAX,
|
||||
SPACE_TO_BATCH,
|
||||
SQRT,
|
||||
SQUARE,
|
||||
@ -239,7 +239,7 @@ struct PReLUAttributes {
|
||||
alpha;
|
||||
};
|
||||
|
||||
struct SoftMaxAttributes {
|
||||
struct SoftmaxAttributes {
|
||||
Axis axis = Axis::UNKNOWN;
|
||||
};
|
||||
|
||||
|
@ -86,7 +86,7 @@ class Registry : public NodeShader {
|
||||
insert_op(Type::RELU, NewReLUNodeShader);
|
||||
insert_op(Type::RESHAPE, NewReshapeNodeShader);
|
||||
insert_op(Type::SLICE, NewSliceNodeShader);
|
||||
insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
|
||||
insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
|
||||
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
|
||||
|
||||
insert_elementwise_op(Type::ABS);
|
||||
|
@ -33,14 +33,14 @@ namespace gpu {
|
||||
namespace gl {
|
||||
namespace {
|
||||
|
||||
class SoftMax : public NodeShader {
|
||||
class Softmax : public NodeShader {
|
||||
public:
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
auto input = ctx.graph->FindInputs(ctx.node->id)[0];
|
||||
auto output = ctx.graph->FindOutputs(ctx.node->id)[0];
|
||||
auto attr =
|
||||
absl::any_cast<SoftMaxAttributes>(ctx.node->operation.attributes);
|
||||
const auto* input = ctx.graph->FindInputs(ctx.node->id)[0];
|
||||
const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0];
|
||||
const auto& attr = absl::any_cast<const SoftmaxAttributes&>(
|
||||
ctx.node->operation.attributes);
|
||||
if (input->tensor.shape != output->tensor.shape) {
|
||||
return InvalidArgumentError("Input and output shape does not match");
|
||||
}
|
||||
@ -89,8 +89,8 @@ class SoftMax : public NodeShader {
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NodeShader> NewSoftMaxNodeShader() {
|
||||
return absl::make_unique<SoftMax>();
|
||||
std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
|
||||
return absl::make_unique<Softmax>();
|
||||
}
|
||||
|
||||
} // namespace gl
|
||||
|
@ -25,7 +25,7 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
|
||||
std::unique_ptr<NodeShader> NewSoftMaxNodeShader();
|
||||
std::unique_ptr<NodeShader> NewSoftmaxNodeShader();
|
||||
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
|
@ -31,7 +31,7 @@ namespace gpu {
|
||||
namespace gl {
|
||||
namespace {
|
||||
|
||||
TEST(SoftmaxTest, WorksForChannelsAxis) {
|
||||
TEST(SoftmaxTest, Softmax) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
@ -42,13 +42,13 @@ TEST(SoftmaxTest, WorksForChannelsAxis) {
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
SoftMaxAttributes attr;
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::CHANNELS;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2}));
|
||||
ASSERT_OK(model.Invoke(*NewSoftMaxNodeShader()));
|
||||
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
@ -63,15 +63,13 @@ TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
SoftMaxAttributes attr;
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::HEIGHT;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_THAT(
|
||||
model.Invoke(*NewSoftMaxNodeShader()).message(),
|
||||
testing::HasSubstr("Softmax is only supported for channels axis."));
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||
EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
|
||||
}
|
||||
|
||||
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
||||
@ -85,15 +83,40 @@ TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
SoftMaxAttributes attr;
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::WIDTH;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_THAT(
|
||||
model.Invoke(*NewSoftMaxNodeShader()).message(),
|
||||
testing::HasSubstr("Softmax is only supported for channels axis."));
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||
EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
|
||||
}
|
||||
|
||||
TEST(SoftmaxTest, Softmax1x1) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 1, 4);
|
||||
|
||||
SoftmaxAttributes attr;
|
||||
attr.axis = Axis::CHANNELS;
|
||||
|
||||
const double sum =
|
||||
std::exp(0.1) + std::exp(0.2) + std::exp(0.3) + std::exp(0.4);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
|
||||
EXPECT_THAT(
|
||||
model.GetOutput(0),
|
||||
Pointwise(FloatNear(1e-6), {std::exp(0.1) / sum, std::exp(0.2) / sum,
|
||||
std::exp(0.3) / sum, std::exp(0.4) / sum}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -220,9 +220,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
||||
Slice(node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<SliceAttributes>(node->operation.attributes));
|
||||
break;
|
||||
case OperationType::SOFT_MAX: {
|
||||
case OperationType::SOFTMAX: {
|
||||
auto attr =
|
||||
absl::any_cast<SoftMaxAttributes>(node->operation.attributes);
|
||||
absl::any_cast<SoftmaxAttributes>(node->operation.attributes);
|
||||
if (attr.axis != Axis::CHANNELS) {
|
||||
return UnimplementedError("Softmax supports only CHANNELS dimension");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user