[tfls.util] Add Category class and conversion from TensorLabel to List<Category>.

PiperOrigin-RevId: 309767075
Change-Id: I02245a1cc066af45260d0faac27cb920956fc71d
This commit is contained in:
Xunkai Zhang 2020-05-04 10:09:41 -07:00 committed by TensorFlower Gardener
parent 1411a67438
commit 1c81ed11a2
2 changed files with 99 additions and 3 deletions
tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.support.label;
import java.util.Objects;
/**
* Category is a util class, contains a label and a float value. Typically it's used as result of
* classification tasks.
*/
public final class Category {
private final String label;
private final float score;
/** Constructs a Category. */
public Category(String label, float score) {
this.label = label;
this.score = score;
}
/** Gets the reference of category's label. */
public String getLabel() {
return label;
}
/** Gets the score of the category. */
public float getScore() {
return score;
}
@Override
public boolean equals(Object o) {
if (o instanceof Category) {
Category other = (Category) o;
return (other.getLabel().equals(this.label) && other.getScore() == this.score);
}
return false;
}
@Override
public int hashCode() {
return Objects.hash(label, score);
}
@Override
public String toString() {
return "<Category \"" + label + "\" (score=" + score + ")>";
}
}

View File

@ -17,6 +17,7 @@ package org.tensorflow.lite.support.label;
import android.content.Context;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
@ -150,17 +151,19 @@ public class TensorLabel {
* than 1, and the axis should be effectively the last axis (which means every sub tensor
* specified by this axis should have a flat size of 1).
*
* @throws IllegalArgumentException if size of a sub tensor on each label is not 1.
* <p>{@link TensorLabel#getCategoryList()} is an alternative API to get the result.
*
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
*/
@NonNull
public Map<String, Float> getMapWithFloatValue() {
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
SupportPreconditions.checkArgument(
SupportPreconditions.checkState(
labeledAxis == shape.length - 1,
"get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
List<String> labels = axisLabels.get(labeledAxis);
float[] data = tensorBuffer.getFloatArray();
SupportPreconditions.checkArgument(labels.size() == data.length);
SupportPreconditions.checkState(labels.size() == data.length);
Map<String, Float> result = new LinkedHashMap<>();
int i = 0;
for (String label : labels) {
@ -170,6 +173,37 @@ public class TensorLabel {
return result;
}
/**
* Gets a list of {@link Category} from the {@link TensorLabel} object.
*
* <p>The axis of label should be effectively the last axis (which means every sub tensor
* specified by this axis should have a flat size of 1), so that each labelled sub tensor could be
* converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}}
* and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}.
*
* <p>{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as
* the result.
*
* @throws IllegalStateException if size of a sub tensor on each label is not 1.
*/
@NonNull
public List<Category> getCategoryList() {
int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer);
SupportPreconditions.checkState(
labeledAxis == shape.length - 1,
"get a Category list is only valid when the only labeled axis is the last one.");
List<String> labels = axisLabels.get(labeledAxis);
float[] data = tensorBuffer.getFloatArray();
SupportPreconditions.checkState(labels.size() == data.length);
List<Category> result = new ArrayList<>();
int i = 0;
for (String label : labels) {
result.add(new Category(label, data[i]));
i += 1;
}
return result;
}
private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) {
int[] shape = tensorBuffer.getShape();
for (int i = 0; i < shape.length; i++) {