Clean up for smart reply demo. As folder tensorflow/lite/models/ becomes empty, delete parent folder meanwhile.

The whole project has been moved to: https://github.com/tensorflow/examples/tree/master/lite/examples/smart_reply/android

(Please see smart reply web page for details: https://www.tensorflow.org/lite/models/smart_reply/overview)

PiperOrigin-RevId: 288242509
Change-Id: Iabce717849f91cd01242634a090c0d936412e834
This commit is contained in:
Tian Lin 2020-01-05 21:53:27 -08:00 committed by TensorFlower Gardener
parent 81fcd56fb3
commit afbfc6a092
23 changed files with 0 additions and 2058 deletions

View File

@ -1,12 +0,0 @@
# Model tests
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
exports_files(glob([
"testdata/*",
]))

View File

@ -1,117 +0,0 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/lite:build_def.bzl", "gen_selected_ops", "tflite_copts")
package(
default_visibility = [
"//visibility:public",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
gen_selected_ops(
name = "smartreply_ops",
model = ["@tflite_smartreply//:smartreply.tflite"],
)
cc_library(
name = "custom_ops",
srcs = [
"ops/extract_feature.cc",
"ops/normalize.cc",
"ops/predict.cc",
":smartreply_ops",
],
copts = tflite_copts(),
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:tensor",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
],
alwayslink = 1,
)
cc_library(
name = "predictor_lib",
srcs = ["predictor.cc"],
hdrs = ["predictor.h"],
copts = tflite_copts(),
deps = [
":custom_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
],
)
# TODO(b/118895218): Make this test compatible with oss.
tf_cc_test(
name = "predictor_test",
srcs = ["predictor_test.cc"],
data = [
"//tensorflow/lite/models:testdata/smartreply_samples.tsv",
"@tflite_smartreply//:smartreply.tflite",
],
tags = ["no_oss"],
deps = [
":predictor_lib",
"//tensorflow/core:test",
"//tensorflow/lite:string_util",
"//tensorflow/lite/testing:util",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "extract_feature_op_test",
size = "small",
srcs = ["ops/extract_feature_test.cc"],
tags = ["no_oss"],
deps = [
":custom_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
"@farmhash_archive//:farmhash",
],
)
cc_test(
name = "normalize_op_test",
size = "small",
srcs = ["ops/normalize_test.cc"],
tags = ["no_oss"],
deps = [
":custom_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "predict_op_test",
size = "small",
srcs = ["ops/predict_test.cc"],
tags = ["no_oss"],
deps = [
":custom_ops",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)

View File

@ -1,38 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2017 The Android Open Source Project
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.example.android.smartreply" >
<uses-sdk
android:minSdkVersion="15"
android:targetSdkVersion="24" />
<application android:label="TfLite SmartReply Demo">
<activity
android:name="com.example.android.smartreply.MainActivity"
android:configChanges="orientation|keyboardHidden|screenSize"
android:windowSoftInputMode="stateUnchanged|adjustPan"
android:label="TfLite SmartReply Demo"
android:screenOrientation="portrait" >
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -1,68 +0,0 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
load(
"//tensorflow/lite:build_def.bzl",
"tflite_copts",
"tflite_jni_binary",
)
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
filegroup(
name = "assets",
srcs = [
"@tflite_smartreply//:model_files",
],
)
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
manifest = "AndroidManifest.xml",
nocompress_extensions = [
".tflite",
],
resource_files = glob(["res/**"]),
tags = ["manual"],
deps = [
":smartreply_runtime",
"@androidsdk//com.android.support:support-v13-25.2.0",
"@androidsdk//com.android.support:support-v4-25.2.0",
],
)
cc_library(
name = "smartreply_runtime",
srcs = ["libsmartreply_jni.so"],
visibility = ["//visibility:public"],
)
tflite_jni_binary(
name = "libsmartreply_jni.so",
deps = [
":smartreply_jni_lib",
],
)
cc_library(
name = "smartreply_jni_lib",
srcs = [
"smartreply_jni.cc",
],
copts = tflite_copts(),
linkopts = [
"-lm",
"-ldl",
],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/java/jni",
"//tensorflow/lite/models/smartreply:predictor_lib",
],
alwayslink = 1,
)

View File

@ -1,16 +0,0 @@
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
exports_files(glob(["*"]))
filegroup(
name = "assets_files",
srcs = glob(
["**/*"],
exclude = [
"BUILD",
],
),
)

View File

@ -1,16 +0,0 @@
Ok
Yes
No
👍
😟
❤️
Lol
Thanks
Got it
Done
Nice
I don't know
What?
Why?
What's up?

View File

@ -1,99 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.app.Activity;
import android.os.Bundle;
import android.os.Handler;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
/**
* The main (and only) activity of this demo app. Displays a text box which updates as messages are
* received.
*/
public class MainActivity extends Activity {
private static final String TAG = "SmartReplyDemo";
private SmartReplyClient client;
private Button sendButton;
private TextView messageTextView;
private EditText messageInput;
private Handler handler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Log.v(TAG, "onCreate");
setContentView(R.layout.main_activity);
client = new SmartReplyClient(getApplicationContext());
handler = new Handler();
sendButton = (Button) findViewById(R.id.send_button);
sendButton.setOnClickListener(
(View v) -> {
send(messageInput.getText().toString());
});
messageTextView = (TextView) findViewById(R.id.message_text);
messageInput = (EditText) findViewById(R.id.message_input);
}
@Override
protected void onStart() {
super.onStart();
Log.v(TAG, "onStart");
handler.post(
() -> {
client.loadModel();
});
}
@Override
protected void onStop() {
super.onStop();
Log.v(TAG, "onStop");
handler.post(
() -> {
client.unloadModel();
});
}
private void send(final String message) {
handler.post(
() -> {
messageTextView.append("Input: " + message + "\n");
SmartReply[] ans = client.predict(new String[] {message});
for (SmartReply reply : ans) {
appendMessage("Reply: " + reply.getText());
}
appendMessage("------");
});
}
private void appendMessage(final String message) {
handler.post(
() -> {
messageTextView.append(message + "\n");
});
}
}

View File

@ -1,44 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.support.annotation.Keep;
/**
* SmartReply contains predicted message, and confidence.
*
* <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated.
*/
@Keep
public class SmartReply {
private final String text;
private final float score;
@Keep
public SmartReply(String text, float score) {
this.text = text;
this.score = score;
}
public String getText() {
return text;
}
public float getScore() {
return score;
}
}

View File

@ -1,131 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.support.annotation.Keep;
import android.support.annotation.WorkerThread;
import android.util.Log;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
/** Interface to load TfLite model and provide predictions. */
public class SmartReplyClient implements AutoCloseable {
private static final String TAG = "SmartReplyDemo";
private static final String MODEL_PATH = "smartreply.tflite";
private static final String BACKOFF_PATH = "backoff_response.txt";
private static final String JNI_LIB = "smartreply_jni";
private final Context context;
private long storage;
private MappedByteBuffer model;
private volatile boolean isLibraryLoaded;
public SmartReplyClient(Context context) {
this.context = context;
}
public boolean isLoaded() {
return storage != 0;
}
@WorkerThread
public synchronized void loadModel() {
if (!isLibraryLoaded) {
try {
System.loadLibrary(JNI_LIB);
isLibraryLoaded = true;
} catch (Exception e) {
Log.e(TAG, "Failed to load prebuilt smartreply_jni lib", e);
return;
}
}
try {
model = loadModelFile();
String[] backoff = loadBackoffList();
storage = loadJNI(model, backoff);
} catch (IOException e) {
Log.e(TAG, "Fail to load model", e);
return;
}
}
@WorkerThread
public synchronized SmartReply[] predict(String[] input) {
if (storage != 0) {
return predictJNI(storage, input);
} else {
return new SmartReply[] {};
}
}
@WorkerThread
public synchronized void unloadModel() {
close();
}
@Override
public synchronized void close() {
if (storage != 0) {
unloadJNI(storage);
storage = 0;
}
}
private MappedByteBuffer loadModelFile() throws IOException {
try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
private String[] loadBackoffList() throws IOException {
List<String> labelList = new ArrayList<String>();
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)))) {
String line;
while ((line = reader.readLine()) != null) {
if (!line.isEmpty()) {
labelList.add(line);
}
}
}
String[] ans = new String[labelList.size()];
labelList.toArray(ans);
return ans;
}
@Keep
private native long loadJNI(MappedByteBuffer buffer, String[] backoff);
@Keep
private native SmartReply[] predictJNI(long storage, String[] text);
@Keep
private native void unloadJNI(long storage);
}

View File

@ -1,44 +0,0 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="0dp"
android:padding="5dip"
android:layout_weight="3">
<TextView
android:id="@+id/message_text"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:scrollbars="vertical"
android:gravity="bottom"/>
</LinearLayout>
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="0dp"
android:padding="5dip"
android:layout_weight="1">
<EditText
android:id="@+id/message_input"
android:layout_width="0dp"
android:layout_height="fill_parent"
android:layout_weight="6"
android:scrollbars="vertical"
android:hint="Enter Text"
android:gravity="top"
android:inputType="text"/>
<Button
android:id="@+id/send_button"
android:layout_width="0dp"
android:layout_height="fill_parent"
android:layout_weight="2"
android:text="Send" />
</LinearLayout>
</LinearLayout>

View File

@ -1,129 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <jni.h>
#include <utility>
#include <vector>
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/models/smartreply/predictor.h"
const char kIllegalStateException[] = "java/lang/IllegalStateException";
using tflite::custom::smartreply::GetSegmentPredictions;
using tflite::custom::smartreply::PredictorResponse;
template <typename T>
T CheckNotNull(JNIEnv* env, T&& t) {
if (t == nullptr) {
env->ThrowNew(env->FindClass(kIllegalStateException), "");
return nullptr;
}
return std::forward<T>(t);
}
std::vector<std::string> jniStringArrayToVector(JNIEnv* env,
jobjectArray string_array) {
int count = env->GetArrayLength(string_array);
std::vector<std::string> result;
for (int i = 0; i < count; i++) {
auto jstr =
reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i));
const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
result.emplace_back(std::string(raw_str));
env->ReleaseStringUTFChars(jstr, raw_str);
}
return result;
}
struct JNIStorage {
std::vector<std::string> backoff_list;
std::unique_ptr<::tflite::FlatBufferModel> model;
};
extern "C" JNIEXPORT jlong JNICALL
Java_com_example_android_smartreply_SmartReplyClient_loadJNI(
JNIEnv* env, jobject thiz, jobject model_buffer,
jobjectArray backoff_list) {
const char* buf =
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
JNIStorage* storage = new JNIStorage;
storage->model = tflite::FlatBufferModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity));
storage->backoff_list = jniStringArrayToVector(env, backoff_list);
if (!storage->model) {
delete storage;
env->ThrowNew(env->FindClass(kIllegalStateException), "");
return 0;
}
return reinterpret_cast<jlong>(storage);
}
extern "C" JNIEXPORT jobjectArray JNICALL
Java_com_example_android_smartreply_SmartReplyClient_predictJNI(
JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) {
// Predict
if (storage_ptr == 0) {
return nullptr;
}
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
if (storage == nullptr) {
return nullptr;
}
std::vector<PredictorResponse> responses;
GetSegmentPredictions(jniStringArrayToVector(env, input_text),
*storage->model, {storage->backoff_list}, &responses);
// Create a SmartReply[] to return back to Java
jclass smart_reply_class = CheckNotNull(
env, env->FindClass("com/example/android/smartreply/SmartReply"));
if (env->ExceptionCheck()) {
return nullptr;
}
jmethodID smart_reply_ctor = CheckNotNull(
env,
env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V"));
if (env->ExceptionCheck()) {
return nullptr;
}
jobjectArray array = CheckNotNull(
env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr));
if (env->ExceptionCheck()) {
return nullptr;
}
for (int i = 0; i < responses.size(); i++) {
jstring text =
CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data()));
if (env->ExceptionCheck()) {
return nullptr;
}
jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text,
responses[i].GetScore());
env->SetObjectArrayElement(array, i, reply);
}
return array;
}
extern "C" JNIEXPORT void JNICALL
Java_com_example_android_smartreply_SmartReplyClient_unloadJNI(
JNIEnv* env, jobject thiz, jlong storage_ptr) {
if (storage_ptr != 0) {
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
delete storage;
}
}

View File

@ -1,146 +0,0 @@
# Smart Reply Model
## What is On-Device Smart Reply Model?
Smart Replies are contextually relevant, one-touch responses that help the user
to reply to an incoming text message (or email) efficiently and effortlessly.
Smart Replies have been highly successful across several Google products
including
[Gmail](https://www.blog.google/products/gmail/save-time-with-smart-reply-in-gmail/),
[Inbox](https://www.blog.google/products/gmail/computer-respond-to-this-email/)
and
[Allo](https://blog.google/products/allo/google-allo-smarter-messaging-app/).
The On-device Smart Reply model is targeted towards text chat use cases. It has
a completely different architecture from its cloud-based counterparts, and is
built specifically for memory constraints devices such as phones & watches. It
has been successfully used to provide [Smart Replies on Android
Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
to all first- & third-party apps.
The on-device model comes with several benefits. It is:
* **Faster**: The model resides on the device and does not require internet
connectivity. Thus, the inference is very fast and has an average latency of
only a few milliseconds.
* **Resource efficient**: The model has a small memory footprint on
the device.
* **Privacy-friendly**: The user data never leaves the device and this
eliminates any privacy restrictions.
A caveat, though, is that the on-device model has lower triggering rate than its
cloud counterparts (triggering rate is the percentage of times the model
suggests a response for an incoming message).
## When to use this Model?
The On-Device Smart Reply model is aimed towards improving the messaging
experience for day-to-day conversational chat messages. We recommend using this
model for similar use cases. Some sample messages on which the model does well
are provided in this [tsv
file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/testdata/smartreply_samples.tsv)
for reference. The file format is:
```
{incoming_message smart_reply1 [smart_reply2] [smart_reply3]}
```
For the current model, we see a triggering rate of about 30-40% for messages
which are similar to those provided in the tsv file above.
In case the model does not trigger any response, the system falls back to
suggesting replies from a fixed back-off set that was compiled from popular
response intents observed in chat conversations. Some of the fallback responses
are `Ok, Yes, No, 👍, ☺`.
The model can only be used for inference at this time (i.e. it cannot be custom
trained). If you are interested to know how the model was trained, please refer
to this [blog
post](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
and [research paper](https://arxiv.org/pdf/1708.00630).
## How to use this Model?
We have provided a pre-built demo APK that you can download, install and test on
your phone
([demo APK here](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)).
The On-Device Smart Reply demo App works in the following way:
1. Android app links to the JNI binary with a predictor library.
2. In the predictor library, `GetSegmentPredictions` is called with a list of input
strings.
2.1 The input string can be 1-3 most recent messages of the conversations in
form of string vector. The model will run on these input sentences and
provide Smart Replies corresponding to them.
2.2 The function performs some preprocessing on input data which includes:
* Sentence splitting: The input message will be split into sentences if
message has more than one sentence. Eg: a message like “How are you?
Want to grab lunch?” will be broken down into 2 different sentences.
* Normalization: The individual sentences will be normalized by converting
them into lower cases, removing unnecessary punctuations, etc. Eg: “how
are you????” will be converted to “how are you?” (refer for NORMALIZE op
for more details).
The input string content will be converted to tensors.
2.3 The function then runs the prediction model on the input tensors.
2.4 The function also performs some post-processing which includes
aggregating the model predictions for the input sentences from 2.2 and
returning the appropriate responses.
3. Finally, it gets response(s) from `std::vector<PredictorResponse>`, and
returns back to Android app. Responses are sorted in descending order of
confidence score.
## Ops and Functionality Supported
Following are the ops supported for using On-Device Smart Reply model:
* **NORMALIZE**
This is a custom op which normalizes the sentences by:
* Converting all sentences into lower case.
* Removing unnecessary punctuations (eg: “how are you????” → “how are
you?”).
* Expanding sentences wherever necessary (eg: “ Im home” → “I am home”).
* **SKIP_GRAM**
This is an op inside TensorFlow Lite that converts sentences into a list of
skip grams. The configurable parameters are `ngram_size` and
`max_skip_size`. For the model provided, the values for these parameters are
set to 3 & 2 respectively.
* **EXTRACT_FEATURES**
This is a custom op that hashes skip grams to features represented as
integers. Longer skip-grams are allocated higher weights.
* **LSH_PROJECTION**
This is an op inside TensorFlow Lite that projects input features to a
corresponding bit vector space using Locality Sensitive Hashing (LSH).
* **PREDICT**
This is a custom op that runs the input features through the projection
model (details [here](https://arxiv.org/pdf/1708.00630.pdf)), computes the
appropriate response labels along with weights for the projected features,
and aggregates the response labels and weights together.
* **HASHTABLE_LOOKUP**
This is an op inside TensorFlow Lite that uses label id from predict op and
looks up the response text from the given label id.
## Further Information
* Open source code
[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/).

View File

@ -1,121 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Convert a list of strings to integers via hashing.
// Input:
// Input[0]: A list of ngrams. string[num of input]
//
// Output:
// Output[0]: Hashed features. int32[num of input]
// Output[1]: Weights. float[num of input]
#include <algorithm>
#include <map>
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
#include <farmhash.h>
namespace tflite {
namespace ops {
namespace custom {
namespace extract {
static const int kMaxDimension = 1000000;
static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"};
bool Equals(const string& x, const tflite::StringRef& strref) {
if (strref.len != x.length()) {
return false;
}
if (strref.len > 0) {
int r = memcmp(strref.str, x.data(), strref.len);
return r == 0;
}
return true;
}
bool IsValidNgram(const tflite::StringRef& strref) {
for (const auto& s : kBlacklistNgram) {
if (Equals(s, strref)) {
return false;
}
}
return true;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
int dim = input->dims->data[0];
if (dim == 0) {
// TFLite non-string output should have size greater than 0.
dim = 1;
}
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString);
TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
outputSize1->data[0] = dim;
outputSize2->data[0] = dim;
context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1);
context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
int num_strings = tflite::GetStringCount(input);
TfLiteTensor* label = GetOutput(context, node, 0);
TfLiteTensor* weight = GetOutput(context, node, 1);
int32_t* label_data = GetTensorData<int32_t>(label);
float* weight_data = GetTensorData<float>(weight);
std::map<int64_t, int> feature_id_counts;
for (int i = 0; i < num_strings; i++) {
// Use fingerprint of feature name as id.
auto strref = tflite::GetString(input, i);
if (!IsValidNgram(strref)) {
label_data[i] = 0;
weight_data[i] = 0;
continue;
}
int64_t feature_id =
::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
label_data[i] = static_cast<int32_t>(feature_id);
weight_data[i] = std::count(strref.str, strref.str + strref.len, ' ') + 1;
}
// Explicitly set an empty result to make preceding ops run.
if (num_strings == 0) {
label_data[0] = 0;
weight_data[0] = 0;
}
return kTfLiteOk;
}
} // namespace extract
TfLiteRegistration* Register_EXTRACT_FEATURES() {
static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare,
extract::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -1,100 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include <farmhash.h>
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_EXTRACT_FEATURES();
namespace {
using ::testing::ElementsAre;
class ExtractFeatureOpModel : public SingleOpModel {
public:
explicit ExtractFeatureOpModel(const std::vector<string>& input) {
input_ = AddInput(TensorType_STRING);
signature_ = AddOutput(TensorType_INT32);
weight_ = AddOutput(TensorType_FLOAT32);
SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
BuildInterpreter({{static_cast<int>(input.size())}});
PopulateStringTensor(input_, input);
}
std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
private:
int input_;
int signature_;
int weight_;
};
int CalcFeature(const string& str) {
return ::util::Fingerprint64(str) % 1000000;
}
TEST(ExtractFeatureOpTest, RegularInput) {
ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
m.Invoke();
EXPECT_THAT(m.GetSignature(),
ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
CalcFeature("Hi !"), CalcFeature("!"),
CalcFeature("! <E>"), 0));
EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
}
TEST(ExtractFeatureOpTest, OneInput) {
ExtractFeatureOpModel m({"Hi"});
m.Invoke();
EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
EXPECT_THAT(m.GetWeight(), ElementsAre(1));
}
TEST(ExtractFeatureOpTest, ZeroInput) {
ExtractFeatureOpModel m({});
m.Invoke();
EXPECT_THAT(m.GetSignature(), ElementsAre(0));
EXPECT_THAT(m.GetWeight(), ElementsAre(0));
}
TEST(ExtractFeatureOpTest, AllBlacklistInput) {
ExtractFeatureOpModel m({"<S>", "<E>"});
m.Invoke();
EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1,108 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Normalize the string input.
//
// Input:
// Input[0]: One sentence. string[1]
//
// Output:
// Output[0]: Normalized sentence. string[1]
//
#include <algorithm>
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/strip.h"
#include "re2/re2.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace normalize {
// Predictor transforms.
const char kPunctuationsRegex[] = "[.*()\"]";
const std::map<string, string>* kRegexTransforms =
new std::map<string, string>({
{"([^\\s]+)n't", "\\1 not"},
{"([^\\s]+)'nt", "\\1 not"},
{"([^\\s]+)'ll", "\\1 will"},
{"([^\\s]+)'re", "\\1 are"},
{"([^\\s]+)'ve", "\\1 have"},
{"i'm", "i am"},
});
static const char kStartToken[] = "<S>";
static const char kEndToken[] = "<E>";
static const int32_t kMaxInputChars = 300;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
absl::StripAsciiWhitespace(&result);
// Do not remove commas, semi-colons or colons from the sentences as they can
// indicate the beginning of a new clause.
RE2::GlobalReplace(&result, kPunctuationsRegex, "");
RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
"\\1\\2");
RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
iter++) {
RE2::GlobalReplace(&result, iter->first, iter->second);
}
// Treat questions & interjections as special cases.
RE2::GlobalReplace(&result, "([?])+", "\\1");
RE2::GlobalReplace(&result, "([!])+", "\\1");
RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
absl::StripAsciiWhitespace(&result);
// Add start and end token.
// Truncate input to maximum allowed size.
if (result.length() <= kMaxInputChars) {
absl::StrAppend(&result, " ", kEndToken);
} else {
result = result.substr(0, kMaxInputChars);
}
result = absl::StrCat(kStartToken, " ", result);
tflite::DynamicBuffer buf;
buf.AddString(result.data(), result.length());
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
return kTfLiteOk;
}
} // namespace normalize
TfLiteRegistration* Register_NORMALIZE() {
static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -1,90 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_NORMALIZE();
namespace {
using ::testing::ElementsAreArray;
class NormalizeOpModel : public SingleOpModel {
public:
explicit NormalizeOpModel(const string& input) {
input_ = AddInput(TensorType_STRING);
output_ = AddOutput(TensorType_STRING);
SetCustomOp("Normalize", {}, Register_NORMALIZE);
BuildInterpreter({{static_cast<int>(input.size())}});
PopulateStringTensor(input_, {input});
}
std::vector<string> GetStringOutput() {
TfLiteTensor* output = interpreter_->tensor(output_);
int num = GetStringCount(output);
std::vector<string> result(num);
for (int i = 0; i < num; i++) {
auto ref = GetString(output, i);
result[i] = string(ref.str, ref.len);
}
return result;
}
private:
int input_;
int output_;
};
TEST(NormalizeOpTest, RegularInput) {
NormalizeOpModel m("I'm good; you're welcome");
m.Invoke();
EXPECT_THAT(m.GetStringOutput(),
ElementsAreArray({"<S> i am good; you are welcome <E>"}));
}
TEST(NormalizeOpTest, OneInput) {
NormalizeOpModel m("Hi!!!!");
m.Invoke();
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"}));
}
TEST(NormalizeOpTest, EmptyInput) {
NormalizeOpModel m("");
m.Invoke();
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"}));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1,176 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Lookup projected hash signatures in Predictor model,
// output predicted labels and weights in decreasing order.
//
// Input:
// Input[0]: A list of hash signatures. int32[num of input]
// Input[1]: Hash signature keys in the model. int32[keys of model]
// Input[2]: Labels in the model. int32[keys of model, item per entry]
// Input[3]: Weights in the model. float[keys of model, item per entry]
//
// Output:
// Output[0]: Predicted labels. int32[num of output]
// Output[1]: Predicted weights. float[num of output]
//
#include <algorithm>
#include <cstdlib>
#include <cstdio>
#include <unordered_map>
#include <vector>
#include "tensorflow/lite/context.h"
namespace tflite {
namespace ops {
namespace custom {
namespace predict {
struct PredictOption {
int32_t num_output;
float weight_threshold;
static PredictOption* Cast(void* ptr) {
return reinterpret_cast<PredictOption*>(ptr);
}
};
bool WeightGreater(const std::pair<int32_t, float>& a,
const std::pair<int32_t, float>& b) {
return a.second > b.second;
}
void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
if (custom_option == nullptr || length != sizeof(PredictOption)) {
fprintf(stderr, "No Custom option set\n");
exit(1);
}
PredictOption* option = new PredictOption;
int offset = 0;
option->num_output =
*reinterpret_cast<const int32_t*>(custom_option + offset);
offset += sizeof(int32_t);
option->weight_threshold =
*reinterpret_cast<const float*>(custom_option + offset);
return reinterpret_cast<void*>(option);
}
void Free(TfLiteContext* context, void* buffer) {
delete PredictOption::Cast(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
model_label->dims->data[0]);
TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
model_weight->dims->data[0]);
TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
model_weight->dims->data[1]);
PredictOption* option = PredictOption::Cast(node->user_data);
TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
label_size->data[0] = option->num_output;
TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
weight_size->data[0] = option->num_output;
TfLiteStatus status =
context->ResizeTensor(context, output_label, label_size);
if (status != kTfLiteOk) {
return status;
}
return context->ResizeTensor(context, output_weight, weight_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
// Aggregate by key
std::unordered_map<int32_t, float> aggregation;
const int num_input = lookup->dims->data[0];
const int num_rows = model_key->dims->data[0];
const int items = model_label->dims->data[1];
int* model_key_end = model_key->data.i32 + num_rows;
for (int i = 0; i < num_input; i++) {
int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
lookup->data.i32[i]);
if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
int idx = ptr - model_key->data.i32;
for (int j = 0; j < items; j++) {
aggregation[model_label->data.i32[idx * items + j]] +=
model_weight->data.f[idx * items + j] / num_input;
}
}
}
// Sort by value
std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
aggregation.end());
std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
PredictOption* option = PredictOption::Cast(node->user_data);
TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
for (int i = 0; i < output_label->dims->data[0]; i++) {
if (i >= sorted_labels.size() ||
sorted_labels[i].second < option->weight_threshold) {
// Set -1 to avoid lookup message with id 0, which is set for backoff.
output_label->data.i32[i] = -1;
output_weight->data.f[i] = 0.0f;
} else {
output_label->data.i32[i] = sorted_labels[i].first;
output_weight->data.f[i] = sorted_labels[i].second;
}
}
return kTfLiteOk;
}
} // namespace predict
TfLiteRegistration* Register_PREDICT() {
static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
predict::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -1,183 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_PREDICT();
namespace {
using ::testing::ElementsAreArray;
class PredictOpModel : public SingleOpModel {
public:
PredictOpModel(std::initializer_list<int> input_signature_shape,
std::initializer_list<int> key_shape,
std::initializer_list<int> labelweight_shape, int num_output,
float threshold) {
input_signature_ = AddInput(TensorType_INT32);
model_key_ = AddInput(TensorType_INT32);
model_label_ = AddInput(TensorType_INT32);
model_weight_ = AddInput(TensorType_FLOAT32);
output_label_ = AddOutput(TensorType_INT32);
output_weight_ = AddOutput(TensorType_FLOAT32);
std::vector<uint8_t> predict_option;
writeInt32(num_output, &predict_option);
writeFloat32(threshold, &predict_option);
SetCustomOp("Predict", predict_option, Register_PREDICT);
BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
labelweight_shape}});
}
void SetInputSignature(std::initializer_list<int> data) {
PopulateTensor<int>(input_signature_, data);
}
void SetModelKey(std::initializer_list<int> data) {
PopulateTensor<int>(model_key_, data);
}
void SetModelLabel(std::initializer_list<int> data) {
PopulateTensor<int>(model_label_, data);
}
void SetModelWeight(std::initializer_list<float> data) {
PopulateTensor<float>(model_weight_, data);
}
std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
std::vector<float> GetWeight() {
return ExtractVector<float>(output_weight_);
}
void writeFloat32(float value, std::vector<uint8_t>* data) {
union {
float v;
uint8_t r[4];
} float_to_raw;
float_to_raw.v = value;
for (unsigned char i : float_to_raw.r) {
data->push_back(i);
}
}
void writeInt32(int32_t value, std::vector<uint8_t>* data) {
union {
int32_t v;
uint8_t r[4];
} int32_to_raw;
int32_to_raw.v = value;
for (unsigned char i : int32_to_raw.r) {
data->push_back(i);
}
}
private:
int input_signature_;
int model_key_;
int model_label_;
int model_weight_;
int output_label_;
int output_weight_;
};
TEST(PredictOpTest, AllLabelsAreValid) {
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
m.SetInputSignature({1, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
}
TEST(PredictOpTest, MoreLabelsThanRequired) {
PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
m.SetInputSignature({1, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
}
TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
m.SetInputSignature({1, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
}
TEST(PredictOpTest, NoneLabelPassThreshold) {
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
m.SetInputSignature({1, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
}
TEST(PredictOpTest, OnlyOneLabelGenerated) {
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
m.SetInputSignature({1, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
}
TEST(PredictOpTest, NoLabelGenerated) {
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
m.SetInputSignature({5, 3, 7, 9});
m.SetModelKey({1, 2, 4, 6, 7});
m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
m.Invoke();
EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1,117 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/models/smartreply/predictor.h"
#include "absl/strings/str_split.h"
#include "re2/re2.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/string_util.h"
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
namespace tflite {
namespace custom {
namespace smartreply {
// Split sentence into segments (using punctuation).
std::vector<std::string> SplitSentence(const std::string& input) {
string result(input);
RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
RE2::GlobalReplace(&result, "[ ]+", " ");
RE2::GlobalReplace(&result, "\t+$", "");
return absl::StrSplit(result, '\t');
}
// Predict with TfLite model.
void ExecuteTfLite(const std::string& sentence,
::tflite::Interpreter* interpreter,
std::map<std::string, float>* response_map) {
{
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
tflite::DynamicBuffer buf;
buf.AddString(sentence.data(), sentence.length());
buf.WriteToTensorAsVector(input);
interpreter->AllocateTensors();
interpreter->Invoke();
TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
for (int i = 0; i < confidence->dims->data[0]; i++) {
float weight = confidence->data.f[i];
auto response_text = tflite::GetString(messages, i);
if (response_text.len > 0) {
(*response_map)[string(response_text.str, response_text.len)] += weight;
}
}
}
}
void GetSegmentPredictions(
const std::vector<std::string>& input,
const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses) {
// Initialize interpreter
std::unique_ptr<::tflite::Interpreter> interpreter;
::tflite::MutableOpResolver resolver;
RegisterSelectedOps(&resolver);
::tflite::InterpreterBuilder(model, resolver)(&interpreter);
if (!model.initialized()) {
fprintf(stderr, "Failed to mmap model \n");
return;
}
// Execute Tflite Model
std::map<std::string, float> response_map;
std::vector<std::string> sentences;
for (const std::string& str : input) {
std::vector<std::string> splitted_str = SplitSentence(str);
sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
}
for (const auto& sentence : sentences) {
ExecuteTfLite(sentence, interpreter.get(), &response_map);
}
// Generate the result.
for (const auto& iter : response_map) {
PredictorResponse prediction(iter.first, iter.second);
predictor_responses->emplace_back(prediction);
}
std::sort(predictor_responses->begin(), predictor_responses->end(),
[](const PredictorResponse& a, const PredictorResponse& b) {
return a.GetScore() > b.GetScore();
});
// Add backoff response.
for (const auto& backoff : config.backoff_responses) {
if (predictor_responses->size() >= config.num_response) {
break;
}
predictor_responses->emplace_back(backoff, config.backoff_confidence);
}
}
} // namespace smartreply
} // namespace custom
} // namespace tflite

View File

@ -1,80 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
#define TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
#include <string>
#include <vector>
#include "tensorflow/lite/model.h"
namespace tflite {
namespace custom {
namespace smartreply {
const int kDefaultNumResponse = 10;
const float kDefaultBackoffConfidence = 1e-4;
class PredictorResponse;
struct SmartReplyConfig;
// With a given string as input, predict the response with a Tflite model.
// When config.backoff_response is not empty, predictor_responses will be filled
// with messagees from backoff response.
void GetSegmentPredictions(const std::vector<std::string>& input,
const ::tflite::FlatBufferModel& model,
const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses);
// Data object used to hold a single predictor response.
// It includes messages, and confidence.
class PredictorResponse {
public:
PredictorResponse(const std::string& response_text, float score) {
response_text_ = response_text;
prediction_score_ = score;
}
// Accessor methods.
const std::string& GetText() const { return response_text_; }
float GetScore() const { return prediction_score_; }
private:
std::string response_text_ = "";
float prediction_score_ = 0.0;
};
// Configurations for SmartReply.
struct SmartReplyConfig {
// Maximum responses to return.
int num_response;
// Default confidence for backoff responses.
float backoff_confidence;
// Backoff responses are used when predicted responses cannot fulfill the
// list.
std::vector<std::string> backoff_responses;
SmartReplyConfig(const std::vector<std::string>& backoff_responses)
: num_response(kDefaultNumResponse),
backoff_confidence(kDefaultBackoffConfidence),
backoff_responses(backoff_responses) {}
};
} // namespace smartreply
} // namespace custom
} // namespace tflite
#endif // TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_

View File

@ -1,163 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/models/smartreply/predictor.h"
#include <fstream>
#include <unordered_set>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/lite/string_util.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace custom {
namespace smartreply {
namespace {
const char kSamples[] = "smartreply_samples.tsv";
string GetModelFilePath() {
return "external/tflite_smartreply/smartreply.tflite"; // NOLINT
}
string GetSamplesFilePath() {
return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
"lite/models/testdata/", kSamples));
}
MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
bool has_expected_response = false;
for (const auto &item : *arg) {
const string &response = item.GetText();
if (expected_response.find(response) != expected_response.end()) {
has_expected_response = true;
break;
}
}
return has_expected_response;
}
class PredictorTest : public ::testing::Test {
protected:
PredictorTest() {}
~PredictorTest() override {}
void SetUp() override {
model_ = tflite::FlatBufferModel::BuildFromFile(GetModelFilePath().c_str());
ASSERT_NE(model_.get(), nullptr);
}
std::unique_ptr<::tflite::FlatBufferModel> model_;
};
TEST_F(PredictorTest, GetSegmentPredictions) {
std::vector<PredictorResponse> predictions;
GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
EXPECT_GT(predictions.size(), 0);
float max = 0;
for (const auto &item : predictions) {
if (item.GetScore() > max) {
max = item.GetScore();
}
}
EXPECT_GT(max, 0.3);
EXPECT_THAT(
&predictions,
IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
}
TEST_F(PredictorTest, TestTwoSentences) {
std::vector<PredictorResponse> predictions;
GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
&predictions);
EXPECT_GT(predictions.size(), 0);
float max = 0;
for (const auto &item : predictions) {
if (item.GetScore() > max) {
max = item.GetScore();
}
}
EXPECT_GT(max, 0.3);
EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
{"Hi, how are you doing?"})));
}
TEST_F(PredictorTest, TestBackoff) {
std::vector<PredictorResponse> predictions;
GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions);
EXPECT_EQ(predictions.size(), 0);
// Backoff responses are returned in order.
GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}},
&predictions);
EXPECT_EQ(predictions.size(), 2);
EXPECT_EQ(predictions[0].GetText(), "Yes");
EXPECT_EQ(predictions[1].GetText(), "Ok");
}
TEST_F(PredictorTest, BatchTest) {
int total_items = 0;
int total_responses = 0;
int total_triggers = 0;
string line;
std::ifstream fin(GetSamplesFilePath());
while (std::getline(fin, line)) {
const std::vector<string> fields = absl::StrSplit(line, '\t');
if (fields.empty()) {
continue;
}
// Parse sample file and predict
const string &msg = fields[0];
std::vector<PredictorResponse> predictions;
GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
// Validate response and generate stats.
total_items++;
total_responses += predictions.size();
if (!predictions.empty()) {
total_triggers++;
}
EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
fields.begin() + 1, fields.end())));
}
EXPECT_EQ(total_triggers, total_items);
EXPECT_GE(total_responses, total_triggers);
}
} // namespace
} // namespace smartreply
} // namespace custom
} // namespace tflite
int main(int argc, char **argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -1,50 +0,0 @@
any chance ur free tonight Maybe not
any updates? No update yet
anything i can do to help? No, but thanks No, but thank you No, but thanks for asking
be safe. I will be Will do my best Thanks, I will
congratulations Thanks thanks Congratulations
cool, let me know when you have time Cool Yes very cool Yeah, cool
drive safe Thank you, I will Home now I will thanks
hang in there, you'll be okay Doing my best Of course we will
happy birthday! Hey, thanks
happy new year! Wish you the same Thanks and same to you
have a safe flight Thanks, love you too Safe travels
hey What is up? How it going? Can I help you?
hey, got a sec? What is up? How it going? Can I help you?
how are you doing? Great and you? I am doing great
how are you feeling Feeling okay A little better Much much better
how was your weekend? It was real good
how you doing Okay and you
hugs. So sweet Thanks sweetie Take care of yourself
i'm bored Sorry to hear that Join the club No you are not
i'm planning on coming next week. let me know if that works. Works Perfect, thanks
i'm sick Sorry to hear that
i'm so happy for you Thanks me too
i'm so hungry Haha me too
i'm sorry No I am sorry Why sorry? No worries love
i'm sorry, i'm going to have to cancel. No I am sorry Why sorry? No worries love
is there anything i can do to help? No, but thanks No, but thanks for asking
lunch? Yes coming
okay. lemme know as soon as you find out. Any more questions? It is done
omg amazing So amazing
on my way Okay see you soon Cool, see you soon Oh wow, ok
oops, mistexted. Oops Haha, oh well That was funny
safe travels. Thanks, love you too Safe travels
so sorry So sorry
sorry, i can't. No worries at all Sorry what?
sorry, i can't do saturday No worries at all
thank you so much. You are so welcome You are so very welcome You are most welcome
thanks for coming It was my pleasure
thanks, this has been great. Glad to help So happy for you
tomorrow would be ideal. Yes it would
tried calling Try again?
ugh, my flight is delayed. Ugh indeed
what are you guys up to tonight? Nothing planned
what day works best for you Any day
what do you want for dinner Your call Whatever is fine
what time will you be home? Not sure why
where are you?!? At my house
wish you were here. I wish the same Me too honey
you're amazing You are too You are amazing I am
you're marvelous You are too
you're the best. I do my best You are the best Well, I try
Can't render this file because it has a wrong number of fields in line 3.

View File

@ -855,16 +855,6 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
], ],
) )
tf_http_archive(
name = "tflite_smartreply",
build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
],
)
tf_http_archive( tf_http_archive(
name = "tflite_ovic_testdata", name = "tflite_ovic_testdata",
build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"), build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),