TFLite GPU: Rename SOFT_MAX and SoftMax to SOFTMAX and Softmax, respectively.
PiperOrigin-RevId: 260970382
This commit is contained in:
parent
49002f2e95
commit
39db34c731
@ -1737,7 +1737,7 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
|||||||
const TfLiteRegistration* registration, GraphFloat32* graph,
|
const TfLiteRegistration* registration, GraphFloat32* graph,
|
||||||
ObjectReader* reader) final {
|
ObjectReader* reader) final {
|
||||||
Node* node = graph->NewNode();
|
Node* node = graph->NewNode();
|
||||||
node->operation.type = ToString(OperationType::SOFT_MAX);
|
node->operation.type = ToString(OperationType::SOFTMAX);
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||||
|
|
||||||
@ -1753,8 +1753,7 @@ class SoftmaxOperationParser : public TFLiteOperationParser {
|
|||||||
// auto mul_node = reader->NewPassthroughNode(node);
|
// auto mul_node = reader->NewPassthroughNode(node);
|
||||||
// mul_node->operation.type = ToString(OperationType::MUL);
|
// mul_node->operation.type = ToString(OperationType::MUL);
|
||||||
}
|
}
|
||||||
// TODO(impjdi): Rename to SoftmaxAttributes.
|
SoftmaxAttributes attr;
|
||||||
SoftMaxAttributes attr;
|
|
||||||
attr.axis = Axis::CHANNELS; // always by channels
|
attr.axis = Axis::CHANNELS; // always by channels
|
||||||
node->operation.attributes = attr;
|
node->operation.attributes = attr;
|
||||||
return OkStatus();
|
return OkStatus();
|
||||||
|
@ -106,8 +106,8 @@ std::string ToString(enum OperationType op) {
|
|||||||
return "sin";
|
return "sin";
|
||||||
case OperationType::SLICE:
|
case OperationType::SLICE:
|
||||||
return "slice";
|
return "slice";
|
||||||
case OperationType::SOFT_MAX:
|
case OperationType::SOFTMAX:
|
||||||
return "soft_max";
|
return "softmax";
|
||||||
case OperationType::SPACE_TO_BATCH:
|
case OperationType::SPACE_TO_BATCH:
|
||||||
return "space_to_batch";
|
return "space_to_batch";
|
||||||
case OperationType::SQRT:
|
case OperationType::SQRT:
|
||||||
@ -158,7 +158,7 @@ OperationType OperationTypeFromString(const std::string& name) {
|
|||||||
{"sigmoid", OperationType::SIGMOID},
|
{"sigmoid", OperationType::SIGMOID},
|
||||||
{"sin", OperationType::SIN},
|
{"sin", OperationType::SIN},
|
||||||
{"slice", OperationType::SLICE},
|
{"slice", OperationType::SLICE},
|
||||||
{"soft_max", OperationType::SOFT_MAX},
|
{"softmax", OperationType::SOFTMAX},
|
||||||
{"sqrt", OperationType::SQRT},
|
{"sqrt", OperationType::SQRT},
|
||||||
{"square", OperationType::SQUARE},
|
{"square", OperationType::SQUARE},
|
||||||
{"subtract", OperationType::SUB},
|
{"subtract", OperationType::SUB},
|
||||||
|
@ -63,7 +63,7 @@ enum class OperationType {
|
|||||||
SIGMOID,
|
SIGMOID,
|
||||||
SIN,
|
SIN,
|
||||||
SLICE,
|
SLICE,
|
||||||
SOFT_MAX,
|
SOFTMAX,
|
||||||
SPACE_TO_BATCH,
|
SPACE_TO_BATCH,
|
||||||
SQRT,
|
SQRT,
|
||||||
SQUARE,
|
SQUARE,
|
||||||
@ -239,7 +239,7 @@ struct PReLUAttributes {
|
|||||||
alpha;
|
alpha;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SoftMaxAttributes {
|
struct SoftmaxAttributes {
|
||||||
Axis axis = Axis::UNKNOWN;
|
Axis axis = Axis::UNKNOWN;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ class Registry : public NodeShader {
|
|||||||
insert_op(Type::RELU, NewReLUNodeShader);
|
insert_op(Type::RELU, NewReLUNodeShader);
|
||||||
insert_op(Type::RESHAPE, NewReshapeNodeShader);
|
insert_op(Type::RESHAPE, NewReshapeNodeShader);
|
||||||
insert_op(Type::SLICE, NewSliceNodeShader);
|
insert_op(Type::SLICE, NewSliceNodeShader);
|
||||||
insert_op(Type::SOFT_MAX, NewSoftMaxNodeShader);
|
insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
|
||||||
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
|
insert_op(Type::UPSAMPLE_2D, NewUpsamplingNodeShader);
|
||||||
|
|
||||||
insert_elementwise_op(Type::ABS);
|
insert_elementwise_op(Type::ABS);
|
||||||
|
@ -33,14 +33,14 @@ namespace gpu {
|
|||||||
namespace gl {
|
namespace gl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class SoftMax : public NodeShader {
|
class Softmax : public NodeShader {
|
||||||
public:
|
public:
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
Status GenerateCode(const GenerationContext& ctx,
|
||||||
GeneratedCode* generated_code) const final {
|
GeneratedCode* generated_code) const final {
|
||||||
auto input = ctx.graph->FindInputs(ctx.node->id)[0];
|
const auto* input = ctx.graph->FindInputs(ctx.node->id)[0];
|
||||||
auto output = ctx.graph->FindOutputs(ctx.node->id)[0];
|
const auto* output = ctx.graph->FindOutputs(ctx.node->id)[0];
|
||||||
auto attr =
|
const auto& attr = absl::any_cast<const SoftmaxAttributes&>(
|
||||||
absl::any_cast<SoftMaxAttributes>(ctx.node->operation.attributes);
|
ctx.node->operation.attributes);
|
||||||
if (input->tensor.shape != output->tensor.shape) {
|
if (input->tensor.shape != output->tensor.shape) {
|
||||||
return InvalidArgumentError("Input and output shape does not match");
|
return InvalidArgumentError("Input and output shape does not match");
|
||||||
}
|
}
|
||||||
@ -89,8 +89,8 @@ class SoftMax : public NodeShader {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewSoftMaxNodeShader() {
|
std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
|
||||||
return absl::make_unique<SoftMax>();
|
return absl::make_unique<Softmax>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gl
|
} // namespace gl
|
||||||
|
@ -25,7 +25,7 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace gl {
|
namespace gl {
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewSoftMaxNodeShader();
|
std::unique_ptr<NodeShader> NewSoftmaxNodeShader();
|
||||||
|
|
||||||
} // namespace gl
|
} // namespace gl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -31,7 +31,7 @@ namespace gpu {
|
|||||||
namespace gl {
|
namespace gl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(SoftmaxTest, WorksForChannelsAxis) {
|
TEST(SoftmaxTest, Softmax) {
|
||||||
TensorRef<BHWC> input;
|
TensorRef<BHWC> input;
|
||||||
input.type = DataType::FLOAT32;
|
input.type = DataType::FLOAT32;
|
||||||
input.ref = 0;
|
input.ref = 0;
|
||||||
@ -42,13 +42,13 @@ TEST(SoftmaxTest, WorksForChannelsAxis) {
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
output.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
SoftMaxAttributes attr;
|
SoftmaxAttributes attr;
|
||||||
attr.axis = Axis::CHANNELS;
|
attr.axis = Axis::CHANNELS;
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||||
{output});
|
{output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.1, 0.2}));
|
||||||
ASSERT_OK(model.Invoke(*NewSoftMaxNodeShader()));
|
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
|
||||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 1, 1, 1}));
|
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 1, 1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,15 +63,13 @@ TEST(SoftmaxTest, DoesNotWorkForHeightAxis) {
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
output.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
SoftMaxAttributes attr;
|
SoftmaxAttributes attr;
|
||||||
attr.axis = Axis::HEIGHT;
|
attr.axis = Axis::HEIGHT;
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||||
{output});
|
{output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||||
ASSERT_THAT(
|
EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
|
||||||
model.Invoke(*NewSoftMaxNodeShader()).message(),
|
|
||||||
testing::HasSubstr("Softmax is only supported for channels axis."));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
||||||
@ -85,15 +83,40 @@ TEST(SoftmaxTest, DoesNotWorkForWidthAxis) {
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
output.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
SoftMaxAttributes attr;
|
SoftmaxAttributes attr;
|
||||||
attr.axis = Axis::WIDTH;
|
attr.axis = Axis::WIDTH;
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::SOFT_MAX), attr}, {input},
|
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||||
{output});
|
{output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||||
ASSERT_THAT(
|
EXPECT_FALSE(model.Invoke(*NewSoftmaxNodeShader()).ok());
|
||||||
model.Invoke(*NewSoftMaxNodeShader()).message(),
|
}
|
||||||
testing::HasSubstr("Softmax is only supported for channels axis."));
|
|
||||||
|
TEST(SoftmaxTest, Softmax1x1) {
|
||||||
|
TensorRef<BHWC> input;
|
||||||
|
input.type = DataType::FLOAT32;
|
||||||
|
input.ref = 0;
|
||||||
|
input.shape = BHWC(1, 1, 1, 4);
|
||||||
|
|
||||||
|
TensorRef<BHWC> output;
|
||||||
|
output.type = DataType::FLOAT32;
|
||||||
|
output.ref = 1;
|
||||||
|
output.shape = BHWC(1, 1, 1, 4);
|
||||||
|
|
||||||
|
SoftmaxAttributes attr;
|
||||||
|
attr.axis = Axis::CHANNELS;
|
||||||
|
|
||||||
|
const double sum =
|
||||||
|
std::exp(0.1) + std::exp(0.2) + std::exp(0.3) + std::exp(0.4);
|
||||||
|
|
||||||
|
SingleOpModel model({ToString(OperationType::SOFTMAX), attr}, {input},
|
||||||
|
{output});
|
||||||
|
ASSERT_TRUE(model.PopulateTensor(0, {0.1, 0.2, 0.3, 0.4}));
|
||||||
|
ASSERT_OK(model.Invoke(*NewSoftmaxNodeShader()));
|
||||||
|
EXPECT_THAT(
|
||||||
|
model.GetOutput(0),
|
||||||
|
Pointwise(FloatNear(1e-6), {std::exp(0.1) / sum, std::exp(0.2) / sum,
|
||||||
|
std::exp(0.3) / sum, std::exp(0.4) / sum}));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -220,9 +220,9 @@ Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
|||||||
Slice(node_id, inputs[0], outputs[0],
|
Slice(node_id, inputs[0], outputs[0],
|
||||||
absl::any_cast<SliceAttributes>(node->operation.attributes));
|
absl::any_cast<SliceAttributes>(node->operation.attributes));
|
||||||
break;
|
break;
|
||||||
case OperationType::SOFT_MAX: {
|
case OperationType::SOFTMAX: {
|
||||||
auto attr =
|
auto attr =
|
||||||
absl::any_cast<SoftMaxAttributes>(node->operation.attributes);
|
absl::any_cast<SoftmaxAttributes>(node->operation.attributes);
|
||||||
if (attr.axis != Axis::CHANNELS) {
|
if (attr.axis != Axis::CHANNELS) {
|
||||||
return UnimplementedError("Softmax supports only CHANNELS dimension");
|
return UnimplementedError("Softmax supports only CHANNELS dimension");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user