[tfls.util] Add Category class and conversion from TensorLabel to List<Category>.
PiperOrigin-RevId: 309767075 Change-Id: I02245a1cc066af45260d0faac27cb920956fc71d
This commit is contained in:
parent
1411a67438
commit
1c81ed11a2
tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.support.label;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Category is a util class, contains a label and a float value. Typically it's used as result of
|
||||
* classification tasks.
|
||||
*/
|
||||
public final class Category {
|
||||
private final String label;
|
||||
private final float score;
|
||||
|
||||
/** Constructs a Category. */
|
||||
public Category(String label, float score) {
|
||||
this.label = label;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
/** Gets the reference of category's label. */
|
||||
public String getLabel() {
|
||||
return label;
|
||||
}
|
||||
|
||||
/** Gets the score of the category. */
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o instanceof Category) {
|
||||
Category other = (Category) o;
|
||||
return (other.getLabel().equals(this.label) && other.getScore() == this.score);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(label, score);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "<Category \"" + label + "\" (score=" + score + ")>";
|
||||
}
|
||||
}
|
@ -17,6 +17,7 @@ package org.tensorflow.lite.support.label;
|
||||
|
||||
import android.content.Context;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
@ -150,17 +151,19 @@ public class TensorLabel {
|
||||
* than 1, and the axis should be effectively the last axis (which means every sub tensor
|
||||
* specified by this axis should have a flat size of 1).
|
||||
*
|
||||
* @throws IllegalArgumentException if size of a sub tensor on each label is not 1.
|
||||
* <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
|
||||
*
|
||||
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
|
||||
*/
|
||||
@NonNull
|
||||
public Map<String, Float> getMapWithFloatValue() {
|
||||
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
|
||||
SupportPreconditions.checkArgument(
|
||||
SupportPreconditions.checkState(
|
||||
labeledAxis == shape.length - 1,
|
||||
"get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
|
||||
List<String> labels = axisLabels.get(labeledAxis);
|
||||
float[] data = tensorBuffer.getFloatArray();
|
||||
SupportPreconditions.checkArgument(labels.size() == data.length);
|
||||
SupportPreconditions.checkState(labels.size() == data.length);
|
||||
Map<String, Float> result = new LinkedHashMap<>();
|
||||
int i = 0;
|
||||
for (String label : labels) {
|
||||
@ -170,6 +173,37 @@ public class TensorLabel {
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a list of {@link Category} from the {@link TensorLabel} object.
|
||||
*
|
||||
* <p>The axis of label should be effectively the last axis (which means every sub tensor
|
||||
* specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
|
||||
* converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
|
||||
* and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
|
||||
*
|
||||
* <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
|
||||
* the result.
|
||||
*
|
||||
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
|
||||
*/
|
||||
@NonNull
|
||||
public List<Category> getCategoryList() {
|
||||
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
|
||||
SupportPreconditions.checkState(
|
||||
labeledAxis == shape.length - 1,
|
||||
"get a Category list is only valid when the only labeled axis is the last one.");
|
||||
List<String> labels = axisLabels.get(labeledAxis);
|
||||
float[] data = tensorBuffer.getFloatArray();
|
||||
SupportPreconditions.checkState(labels.size() == data.length);
|
||||
List<Category> result = new ArrayList<>();
|
||||
int i = 0;
|
||||
for (String label : labels) {
|
||||
result.add(new Category(label, data[i]));
|
||||
i += 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
|
||||
int[] shape = tensorBuffer.getShape();
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
|
Loading…
Reference in New Issue
Block a user