Clean up for smart reply demo. As folder tensorflow/lite/models/
becomes empty, delete parent folder meanwhile.
The whole project has been moved to: https://github.com/tensorflow/examples/tree/master/lite/examples/smart_reply/android (Please see smart reply web page for details: https://www.tensorflow.org/lite/models/smart_reply/overview) PiperOrigin-RevId: 288242509 Change-Id: Iabce717849f91cd01242634a090c0d936412e834
This commit is contained in:
parent
81fcd56fb3
commit
afbfc6a092
@ -1,12 +0,0 @@
|
|||||||
# Model tests
|
|
||||||
package(
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
exports_files(glob([
|
|
||||||
"testdata/*",
|
|
||||||
]))
|
|
@ -1,117 +0,0 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
|
||||||
load("//tensorflow/lite:build_def.bzl", "gen_selected_ops", "tflite_copts")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = [
|
|
||||||
"//visibility:public",
|
|
||||||
],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
gen_selected_ops(
|
|
||||||
name = "smartreply_ops",
|
|
||||||
model = ["@tflite_smartreply//:smartreply.tflite"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "custom_ops",
|
|
||||||
srcs = [
|
|
||||||
"ops/extract_feature.cc",
|
|
||||||
"ops/normalize.cc",
|
|
||||||
"ops/predict.cc",
|
|
||||||
":smartreply_ops",
|
|
||||||
],
|
|
||||||
copts = tflite_copts(),
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/kernels:kernel_util",
|
|
||||||
"//tensorflow/lite/kernels/internal:tensor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_googlesource_code_re2//:re2",
|
|
||||||
"@farmhash_archive//:farmhash",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "predictor_lib",
|
|
||||||
srcs = ["predictor.cc"],
|
|
||||||
hdrs = ["predictor.h"],
|
|
||||||
copts = tflite_copts(),
|
|
||||||
deps = [
|
|
||||||
":custom_ops",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_googlesource_code_re2//:re2",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(b/118895218): Make this test compatible with oss.
|
|
||||||
tf_cc_test(
|
|
||||||
name = "predictor_test",
|
|
||||||
srcs = ["predictor_test.cc"],
|
|
||||||
data = [
|
|
||||||
"//tensorflow/lite/models:testdata/smartreply_samples.tsv",
|
|
||||||
"@tflite_smartreply//:smartreply.tflite",
|
|
||||||
],
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":predictor_lib",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
"//tensorflow/lite/testing:util",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "extract_feature_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["ops/extract_feature_test.cc"],
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":custom_ops",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/kernels:test_util",
|
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
"@farmhash_archive//:farmhash",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "normalize_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["ops/normalize_test.cc"],
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":custom_ops",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/kernels:test_util",
|
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "predict_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["ops/predict_test.cc"],
|
|
||||||
tags = ["no_oss"],
|
|
||||||
deps = [
|
|
||||||
":custom_ops",
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite:string_util",
|
|
||||||
"//tensorflow/lite/kernels:builtin_ops",
|
|
||||||
"//tensorflow/lite/kernels:test_util",
|
|
||||||
"@com_google_googletest//:gtest",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,38 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
Copyright 2017 The Android Open Source Project
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
-->
|
|
||||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
|
||||||
package="com.example.android.smartreply" >
|
|
||||||
|
|
||||||
<uses-sdk
|
|
||||||
android:minSdkVersion="15"
|
|
||||||
android:targetSdkVersion="24" />
|
|
||||||
|
|
||||||
<application android:label="TfLite SmartReply Demo">
|
|
||||||
<activity
|
|
||||||
android:name="com.example.android.smartreply.MainActivity"
|
|
||||||
android:configChanges="orientation|keyboardHidden|screenSize"
|
|
||||||
android:windowSoftInputMode="stateUnchanged|adjustPan"
|
|
||||||
android:label="TfLite SmartReply Demo"
|
|
||||||
android:screenOrientation="portrait" >
|
|
||||||
<intent-filter>
|
|
||||||
<action android:name="android.intent.action.MAIN" />
|
|
||||||
<category android:name="android.intent.category.LAUNCHER" />
|
|
||||||
</intent-filter>
|
|
||||||
</activity>
|
|
||||||
</application>
|
|
||||||
|
|
||||||
</manifest>
|
|
@ -1,68 +0,0 @@
|
|||||||
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
|
|
||||||
load(
|
|
||||||
"//tensorflow/lite:build_def.bzl",
|
|
||||||
"tflite_copts",
|
|
||||||
"tflite_jni_binary",
|
|
||||||
)
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "assets",
|
|
||||||
srcs = [
|
|
||||||
"@tflite_smartreply//:model_files",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
android_binary(
|
|
||||||
name = "SmartReplyDemo",
|
|
||||||
srcs = glob(["java/**/*.java"]),
|
|
||||||
assets = [":assets"],
|
|
||||||
assets_dir = "",
|
|
||||||
custom_package = "com.example.android.smartreply",
|
|
||||||
manifest = "AndroidManifest.xml",
|
|
||||||
nocompress_extensions = [
|
|
||||||
".tflite",
|
|
||||||
],
|
|
||||||
resource_files = glob(["res/**"]),
|
|
||||||
tags = ["manual"],
|
|
||||||
deps = [
|
|
||||||
":smartreply_runtime",
|
|
||||||
"@androidsdk//com.android.support:support-v13-25.2.0",
|
|
||||||
"@androidsdk//com.android.support:support-v4-25.2.0",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "smartreply_runtime",
|
|
||||||
srcs = ["libsmartreply_jni.so"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tflite_jni_binary(
|
|
||||||
name = "libsmartreply_jni.so",
|
|
||||||
deps = [
|
|
||||||
":smartreply_jni_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "smartreply_jni_lib",
|
|
||||||
srcs = [
|
|
||||||
"smartreply_jni.cc",
|
|
||||||
],
|
|
||||||
copts = tflite_copts(),
|
|
||||||
linkopts = [
|
|
||||||
"-lm",
|
|
||||||
"-ldl",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/lite:framework",
|
|
||||||
"//tensorflow/lite/java/jni",
|
|
||||||
"//tensorflow/lite/models/smartreply:predictor_lib",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
@ -1,16 +0,0 @@
|
|||||||
package(
|
|
||||||
default_visibility = ["//visibility:public"],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
exports_files(glob(["*"]))
|
|
||||||
|
|
||||||
filegroup(
|
|
||||||
name = "assets_files",
|
|
||||||
srcs = glob(
|
|
||||||
["**/*"],
|
|
||||||
exclude = [
|
|
||||||
"BUILD",
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
@ -1,16 +0,0 @@
|
|||||||
Ok
|
|
||||||
Yes
|
|
||||||
No
|
|
||||||
👍
|
|
||||||
☺
|
|
||||||
😟
|
|
||||||
❤️
|
|
||||||
Lol
|
|
||||||
Thanks
|
|
||||||
Got it
|
|
||||||
Done
|
|
||||||
Nice
|
|
||||||
I don't know
|
|
||||||
What?
|
|
||||||
Why?
|
|
||||||
What's up?
|
|
@ -1,99 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
package com.example.android.smartreply;
|
|
||||||
|
|
||||||
import android.app.Activity;
|
|
||||||
import android.os.Bundle;
|
|
||||||
import android.os.Handler;
|
|
||||||
import android.util.Log;
|
|
||||||
import android.view.View;
|
|
||||||
import android.widget.Button;
|
|
||||||
import android.widget.EditText;
|
|
||||||
import android.widget.TextView;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The main (and only) activity of this demo app. Displays a text box which updates as messages are
|
|
||||||
* received.
|
|
||||||
*/
|
|
||||||
public class MainActivity extends Activity {
|
|
||||||
private static final String TAG = "SmartReplyDemo";
|
|
||||||
private SmartReplyClient client;
|
|
||||||
|
|
||||||
private Button sendButton;
|
|
||||||
private TextView messageTextView;
|
|
||||||
private EditText messageInput;
|
|
||||||
|
|
||||||
private Handler handler;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void onCreate(Bundle savedInstanceState) {
|
|
||||||
super.onCreate(savedInstanceState);
|
|
||||||
Log.v(TAG, "onCreate");
|
|
||||||
setContentView(R.layout.main_activity);
|
|
||||||
|
|
||||||
client = new SmartReplyClient(getApplicationContext());
|
|
||||||
handler = new Handler();
|
|
||||||
|
|
||||||
sendButton = (Button) findViewById(R.id.send_button);
|
|
||||||
sendButton.setOnClickListener(
|
|
||||||
(View v) -> {
|
|
||||||
send(messageInput.getText().toString());
|
|
||||||
});
|
|
||||||
|
|
||||||
messageTextView = (TextView) findViewById(R.id.message_text);
|
|
||||||
messageInput = (EditText) findViewById(R.id.message_input);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void onStart() {
|
|
||||||
super.onStart();
|
|
||||||
Log.v(TAG, "onStart");
|
|
||||||
handler.post(
|
|
||||||
() -> {
|
|
||||||
client.loadModel();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void onStop() {
|
|
||||||
super.onStop();
|
|
||||||
Log.v(TAG, "onStop");
|
|
||||||
handler.post(
|
|
||||||
() -> {
|
|
||||||
client.unloadModel();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private void send(final String message) {
|
|
||||||
handler.post(
|
|
||||||
() -> {
|
|
||||||
messageTextView.append("Input: " + message + "\n");
|
|
||||||
|
|
||||||
SmartReply[] ans = client.predict(new String[] {message});
|
|
||||||
for (SmartReply reply : ans) {
|
|
||||||
appendMessage("Reply: " + reply.getText());
|
|
||||||
}
|
|
||||||
appendMessage("------");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private void appendMessage(final String message) {
|
|
||||||
handler.post(
|
|
||||||
() -> {
|
|
||||||
messageTextView.append(message + "\n");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,44 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
package com.example.android.smartreply;
|
|
||||||
|
|
||||||
import android.support.annotation.Keep;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SmartReply contains predicted message, and confidence.
|
|
||||||
*
|
|
||||||
* <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated.
|
|
||||||
*/
|
|
||||||
@Keep
|
|
||||||
public class SmartReply {
|
|
||||||
|
|
||||||
private final String text;
|
|
||||||
private final float score;
|
|
||||||
|
|
||||||
@Keep
|
|
||||||
public SmartReply(String text, float score) {
|
|
||||||
this.text = text;
|
|
||||||
this.score = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getText() {
|
|
||||||
return text;
|
|
||||||
}
|
|
||||||
|
|
||||||
public float getScore() {
|
|
||||||
return score;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,131 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
package com.example.android.smartreply;
|
|
||||||
|
|
||||||
import android.content.Context;
|
|
||||||
import android.content.res.AssetFileDescriptor;
|
|
||||||
import android.support.annotation.Keep;
|
|
||||||
import android.support.annotation.WorkerThread;
|
|
||||||
import android.util.Log;
|
|
||||||
import java.io.BufferedReader;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.InputStreamReader;
|
|
||||||
import java.nio.MappedByteBuffer;
|
|
||||||
import java.nio.channels.FileChannel;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/** Interface to load TfLite model and provide predictions. */
|
|
||||||
public class SmartReplyClient implements AutoCloseable {
|
|
||||||
private static final String TAG = "SmartReplyDemo";
|
|
||||||
private static final String MODEL_PATH = "smartreply.tflite";
|
|
||||||
private static final String BACKOFF_PATH = "backoff_response.txt";
|
|
||||||
private static final String JNI_LIB = "smartreply_jni";
|
|
||||||
|
|
||||||
private final Context context;
|
|
||||||
private long storage;
|
|
||||||
private MappedByteBuffer model;
|
|
||||||
|
|
||||||
private volatile boolean isLibraryLoaded;
|
|
||||||
|
|
||||||
public SmartReplyClient(Context context) {
|
|
||||||
this.context = context;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isLoaded() {
|
|
||||||
return storage != 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@WorkerThread
|
|
||||||
public synchronized void loadModel() {
|
|
||||||
if (!isLibraryLoaded) {
|
|
||||||
try {
|
|
||||||
System.loadLibrary(JNI_LIB);
|
|
||||||
isLibraryLoaded = true;
|
|
||||||
} catch (Exception e) {
|
|
||||||
Log.e(TAG, "Failed to load prebuilt smartreply_jni lib", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
model = loadModelFile();
|
|
||||||
String[] backoff = loadBackoffList();
|
|
||||||
storage = loadJNI(model, backoff);
|
|
||||||
} catch (IOException e) {
|
|
||||||
Log.e(TAG, "Fail to load model", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@WorkerThread
|
|
||||||
public synchronized SmartReply[] predict(String[] input) {
|
|
||||||
if (storage != 0) {
|
|
||||||
return predictJNI(storage, input);
|
|
||||||
} else {
|
|
||||||
return new SmartReply[] {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@WorkerThread
|
|
||||||
public synchronized void unloadModel() {
|
|
||||||
close();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public synchronized void close() {
|
|
||||||
if (storage != 0) {
|
|
||||||
unloadJNI(storage);
|
|
||||||
storage = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private MappedByteBuffer loadModelFile() throws IOException {
|
|
||||||
try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
|
|
||||||
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
|
|
||||||
FileChannel fileChannel = inputStream.getChannel();
|
|
||||||
long startOffset = fileDescriptor.getStartOffset();
|
|
||||||
long declaredLength = fileDescriptor.getDeclaredLength();
|
|
||||||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String[] loadBackoffList() throws IOException {
|
|
||||||
List<String> labelList = new ArrayList<String>();
|
|
||||||
try (BufferedReader reader =
|
|
||||||
new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)))) {
|
|
||||||
String line;
|
|
||||||
while ((line = reader.readLine()) != null) {
|
|
||||||
if (!line.isEmpty()) {
|
|
||||||
labelList.add(line);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
String[] ans = new String[labelList.size()];
|
|
||||||
labelList.toArray(ans);
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Keep
|
|
||||||
private native long loadJNI(MappedByteBuffer buffer, String[] backoff);
|
|
||||||
|
|
||||||
@Keep
|
|
||||||
private native SmartReply[] predictJNI(long storage, String[] text);
|
|
||||||
|
|
||||||
@Keep
|
|
||||||
private native void unloadJNI(long storage);
|
|
||||||
}
|
|
@ -1,44 +0,0 @@
|
|||||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
|
||||||
xmlns:tools="http://schemas.android.com/tools"
|
|
||||||
android:layout_width="match_parent"
|
|
||||||
android:layout_height="match_parent"
|
|
||||||
android:orientation="vertical">
|
|
||||||
|
|
||||||
<LinearLayout
|
|
||||||
android:layout_width="fill_parent"
|
|
||||||
android:layout_height="0dp"
|
|
||||||
android:padding="5dip"
|
|
||||||
android:layout_weight="3">
|
|
||||||
|
|
||||||
<TextView
|
|
||||||
android:id="@+id/message_text"
|
|
||||||
android:layout_width="fill_parent"
|
|
||||||
android:layout_height="fill_parent"
|
|
||||||
android:scrollbars="vertical"
|
|
||||||
android:gravity="bottom"/>
|
|
||||||
</LinearLayout>
|
|
||||||
|
|
||||||
<LinearLayout
|
|
||||||
android:layout_width="fill_parent"
|
|
||||||
android:layout_height="0dp"
|
|
||||||
android:padding="5dip"
|
|
||||||
android:layout_weight="1">
|
|
||||||
|
|
||||||
<EditText
|
|
||||||
android:id="@+id/message_input"
|
|
||||||
android:layout_width="0dp"
|
|
||||||
android:layout_height="fill_parent"
|
|
||||||
android:layout_weight="6"
|
|
||||||
android:scrollbars="vertical"
|
|
||||||
android:hint="Enter Text"
|
|
||||||
android:gravity="top"
|
|
||||||
android:inputType="text"/>
|
|
||||||
<Button
|
|
||||||
android:id="@+id/send_button"
|
|
||||||
android:layout_width="0dp"
|
|
||||||
android:layout_height="fill_parent"
|
|
||||||
android:layout_weight="2"
|
|
||||||
android:text="Send" />
|
|
||||||
</LinearLayout>
|
|
||||||
|
|
||||||
</LinearLayout>
|
|
@ -1,129 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <jni.h>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include "tensorflow/lite/models/smartreply/predictor.h"
|
|
||||||
|
|
||||||
const char kIllegalStateException[] = "java/lang/IllegalStateException";
|
|
||||||
|
|
||||||
using tflite::custom::smartreply::GetSegmentPredictions;
|
|
||||||
using tflite::custom::smartreply::PredictorResponse;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T CheckNotNull(JNIEnv* env, T&& t) {
|
|
||||||
if (t == nullptr) {
|
|
||||||
env->ThrowNew(env->FindClass(kIllegalStateException), "");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return std::forward<T>(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> jniStringArrayToVector(JNIEnv* env,
|
|
||||||
jobjectArray string_array) {
|
|
||||||
int count = env->GetArrayLength(string_array);
|
|
||||||
std::vector<std::string> result;
|
|
||||||
for (int i = 0; i < count; i++) {
|
|
||||||
auto jstr =
|
|
||||||
reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i));
|
|
||||||
const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
|
|
||||||
result.emplace_back(std::string(raw_str));
|
|
||||||
env->ReleaseStringUTFChars(jstr, raw_str);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct JNIStorage {
|
|
||||||
std::vector<std::string> backoff_list;
|
|
||||||
std::unique_ptr<::tflite::FlatBufferModel> model;
|
|
||||||
};
|
|
||||||
|
|
||||||
extern "C" JNIEXPORT jlong JNICALL
|
|
||||||
Java_com_example_android_smartreply_SmartReplyClient_loadJNI(
|
|
||||||
JNIEnv* env, jobject thiz, jobject model_buffer,
|
|
||||||
jobjectArray backoff_list) {
|
|
||||||
const char* buf =
|
|
||||||
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
|
|
||||||
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
|
|
||||||
|
|
||||||
JNIStorage* storage = new JNIStorage;
|
|
||||||
storage->model = tflite::FlatBufferModel::BuildFromBuffer(
|
|
||||||
buf, static_cast<size_t>(capacity));
|
|
||||||
storage->backoff_list = jniStringArrayToVector(env, backoff_list);
|
|
||||||
|
|
||||||
if (!storage->model) {
|
|
||||||
delete storage;
|
|
||||||
env->ThrowNew(env->FindClass(kIllegalStateException), "");
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return reinterpret_cast<jlong>(storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" JNIEXPORT jobjectArray JNICALL
|
|
||||||
Java_com_example_android_smartreply_SmartReplyClient_predictJNI(
|
|
||||||
JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) {
|
|
||||||
// Predict
|
|
||||||
if (storage_ptr == 0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
|
|
||||||
if (storage == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
std::vector<PredictorResponse> responses;
|
|
||||||
GetSegmentPredictions(jniStringArrayToVector(env, input_text),
|
|
||||||
*storage->model, {storage->backoff_list}, &responses);
|
|
||||||
|
|
||||||
// Create a SmartReply[] to return back to Java
|
|
||||||
jclass smart_reply_class = CheckNotNull(
|
|
||||||
env, env->FindClass("com/example/android/smartreply/SmartReply"));
|
|
||||||
if (env->ExceptionCheck()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
jmethodID smart_reply_ctor = CheckNotNull(
|
|
||||||
env,
|
|
||||||
env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V"));
|
|
||||||
if (env->ExceptionCheck()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
jobjectArray array = CheckNotNull(
|
|
||||||
env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr));
|
|
||||||
if (env->ExceptionCheck()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < responses.size(); i++) {
|
|
||||||
jstring text =
|
|
||||||
CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data()));
|
|
||||||
if (env->ExceptionCheck()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text,
|
|
||||||
responses[i].GetScore());
|
|
||||||
env->SetObjectArrayElement(array, i, reply);
|
|
||||||
}
|
|
||||||
return array;
|
|
||||||
}
|
|
||||||
|
|
||||||
extern "C" JNIEXPORT void JNICALL
|
|
||||||
Java_com_example_android_smartreply_SmartReplyClient_unloadJNI(
|
|
||||||
JNIEnv* env, jobject thiz, jlong storage_ptr) {
|
|
||||||
if (storage_ptr != 0) {
|
|
||||||
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
|
|
||||||
delete storage;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,146 +0,0 @@
|
|||||||
# Smart Reply Model
|
|
||||||
|
|
||||||
## What is On-Device Smart Reply Model?
|
|
||||||
|
|
||||||
Smart Replies are contextually relevant, one-touch responses that help the user
|
|
||||||
to reply to an incoming text message (or email) efficiently and effortlessly.
|
|
||||||
Smart Replies have been highly successful across several Google products
|
|
||||||
including
|
|
||||||
[Gmail](https://www.blog.google/products/gmail/save-time-with-smart-reply-in-gmail/),
|
|
||||||
[Inbox](https://www.blog.google/products/gmail/computer-respond-to-this-email/)
|
|
||||||
and
|
|
||||||
[Allo](https://blog.google/products/allo/google-allo-smarter-messaging-app/).
|
|
||||||
|
|
||||||
The On-device Smart Reply model is targeted towards text chat use cases. It has
|
|
||||||
a completely different architecture from its cloud-based counterparts, and is
|
|
||||||
built specifically for memory constraints devices such as phones & watches. It
|
|
||||||
has been successfully used to provide [Smart Replies on Android
|
|
||||||
Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
|
|
||||||
to all first- & third-party apps.
|
|
||||||
|
|
||||||
The on-device model comes with several benefits. It is:
|
|
||||||
|
|
||||||
* **Faster**: The model resides on the device and does not require internet
|
|
||||||
connectivity. Thus, the inference is very fast and has an average latency of
|
|
||||||
only a few milliseconds.
|
|
||||||
* **Resource efficient**: The model has a small memory footprint on
|
|
||||||
the device.
|
|
||||||
* **Privacy-friendly**: The user data never leaves the device and this
|
|
||||||
eliminates any privacy restrictions.
|
|
||||||
|
|
||||||
A caveat, though, is that the on-device model has lower triggering rate than its
|
|
||||||
cloud counterparts (triggering rate is the percentage of times the model
|
|
||||||
suggests a response for an incoming message).
|
|
||||||
|
|
||||||
## When to use this Model?
|
|
||||||
|
|
||||||
The On-Device Smart Reply model is aimed towards improving the messaging
|
|
||||||
experience for day-to-day conversational chat messages. We recommend using this
|
|
||||||
model for similar use cases. Some sample messages on which the model does well
|
|
||||||
are provided in this [tsv
|
|
||||||
file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/testdata/smartreply_samples.tsv)
|
|
||||||
for reference. The file format is:
|
|
||||||
|
|
||||||
```
|
|
||||||
{incoming_message smart_reply1 [smart_reply2] [smart_reply3]}
|
|
||||||
```
|
|
||||||
|
|
||||||
For the current model, we see a triggering rate of about 30-40% for messages
|
|
||||||
which are similar to those provided in the tsv file above.
|
|
||||||
|
|
||||||
In case the model does not trigger any response, the system falls back to
|
|
||||||
suggesting replies from a fixed back-off set that was compiled from popular
|
|
||||||
response intents observed in chat conversations. Some of the fallback responses
|
|
||||||
are `Ok, Yes, No, 👍, ☺`.
|
|
||||||
|
|
||||||
The model can only be used for inference at this time (i.e. it cannot be custom
|
|
||||||
trained). If you are interested to know how the model was trained, please refer
|
|
||||||
to this [blog
|
|
||||||
post](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
|
|
||||||
and [research paper](https://arxiv.org/pdf/1708.00630).
|
|
||||||
|
|
||||||
## How to use this Model?
|
|
||||||
|
|
||||||
We have provided a pre-built demo APK that you can download, install and test on
|
|
||||||
your phone
|
|
||||||
([demo APK here](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)).
|
|
||||||
|
|
||||||
The On-Device Smart Reply demo App works in the following way:
|
|
||||||
|
|
||||||
1. Android app links to the JNI binary with a predictor library.
|
|
||||||
|
|
||||||
2. In the predictor library, `GetSegmentPredictions` is called with a list of input
|
|
||||||
strings.
|
|
||||||
|
|
||||||
2.1 The input string can be 1-3 most recent messages of the conversations in
|
|
||||||
form of string vector. The model will run on these input sentences and
|
|
||||||
provide Smart Replies corresponding to them.
|
|
||||||
|
|
||||||
2.2 The function performs some preprocessing on input data which includes:
|
|
||||||
|
|
||||||
* Sentence splitting: The input message will be split into sentences if
|
|
||||||
message has more than one sentence. Eg: a message like “How are you?
|
|
||||||
Want to grab lunch?” will be broken down into 2 different sentences.
|
|
||||||
* Normalization: The individual sentences will be normalized by converting
|
|
||||||
them into lower cases, removing unnecessary punctuations, etc. Eg: “how
|
|
||||||
are you????” will be converted to “how are you?” (refer for NORMALIZE op
|
|
||||||
for more details).
|
|
||||||
|
|
||||||
The input string content will be converted to tensors.
|
|
||||||
|
|
||||||
2.3 The function then runs the prediction model on the input tensors.
|
|
||||||
|
|
||||||
2.4 The function also performs some post-processing which includes
|
|
||||||
aggregating the model predictions for the input sentences from 2.2 and
|
|
||||||
returning the appropriate responses.
|
|
||||||
|
|
||||||
3. Finally, it gets response(s) from `std::vector<PredictorResponse>`, and
|
|
||||||
returns back to Android app. Responses are sorted in descending order of
|
|
||||||
confidence score.
|
|
||||||
|
|
||||||
## Ops and Functionality Supported
|
|
||||||
|
|
||||||
Following are the ops supported for using On-Device Smart Reply model:
|
|
||||||
|
|
||||||
* **NORMALIZE**
|
|
||||||
|
|
||||||
This is a custom op which normalizes the sentences by:
|
|
||||||
|
|
||||||
* Converting all sentences into lower case.
|
|
||||||
* Removing unnecessary punctuations (eg: “how are you????” → “how are
|
|
||||||
you?”).
|
|
||||||
* Expanding sentences wherever necessary (eg: “ I’m home” → “I am home”).
|
|
||||||
|
|
||||||
* **SKIP_GRAM**
|
|
||||||
|
|
||||||
This is an op inside TensorFlow Lite that converts sentences into a list of
|
|
||||||
skip grams. The configurable parameters are `ngram_size` and
|
|
||||||
`max_skip_size`. For the model provided, the values for these parameters are
|
|
||||||
set to 3 & 2 respectively.
|
|
||||||
|
|
||||||
* **EXTRACT_FEATURES**
|
|
||||||
|
|
||||||
This is a custom op that hashes skip grams to features represented as
|
|
||||||
integers. Longer skip-grams are allocated higher weights.
|
|
||||||
|
|
||||||
* **LSH_PROJECTION**
|
|
||||||
|
|
||||||
This is an op inside TensorFlow Lite that projects input features to a
|
|
||||||
corresponding bit vector space using Locality Sensitive Hashing (LSH).
|
|
||||||
|
|
||||||
* **PREDICT**
|
|
||||||
|
|
||||||
This is a custom op that runs the input features through the projection
|
|
||||||
model (details [here](https://arxiv.org/pdf/1708.00630.pdf)), computes the
|
|
||||||
appropriate response labels along with weights for the projected features,
|
|
||||||
and aggregates the response labels and weights together.
|
|
||||||
|
|
||||||
* **HASHTABLE_LOOKUP**
|
|
||||||
|
|
||||||
This is an op inside TensorFlow Lite that uses label id from predict op and
|
|
||||||
looks up the response text from the given label id.
|
|
||||||
|
|
||||||
## Further Information
|
|
||||||
|
|
||||||
* Open source code
|
|
||||||
[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/models/smartreply/).
|
|
@ -1,121 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Convert a list of strings to integers via hashing.
|
|
||||||
// Input:
|
|
||||||
// Input[0]: A list of ngrams. string[num of input]
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Output[0]: Hashed features. int32[num of input]
|
|
||||||
// Output[1]: Weights. float[num of input]
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <map>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/context.h"
|
|
||||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
#include <farmhash.h>
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
|
|
||||||
namespace extract {
|
|
||||||
|
|
||||||
static const int kMaxDimension = 1000000;
|
|
||||||
static const std::vector<string> kBlacklistNgram = {"<S>", "<E>", "<S> <E>"};
|
|
||||||
|
|
||||||
bool Equals(const string& x, const tflite::StringRef& strref) {
|
|
||||||
if (strref.len != x.length()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (strref.len > 0) {
|
|
||||||
int r = memcmp(strref.str, x.data(), strref.len);
|
|
||||||
return r == 0;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsValidNgram(const tflite::StringRef& strref) {
|
|
||||||
for (const auto& s : kBlacklistNgram) {
|
|
||||||
if (Equals(s, strref)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
|
||||||
int dim = input->dims->data[0];
|
|
||||||
if (dim == 0) {
|
|
||||||
// TFLite non-string output should have size greater than 0.
|
|
||||||
dim = 1;
|
|
||||||
}
|
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteString);
|
|
||||||
TfLiteIntArray* outputSize1 = TfLiteIntArrayCreate(1);
|
|
||||||
TfLiteIntArray* outputSize2 = TfLiteIntArrayCreate(1);
|
|
||||||
outputSize1->data[0] = dim;
|
|
||||||
outputSize2->data[0] = dim;
|
|
||||||
context->ResizeTensor(context, GetOutput(context, node, 0), outputSize1);
|
|
||||||
context->ResizeTensor(context, GetOutput(context, node, 1), outputSize2);
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
|
||||||
int num_strings = tflite::GetStringCount(input);
|
|
||||||
TfLiteTensor* label = GetOutput(context, node, 0);
|
|
||||||
TfLiteTensor* weight = GetOutput(context, node, 1);
|
|
||||||
|
|
||||||
int32_t* label_data = GetTensorData<int32_t>(label);
|
|
||||||
float* weight_data = GetTensorData<float>(weight);
|
|
||||||
|
|
||||||
std::map<int64_t, int> feature_id_counts;
|
|
||||||
for (int i = 0; i < num_strings; i++) {
|
|
||||||
// Use fingerprint of feature name as id.
|
|
||||||
auto strref = tflite::GetString(input, i);
|
|
||||||
if (!IsValidNgram(strref)) {
|
|
||||||
label_data[i] = 0;
|
|
||||||
weight_data[i] = 0;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t feature_id =
|
|
||||||
::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
|
|
||||||
label_data[i] = static_cast<int32_t>(feature_id);
|
|
||||||
weight_data[i] = std::count(strref.str, strref.str + strref.len, ' ') + 1;
|
|
||||||
}
|
|
||||||
// Explicitly set an empty result to make preceding ops run.
|
|
||||||
if (num_strings == 0) {
|
|
||||||
label_data[0] = 0;
|
|
||||||
weight_data[0] = 0;
|
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace extract
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_EXTRACT_FEATURES() {
|
|
||||||
static TfLiteRegistration r = {nullptr, nullptr, extract::Prepare,
|
|
||||||
extract::Eval};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
@ -1,100 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "tensorflow/lite/interpreter.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include <farmhash.h>
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
TfLiteRegistration* Register_EXTRACT_FEATURES();
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using ::testing::ElementsAre;
|
|
||||||
|
|
||||||
class ExtractFeatureOpModel : public SingleOpModel {
|
|
||||||
public:
|
|
||||||
explicit ExtractFeatureOpModel(const std::vector<string>& input) {
|
|
||||||
input_ = AddInput(TensorType_STRING);
|
|
||||||
signature_ = AddOutput(TensorType_INT32);
|
|
||||||
weight_ = AddOutput(TensorType_FLOAT32);
|
|
||||||
|
|
||||||
SetCustomOp("ExtractFeatures", {}, Register_EXTRACT_FEATURES);
|
|
||||||
BuildInterpreter({{static_cast<int>(input.size())}});
|
|
||||||
PopulateStringTensor(input_, input);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> GetSignature() { return ExtractVector<int>(signature_); }
|
|
||||||
std::vector<float> GetWeight() { return ExtractVector<float>(weight_); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
int input_;
|
|
||||||
int signature_;
|
|
||||||
int weight_;
|
|
||||||
};
|
|
||||||
|
|
||||||
int CalcFeature(const string& str) {
|
|
||||||
return ::util::Fingerprint64(str) % 1000000;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ExtractFeatureOpTest, RegularInput) {
|
|
||||||
ExtractFeatureOpModel m({"<S>", "<S> Hi", "Hi", "Hi !", "!", "! <E>", "<E>"});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSignature(),
|
|
||||||
ElementsAre(0, CalcFeature("<S> Hi"), CalcFeature("Hi"),
|
|
||||||
CalcFeature("Hi !"), CalcFeature("!"),
|
|
||||||
CalcFeature("! <E>"), 0));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAre(0, 2, 1, 2, 1, 2, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ExtractFeatureOpTest, OneInput) {
|
|
||||||
ExtractFeatureOpModel m({"Hi"});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSignature(), ElementsAre(CalcFeature("Hi")));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAre(1));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ExtractFeatureOpTest, ZeroInput) {
|
|
||||||
ExtractFeatureOpModel m({});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSignature(), ElementsAre(0));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAre(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(ExtractFeatureOpTest, AllBlacklistInput) {
|
|
||||||
ExtractFeatureOpModel m({"<S>", "<E>"});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetSignature(), ElementsAre(0, 0));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAre(0, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
@ -1,108 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Normalize the string input.
|
|
||||||
//
|
|
||||||
// Input:
|
|
||||||
// Input[0]: One sentence. string[1]
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Output[0]: Normalized sentence. string[1]
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/strip.h"
|
|
||||||
#include "re2/re2.h"
|
|
||||||
#include "tensorflow/lite/context.h"
|
|
||||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
|
|
||||||
namespace normalize {
|
|
||||||
|
|
||||||
// Predictor transforms.
|
|
||||||
const char kPunctuationsRegex[] = "[.*()\"]";
|
|
||||||
|
|
||||||
const std::map<string, string>* kRegexTransforms =
|
|
||||||
new std::map<string, string>({
|
|
||||||
{"([^\\s]+)n't", "\\1 not"},
|
|
||||||
{"([^\\s]+)'nt", "\\1 not"},
|
|
||||||
{"([^\\s]+)'ll", "\\1 will"},
|
|
||||||
{"([^\\s]+)'re", "\\1 are"},
|
|
||||||
{"([^\\s]+)'ve", "\\1 have"},
|
|
||||||
{"i'm", "i am"},
|
|
||||||
});
|
|
||||||
|
|
||||||
static const char kStartToken[] = "<S>";
|
|
||||||
static const char kEndToken[] = "<E>";
|
|
||||||
static const int32_t kMaxInputChars = 300;
|
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
|
|
||||||
|
|
||||||
string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
|
|
||||||
absl::StripAsciiWhitespace(&result);
|
|
||||||
// Do not remove commas, semi-colons or colons from the sentences as they can
|
|
||||||
// indicate the beginning of a new clause.
|
|
||||||
RE2::GlobalReplace(&result, kPunctuationsRegex, "");
|
|
||||||
RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
|
|
||||||
"\\1\\2");
|
|
||||||
RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
|
|
||||||
for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
|
|
||||||
iter++) {
|
|
||||||
RE2::GlobalReplace(&result, iter->first, iter->second);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Treat questions & interjections as special cases.
|
|
||||||
RE2::GlobalReplace(&result, "([?])+", "\\1");
|
|
||||||
RE2::GlobalReplace(&result, "([!])+", "\\1");
|
|
||||||
RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
|
|
||||||
RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
|
|
||||||
|
|
||||||
RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
|
|
||||||
RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
|
|
||||||
absl::StripAsciiWhitespace(&result);
|
|
||||||
|
|
||||||
// Add start and end token.
|
|
||||||
// Truncate input to maximum allowed size.
|
|
||||||
if (result.length() <= kMaxInputChars) {
|
|
||||||
absl::StrAppend(&result, " ", kEndToken);
|
|
||||||
} else {
|
|
||||||
result = result.substr(0, kMaxInputChars);
|
|
||||||
}
|
|
||||||
result = absl::StrCat(kStartToken, " ", result);
|
|
||||||
|
|
||||||
tflite::DynamicBuffer buf;
|
|
||||||
buf.AddString(result.data(), result.length());
|
|
||||||
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace normalize
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_NORMALIZE() {
|
|
||||||
static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
@ -1,90 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "tensorflow/lite/interpreter.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
TfLiteRegistration* Register_NORMALIZE();
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
|
||||||
|
|
||||||
class NormalizeOpModel : public SingleOpModel {
|
|
||||||
public:
|
|
||||||
explicit NormalizeOpModel(const string& input) {
|
|
||||||
input_ = AddInput(TensorType_STRING);
|
|
||||||
output_ = AddOutput(TensorType_STRING);
|
|
||||||
|
|
||||||
SetCustomOp("Normalize", {}, Register_NORMALIZE);
|
|
||||||
BuildInterpreter({{static_cast<int>(input.size())}});
|
|
||||||
PopulateStringTensor(input_, {input});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<string> GetStringOutput() {
|
|
||||||
TfLiteTensor* output = interpreter_->tensor(output_);
|
|
||||||
int num = GetStringCount(output);
|
|
||||||
std::vector<string> result(num);
|
|
||||||
for (int i = 0; i < num; i++) {
|
|
||||||
auto ref = GetString(output, i);
|
|
||||||
result[i] = string(ref.str, ref.len);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int input_;
|
|
||||||
int output_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(NormalizeOpTest, RegularInput) {
|
|
||||||
NormalizeOpModel m("I'm good; you're welcome");
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetStringOutput(),
|
|
||||||
ElementsAreArray({"<S> i am good; you are welcome <E>"}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(NormalizeOpTest, OneInput) {
|
|
||||||
NormalizeOpModel m("Hi!!!!");
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> hi ! <E>"}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(NormalizeOpTest, EmptyInput) {
|
|
||||||
NormalizeOpModel m("");
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"<S> <E>"}));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
@ -1,176 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
// Lookup projected hash signatures in Predictor model,
|
|
||||||
// output predicted labels and weights in decreasing order.
|
|
||||||
//
|
|
||||||
// Input:
|
|
||||||
// Input[0]: A list of hash signatures. int32[num of input]
|
|
||||||
// Input[1]: Hash signature keys in the model. int32[keys of model]
|
|
||||||
// Input[2]: Labels in the model. int32[keys of model, item per entry]
|
|
||||||
// Input[3]: Weights in the model. float[keys of model, item per entry]
|
|
||||||
//
|
|
||||||
// Output:
|
|
||||||
// Output[0]: Predicted labels. int32[num of output]
|
|
||||||
// Output[1]: Predicted weights. float[num of output]
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/context.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
|
|
||||||
namespace predict {
|
|
||||||
|
|
||||||
struct PredictOption {
|
|
||||||
int32_t num_output;
|
|
||||||
float weight_threshold;
|
|
||||||
|
|
||||||
static PredictOption* Cast(void* ptr) {
|
|
||||||
return reinterpret_cast<PredictOption*>(ptr);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
bool WeightGreater(const std::pair<int32_t, float>& a,
|
|
||||||
const std::pair<int32_t, float>& b) {
|
|
||||||
return a.second > b.second;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* Init(TfLiteContext* context, const char* custom_option, size_t length) {
|
|
||||||
if (custom_option == nullptr || length != sizeof(PredictOption)) {
|
|
||||||
fprintf(stderr, "No Custom option set\n");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
PredictOption* option = new PredictOption;
|
|
||||||
int offset = 0;
|
|
||||||
option->num_output =
|
|
||||||
*reinterpret_cast<const int32_t*>(custom_option + offset);
|
|
||||||
offset += sizeof(int32_t);
|
|
||||||
option->weight_threshold =
|
|
||||||
*reinterpret_cast<const float*>(custom_option + offset);
|
|
||||||
return reinterpret_cast<void*>(option);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
|
||||||
delete PredictOption::Cast(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
|
|
||||||
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
|
|
||||||
|
|
||||||
TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
|
|
||||||
TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
|
|
||||||
TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
|
|
||||||
TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
|
|
||||||
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_key->type, kTfLiteInt32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_label->type, kTfLiteInt32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_weight->type, kTfLiteFloat32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, lookup->dims->size, 1);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_key->dims->size, 1);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_label->dims->size, 2);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_weight->dims->size, 2);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
|
|
||||||
model_label->dims->data[0]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_key->dims->data[0],
|
|
||||||
model_weight->dims->data[0]);
|
|
||||||
TF_LITE_ENSURE_EQ(context, model_label->dims->data[1],
|
|
||||||
model_weight->dims->data[1]);
|
|
||||||
|
|
||||||
PredictOption* option = PredictOption::Cast(node->user_data);
|
|
||||||
TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
|
|
||||||
TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
|
|
||||||
TF_LITE_ENSURE_EQ(context, output_label->type, kTfLiteInt32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, output_weight->type, kTfLiteFloat32);
|
|
||||||
|
|
||||||
TfLiteIntArray* label_size = TfLiteIntArrayCreate(1);
|
|
||||||
label_size->data[0] = option->num_output;
|
|
||||||
TfLiteIntArray* weight_size = TfLiteIntArrayCreate(1);
|
|
||||||
weight_size->data[0] = option->num_output;
|
|
||||||
TfLiteStatus status =
|
|
||||||
context->ResizeTensor(context, output_label, label_size);
|
|
||||||
if (status != kTfLiteOk) {
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
return context->ResizeTensor(context, output_weight, weight_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|
||||||
TfLiteTensor* lookup = &context->tensors[node->inputs->data[0]];
|
|
||||||
TfLiteTensor* model_key = &context->tensors[node->inputs->data[1]];
|
|
||||||
TfLiteTensor* model_label = &context->tensors[node->inputs->data[2]];
|
|
||||||
TfLiteTensor* model_weight = &context->tensors[node->inputs->data[3]];
|
|
||||||
|
|
||||||
// Aggregate by key
|
|
||||||
std::unordered_map<int32_t, float> aggregation;
|
|
||||||
const int num_input = lookup->dims->data[0];
|
|
||||||
const int num_rows = model_key->dims->data[0];
|
|
||||||
const int items = model_label->dims->data[1];
|
|
||||||
int* model_key_end = model_key->data.i32 + num_rows;
|
|
||||||
|
|
||||||
for (int i = 0; i < num_input; i++) {
|
|
||||||
int* ptr = std::lower_bound(model_key->data.i32, model_key_end,
|
|
||||||
lookup->data.i32[i]);
|
|
||||||
if (ptr != nullptr && ptr != model_key_end && *ptr == lookup->data.i32[i]) {
|
|
||||||
int idx = ptr - model_key->data.i32;
|
|
||||||
for (int j = 0; j < items; j++) {
|
|
||||||
aggregation[model_label->data.i32[idx * items + j]] +=
|
|
||||||
model_weight->data.f[idx * items + j] / num_input;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort by value
|
|
||||||
std::vector<std::pair<int32_t, float>> sorted_labels(aggregation.begin(),
|
|
||||||
aggregation.end());
|
|
||||||
std::sort(sorted_labels.begin(), sorted_labels.end(), WeightGreater);
|
|
||||||
|
|
||||||
PredictOption* option = PredictOption::Cast(node->user_data);
|
|
||||||
TfLiteTensor* output_label = &context->tensors[node->outputs->data[0]];
|
|
||||||
TfLiteTensor* output_weight = &context->tensors[node->outputs->data[1]];
|
|
||||||
for (int i = 0; i < output_label->dims->data[0]; i++) {
|
|
||||||
if (i >= sorted_labels.size() ||
|
|
||||||
sorted_labels[i].second < option->weight_threshold) {
|
|
||||||
// Set -1 to avoid lookup message with id 0, which is set for backoff.
|
|
||||||
output_label->data.i32[i] = -1;
|
|
||||||
output_weight->data.f[i] = 0.0f;
|
|
||||||
} else {
|
|
||||||
output_label->data.i32[i] = sorted_labels[i].first;
|
|
||||||
output_weight->data.f[i] = sorted_labels[i].second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return kTfLiteOk;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace predict
|
|
||||||
|
|
||||||
TfLiteRegistration* Register_PREDICT() {
|
|
||||||
static TfLiteRegistration r = {predict::Init, predict::Free, predict::Prepare,
|
|
||||||
predict::Eval};
|
|
||||||
return &r;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
@ -1,183 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "tensorflow/lite/interpreter.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#include "tensorflow/lite/kernels/test_util.h"
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
|
|
||||||
namespace ops {
|
|
||||||
namespace custom {
|
|
||||||
TfLiteRegistration* Register_PREDICT();
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
|
||||||
|
|
||||||
class PredictOpModel : public SingleOpModel {
|
|
||||||
public:
|
|
||||||
PredictOpModel(std::initializer_list<int> input_signature_shape,
|
|
||||||
std::initializer_list<int> key_shape,
|
|
||||||
std::initializer_list<int> labelweight_shape, int num_output,
|
|
||||||
float threshold) {
|
|
||||||
input_signature_ = AddInput(TensorType_INT32);
|
|
||||||
model_key_ = AddInput(TensorType_INT32);
|
|
||||||
model_label_ = AddInput(TensorType_INT32);
|
|
||||||
model_weight_ = AddInput(TensorType_FLOAT32);
|
|
||||||
output_label_ = AddOutput(TensorType_INT32);
|
|
||||||
output_weight_ = AddOutput(TensorType_FLOAT32);
|
|
||||||
|
|
||||||
std::vector<uint8_t> predict_option;
|
|
||||||
writeInt32(num_output, &predict_option);
|
|
||||||
writeFloat32(threshold, &predict_option);
|
|
||||||
SetCustomOp("Predict", predict_option, Register_PREDICT);
|
|
||||||
BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape,
|
|
||||||
labelweight_shape}});
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetInputSignature(std::initializer_list<int> data) {
|
|
||||||
PopulateTensor<int>(input_signature_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetModelKey(std::initializer_list<int> data) {
|
|
||||||
PopulateTensor<int>(model_key_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetModelLabel(std::initializer_list<int> data) {
|
|
||||||
PopulateTensor<int>(model_label_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetModelWeight(std::initializer_list<float> data) {
|
|
||||||
PopulateTensor<float>(model_weight_, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); }
|
|
||||||
std::vector<float> GetWeight() {
|
|
||||||
return ExtractVector<float>(output_weight_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void writeFloat32(float value, std::vector<uint8_t>* data) {
|
|
||||||
union {
|
|
||||||
float v;
|
|
||||||
uint8_t r[4];
|
|
||||||
} float_to_raw;
|
|
||||||
float_to_raw.v = value;
|
|
||||||
for (unsigned char i : float_to_raw.r) {
|
|
||||||
data->push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void writeInt32(int32_t value, std::vector<uint8_t>* data) {
|
|
||||||
union {
|
|
||||||
int32_t v;
|
|
||||||
uint8_t r[4];
|
|
||||||
} int32_to_raw;
|
|
||||||
int32_to_raw.v = value;
|
|
||||||
for (unsigned char i : int32_to_raw.r) {
|
|
||||||
data->push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int input_signature_;
|
|
||||||
int model_key_;
|
|
||||||
int model_label_;
|
|
||||||
int model_weight_;
|
|
||||||
int output_label_;
|
|
||||||
int output_weight_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST(PredictOpTest, AllLabelsAreValid) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
|
|
||||||
m.SetInputSignature({1, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
|
|
||||||
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PredictOpTest, MoreLabelsThanRequired) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001);
|
|
||||||
m.SetInputSignature({1, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
|
|
||||||
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PredictOpTest, OneLabelDoesNotPassThreshold) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07);
|
|
||||||
m.SetInputSignature({1, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
|
|
||||||
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PredictOpTest, NoneLabelPassThreshold) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6);
|
|
||||||
m.SetInputSignature({1, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12});
|
|
||||||
m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PredictOpTest, OnlyOneLabelGenerated) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
|
|
||||||
m.SetInputSignature({1, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0});
|
|
||||||
m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(PredictOpTest, NoLabelGenerated) {
|
|
||||||
PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001);
|
|
||||||
m.SetInputSignature({5, 3, 7, 9});
|
|
||||||
m.SetModelKey({1, 2, 4, 6, 7});
|
|
||||||
m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0});
|
|
||||||
m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0});
|
|
||||||
m.Invoke();
|
|
||||||
EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1}));
|
|
||||||
EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0})));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace ops
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
@ -1,117 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/models/smartreply/predictor.h"
|
|
||||||
|
|
||||||
#include "absl/strings/str_split.h"
|
|
||||||
#include "re2/re2.h"
|
|
||||||
#include "tensorflow/lite/interpreter.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include "tensorflow/lite/op_resolver.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
|
|
||||||
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace custom {
|
|
||||||
namespace smartreply {
|
|
||||||
|
|
||||||
// Split sentence into segments (using punctuation).
|
|
||||||
std::vector<std::string> SplitSentence(const std::string& input) {
|
|
||||||
string result(input);
|
|
||||||
|
|
||||||
RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
|
|
||||||
RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
|
|
||||||
RE2::GlobalReplace(&result, "[ ]+", " ");
|
|
||||||
RE2::GlobalReplace(&result, "\t+$", "");
|
|
||||||
|
|
||||||
return absl::StrSplit(result, '\t');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Predict with TfLite model.
|
|
||||||
void ExecuteTfLite(const std::string& sentence,
|
|
||||||
::tflite::Interpreter* interpreter,
|
|
||||||
std::map<std::string, float>* response_map) {
|
|
||||||
{
|
|
||||||
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
|
|
||||||
tflite::DynamicBuffer buf;
|
|
||||||
buf.AddString(sentence.data(), sentence.length());
|
|
||||||
buf.WriteToTensorAsVector(input);
|
|
||||||
interpreter->AllocateTensors();
|
|
||||||
|
|
||||||
interpreter->Invoke();
|
|
||||||
|
|
||||||
TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
|
|
||||||
TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
|
|
||||||
|
|
||||||
for (int i = 0; i < confidence->dims->data[0]; i++) {
|
|
||||||
float weight = confidence->data.f[i];
|
|
||||||
auto response_text = tflite::GetString(messages, i);
|
|
||||||
if (response_text.len > 0) {
|
|
||||||
(*response_map)[string(response_text.str, response_text.len)] += weight;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GetSegmentPredictions(
|
|
||||||
const std::vector<std::string>& input,
|
|
||||||
const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
|
|
||||||
std::vector<PredictorResponse>* predictor_responses) {
|
|
||||||
// Initialize interpreter
|
|
||||||
std::unique_ptr<::tflite::Interpreter> interpreter;
|
|
||||||
::tflite::MutableOpResolver resolver;
|
|
||||||
RegisterSelectedOps(&resolver);
|
|
||||||
::tflite::InterpreterBuilder(model, resolver)(&interpreter);
|
|
||||||
|
|
||||||
if (!model.initialized()) {
|
|
||||||
fprintf(stderr, "Failed to mmap model \n");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute Tflite Model
|
|
||||||
std::map<std::string, float> response_map;
|
|
||||||
std::vector<std::string> sentences;
|
|
||||||
for (const std::string& str : input) {
|
|
||||||
std::vector<std::string> splitted_str = SplitSentence(str);
|
|
||||||
sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
|
|
||||||
}
|
|
||||||
for (const auto& sentence : sentences) {
|
|
||||||
ExecuteTfLite(sentence, interpreter.get(), &response_map);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate the result.
|
|
||||||
for (const auto& iter : response_map) {
|
|
||||||
PredictorResponse prediction(iter.first, iter.second);
|
|
||||||
predictor_responses->emplace_back(prediction);
|
|
||||||
}
|
|
||||||
std::sort(predictor_responses->begin(), predictor_responses->end(),
|
|
||||||
[](const PredictorResponse& a, const PredictorResponse& b) {
|
|
||||||
return a.GetScore() > b.GetScore();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add backoff response.
|
|
||||||
for (const auto& backoff : config.backoff_responses) {
|
|
||||||
if (predictor_responses->size() >= config.num_response) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
predictor_responses->emplace_back(backoff, config.backoff_confidence);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace smartreply
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace tflite
|
|
@ -1,80 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
|
|
||||||
#define TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace custom {
|
|
||||||
namespace smartreply {
|
|
||||||
|
|
||||||
const int kDefaultNumResponse = 10;
|
|
||||||
const float kDefaultBackoffConfidence = 1e-4;
|
|
||||||
|
|
||||||
class PredictorResponse;
|
|
||||||
struct SmartReplyConfig;
|
|
||||||
|
|
||||||
// With a given string as input, predict the response with a Tflite model.
|
|
||||||
// When config.backoff_response is not empty, predictor_responses will be filled
|
|
||||||
// with messagees from backoff response.
|
|
||||||
void GetSegmentPredictions(const std::vector<std::string>& input,
|
|
||||||
const ::tflite::FlatBufferModel& model,
|
|
||||||
const SmartReplyConfig& config,
|
|
||||||
std::vector<PredictorResponse>* predictor_responses);
|
|
||||||
|
|
||||||
// Data object used to hold a single predictor response.
|
|
||||||
// It includes messages, and confidence.
|
|
||||||
class PredictorResponse {
|
|
||||||
public:
|
|
||||||
PredictorResponse(const std::string& response_text, float score) {
|
|
||||||
response_text_ = response_text;
|
|
||||||
prediction_score_ = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accessor methods.
|
|
||||||
const std::string& GetText() const { return response_text_; }
|
|
||||||
float GetScore() const { return prediction_score_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string response_text_ = "";
|
|
||||||
float prediction_score_ = 0.0;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Configurations for SmartReply.
|
|
||||||
struct SmartReplyConfig {
|
|
||||||
// Maximum responses to return.
|
|
||||||
int num_response;
|
|
||||||
// Default confidence for backoff responses.
|
|
||||||
float backoff_confidence;
|
|
||||||
// Backoff responses are used when predicted responses cannot fulfill the
|
|
||||||
// list.
|
|
||||||
std::vector<std::string> backoff_responses;
|
|
||||||
|
|
||||||
SmartReplyConfig(const std::vector<std::string>& backoff_responses)
|
|
||||||
: num_response(kDefaultNumResponse),
|
|
||||||
backoff_confidence(kDefaultBackoffConfidence),
|
|
||||||
backoff_responses(backoff_responses) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace smartreply
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_MODELS_SMARTREPLY_PREDICTOR_H_
|
|
@ -1,163 +0,0 @@
|
|||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/models/smartreply/predictor.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <unordered_set>
|
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_split.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
#include "tensorflow/lite/string_util.h"
|
|
||||||
#include "tensorflow/lite/testing/util.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace custom {
|
|
||||||
namespace smartreply {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
const char kSamples[] = "smartreply_samples.tsv";
|
|
||||||
|
|
||||||
string GetModelFilePath() {
|
|
||||||
return "external/tflite_smartreply/smartreply.tflite"; // NOLINT
|
|
||||||
}
|
|
||||||
|
|
||||||
string GetSamplesFilePath() {
|
|
||||||
return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
|
|
||||||
"lite/models/testdata/", kSamples));
|
|
||||||
}
|
|
||||||
|
|
||||||
MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
|
|
||||||
bool has_expected_response = false;
|
|
||||||
for (const auto &item : *arg) {
|
|
||||||
const string &response = item.GetText();
|
|
||||||
if (expected_response.find(response) != expected_response.end()) {
|
|
||||||
has_expected_response = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return has_expected_response;
|
|
||||||
}
|
|
||||||
|
|
||||||
class PredictorTest : public ::testing::Test {
|
|
||||||
protected:
|
|
||||||
PredictorTest() {}
|
|
||||||
~PredictorTest() override {}
|
|
||||||
|
|
||||||
void SetUp() override {
|
|
||||||
model_ = tflite::FlatBufferModel::BuildFromFile(GetModelFilePath().c_str());
|
|
||||||
ASSERT_NE(model_.get(), nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<::tflite::FlatBufferModel> model_;
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(PredictorTest, GetSegmentPredictions) {
|
|
||||||
std::vector<PredictorResponse> predictions;
|
|
||||||
|
|
||||||
GetSegmentPredictions({"Welcome"}, *model_, /*config=*/{{}}, &predictions);
|
|
||||||
EXPECT_GT(predictions.size(), 0);
|
|
||||||
|
|
||||||
float max = 0;
|
|
||||||
for (const auto &item : predictions) {
|
|
||||||
if (item.GetScore() > max) {
|
|
||||||
max = item.GetScore();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_GT(max, 0.3);
|
|
||||||
EXPECT_THAT(
|
|
||||||
&predictions,
|
|
||||||
IncludeAnyResponesIn(std::unordered_set<string>({"Thanks very much"})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(PredictorTest, TestTwoSentences) {
|
|
||||||
std::vector<PredictorResponse> predictions;
|
|
||||||
|
|
||||||
GetSegmentPredictions({"Hello", "How are you?"}, *model_, /*config=*/{{}},
|
|
||||||
&predictions);
|
|
||||||
EXPECT_GT(predictions.size(), 0);
|
|
||||||
|
|
||||||
float max = 0;
|
|
||||||
for (const auto &item : predictions) {
|
|
||||||
if (item.GetScore() > max) {
|
|
||||||
max = item.GetScore();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_GT(max, 0.3);
|
|
||||||
EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
|
|
||||||
{"Hi, how are you doing?"})));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(PredictorTest, TestBackoff) {
|
|
||||||
std::vector<PredictorResponse> predictions;
|
|
||||||
|
|
||||||
GetSegmentPredictions({"你好"}, *model_, /*config=*/{{}}, &predictions);
|
|
||||||
EXPECT_EQ(predictions.size(), 0);
|
|
||||||
|
|
||||||
// Backoff responses are returned in order.
|
|
||||||
GetSegmentPredictions({"你好"}, *model_, /*config=*/{{"Yes", "Ok"}},
|
|
||||||
&predictions);
|
|
||||||
EXPECT_EQ(predictions.size(), 2);
|
|
||||||
EXPECT_EQ(predictions[0].GetText(), "Yes");
|
|
||||||
EXPECT_EQ(predictions[1].GetText(), "Ok");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(PredictorTest, BatchTest) {
|
|
||||||
int total_items = 0;
|
|
||||||
int total_responses = 0;
|
|
||||||
int total_triggers = 0;
|
|
||||||
|
|
||||||
string line;
|
|
||||||
std::ifstream fin(GetSamplesFilePath());
|
|
||||||
while (std::getline(fin, line)) {
|
|
||||||
const std::vector<string> fields = absl::StrSplit(line, '\t');
|
|
||||||
if (fields.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse sample file and predict
|
|
||||||
const string &msg = fields[0];
|
|
||||||
std::vector<PredictorResponse> predictions;
|
|
||||||
GetSegmentPredictions({msg}, *model_, /*config=*/{{}}, &predictions);
|
|
||||||
|
|
||||||
// Validate response and generate stats.
|
|
||||||
total_items++;
|
|
||||||
total_responses += predictions.size();
|
|
||||||
if (!predictions.empty()) {
|
|
||||||
total_triggers++;
|
|
||||||
}
|
|
||||||
EXPECT_THAT(&predictions, IncludeAnyResponesIn(std::unordered_set<string>(
|
|
||||||
fields.begin() + 1, fields.end())));
|
|
||||||
}
|
|
||||||
|
|
||||||
EXPECT_EQ(total_triggers, total_items);
|
|
||||||
EXPECT_GE(total_responses, total_triggers);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace smartreply
|
|
||||||
} // namespace custom
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
|
||||||
::tflite::LogToStderr();
|
|
||||||
::testing::InitGoogleTest(&argc, argv);
|
|
||||||
return RUN_ALL_TESTS();
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
any chance ur free tonight Maybe not
|
|
||||||
any updates? No update yet
|
|
||||||
anything i can do to help? No, but thanks No, but thank you No, but thanks for asking
|
|
||||||
be safe. I will be Will do my best Thanks, I will
|
|
||||||
congratulations Thanks thanks Congratulations
|
|
||||||
cool, let me know when you have time Cool Yes very cool Yeah, cool
|
|
||||||
drive safe Thank you, I will Home now I will thanks
|
|
||||||
hang in there, you'll be okay Doing my best Of course we will
|
|
||||||
happy birthday! Hey, thanks
|
|
||||||
happy new year! Wish you the same Thanks and same to you
|
|
||||||
have a safe flight Thanks, love you too Safe travels
|
|
||||||
hey What is up? How it going? Can I help you?
|
|
||||||
hey, got a sec? What is up? How it going? Can I help you?
|
|
||||||
how are you doing? Great and you? I am doing great
|
|
||||||
how are you feeling Feeling okay A little better Much much better
|
|
||||||
how was your weekend? It was real good
|
|
||||||
how you doing Okay and you
|
|
||||||
hugs. So sweet Thanks sweetie Take care of yourself
|
|
||||||
i'm bored Sorry to hear that Join the club No you are not
|
|
||||||
i'm planning on coming next week. let me know if that works. Works Perfect, thanks
|
|
||||||
i'm sick Sorry to hear that
|
|
||||||
i'm so happy for you Thanks me too
|
|
||||||
i'm so hungry Haha me too
|
|
||||||
i'm sorry No I am sorry Why sorry? No worries love
|
|
||||||
i'm sorry, i'm going to have to cancel. No I am sorry Why sorry? No worries love
|
|
||||||
is there anything i can do to help? No, but thanks No, but thanks for asking
|
|
||||||
lunch? Yes coming
|
|
||||||
okay. lemme know as soon as you find out. Any more questions? It is done
|
|
||||||
omg amazing So amazing
|
|
||||||
on my way Okay see you soon Cool, see you soon Oh wow, ok
|
|
||||||
oops, mistexted. Oops Haha, oh well That was funny
|
|
||||||
safe travels. Thanks, love you too Safe travels
|
|
||||||
so sorry So sorry
|
|
||||||
sorry, i can't. No worries at all Sorry what?
|
|
||||||
sorry, i can't do saturday No worries at all
|
|
||||||
thank you so much. You are so welcome You are so very welcome You are most welcome
|
|
||||||
thanks for coming It was my pleasure
|
|
||||||
thanks, this has been great. Glad to help So happy for you
|
|
||||||
tomorrow would be ideal. Yes it would
|
|
||||||
tried calling Try again?
|
|
||||||
ugh, my flight is delayed. Ugh indeed
|
|
||||||
what are you guys up to tonight? Nothing planned
|
|
||||||
what day works best for you Any day
|
|
||||||
what do you want for dinner Your call Whatever is fine
|
|
||||||
what time will you be home? Not sure why
|
|
||||||
where are you?!? At my house
|
|
||||||
wish you were here. I wish the same Me too honey
|
|
||||||
you're amazing You are too You are amazing I am
|
|
||||||
you're marvelous You are too
|
|
||||||
you're the best. I do my best You are the best Well, I try
|
|
Can't render this file because it has a wrong number of fields in line 3.
|
@ -855,16 +855,6 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_http_archive(
|
|
||||||
name = "tflite_smartreply",
|
|
||||||
build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
|
|
||||||
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
|
|
||||||
urls = [
|
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
|
|
||||||
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_http_archive(
|
tf_http_archive(
|
||||||
name = "tflite_ovic_testdata",
|
name = "tflite_ovic_testdata",
|
||||||
build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
|
build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user