From 1c81ed11a2d81a550d1c4161aca6ab155156805a Mon Sep 17 00:00:00 2001 From: Xunkai Zhang Date: Mon, 4 May 2020 10:09:41 -0700 Subject: [PATCH] [tfls.util] Add Category class and conversion from TensorLabel to List. PiperOrigin-RevId: 309767075 Change-Id: I02245a1cc066af45260d0faac27cb920956fc71d --- .../lite/support/label/Category.java | 62 +++++++++++++++++++ .../lite/support/label/TensorLabel.java | 40 +++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/Category.java diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/Category.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/Category.java new file mode 100644 index 00000000000..ea369c3ac12 --- /dev/null +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/Category.java @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.label; + +import java.util.Objects; + +/** + * Category is a util class, contains a label and a float value. Typically it's used as result of + * classification tasks. + */ +public final class Category { + private final String label; + private final float score; + + /** Constructs a Category. */ + public Category(String label, float score) { + this.label = label; + this.score = score; + } + + /** Gets the reference of category's label. */ + public String getLabel() { + return label; + } + + /** Gets the score of the category. */ + public float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (o instanceof Category) { + Category other = (Category) o; + return (other.getLabel().equals(this.label) && other.getScore() == this.score); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hash(label, score); + } + + @Override + public String toString() { + return ""; + } +} diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java index 8c27995c0f7..10763a1a065 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/label/TensorLabel.java @@ -17,6 +17,7 @@ package org.tensorflow.lite.support.label; import android.content.Context; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; @@ -150,17 +151,19 @@ public class TensorLabel { * than 1, and the axis should be effectively the last axis (which means every sub tensor * specified by this axis should have a flat size of 1). * - * @throws IllegalArgumentException if size of a sub tensor on each label is not 1. + *

{@link TensorLabel#getCategoryList()} is an alternative API to get the result. + * + * @throws IllegalStateException if size of a sub tensor on each label is not 1. */ @NonNull public Map getMapWithFloatValue() { int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); - SupportPreconditions.checkArgument( + SupportPreconditions.checkState( labeledAxis == shape.length - 1, "get a map is only valid when the only labeled axis is the last one."); List labels = axisLabels.get(labeledAxis); float[] data = tensorBuffer.getFloatArray(); - SupportPreconditions.checkArgument(labels.size() == data.length); + SupportPreconditions.checkState(labels.size() == data.length); Map result = new LinkedHashMap<>(); int i = 0; for (String label : labels) { @@ -170,6 +173,37 @@ public class TensorLabel { return result; } + /** + * Gets a list of {@link Category} from the {@link TensorLabel} object. + * + *

The axis of label should be effectively the last axis (which means every sub tensor + * specified by this axis should have a flat size of 1), so that each labelled sub tensor could be + * converted into a float value score. Example: A {@link TensorLabel} with shape {@code {2, 5, 3}} + * and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a {@link Category}. + * + *

{@link TensorLabel#getMapWithFloatValue()} is an alternative but returns a {@link Map} as + * the result. + * + * @throws IllegalStateException if size of a sub tensor on each label is not 1. + */ + @NonNull + public List getCategoryList() { + int labeledAxis = getFirstAxisWithSizeGreaterThanOne(tensorBuffer); + SupportPreconditions.checkState( + labeledAxis == shape.length - 1, + "get a Category list is only valid when the only labeled axis is the last one."); + List labels = axisLabels.get(labeledAxis); + float[] data = tensorBuffer.getFloatArray(); + SupportPreconditions.checkState(labels.size() == data.length); + List result = new ArrayList<>(); + int i = 0; + for (String label : labels) { + result.add(new Category(label, data[i])); + i += 1; + } + return result; + } + private static int getFirstAxisWithSizeGreaterThanOne(@NonNull TensorBuffer tensorBuffer) { int[] shape = tensorBuffer.getShape(); for (int i = 0; i < shape.length; i++) {