Build demo app for SmartReply
PiperOrigin-RevId: 177559103
This commit is contained in:
parent
370e521762
commit
6b6244c401
@ -223,11 +223,12 @@ def gen_selected_ops(name, model):
|
||||
"""
|
||||
out = name + "_registration.cc"
|
||||
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
|
||||
tflite_path = "//tensorflow/contrib/lite"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [model],
|
||||
outs = [out],
|
||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)")
|
||||
% (tool, model, out),
|
||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
|
||||
% (tool, model, out, tflite_path[2:]),
|
||||
tools = [tool],
|
||||
)
|
||||
|
@ -1,7 +1,92 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
gen_selected_ops(
|
||||
name = "smartreply_ops",
|
||||
model = "@tflite_smartreply//:smartreply.tflite",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "custom_ops",
|
||||
srcs = [
|
||||
"ops/extract_feature.cc",
|
||||
"ops/normalize.cc",
|
||||
"ops/predict.cc",
|
||||
":smartreply_ops",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/tools:mutable_op_resolver",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
"@farmhash_archive//:farmhash",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "predictor_lib",
|
||||
srcs = ["predictor.cc"],
|
||||
hdrs = ["predictor.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":custom_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/tools:mutable_op_resolver",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "extract_feature_op_test",
|
||||
size = "small",
|
||||
srcs = ["ops/extract_feature_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@farmhash_archive//:farmhash",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "normalize_op_test",
|
||||
size = "small",
|
||||
srcs = ["ops/normalize_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "predict_op_test",
|
||||
size = "small",
|
||||
srcs = ["ops/predict_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite:string_util",
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
"//tensorflow/contrib/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -0,0 +1,38 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
Copyright 2017 The Android Open Source Project
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.example.android.smartreply" >
|
||||
|
||||
<uses-sdk
|
||||
android:minSdkVersion="15"
|
||||
android:targetSdkVersion="24" />
|
||||
|
||||
<application android:label="TfLite SmartReply Demo">
|
||||
<activity
|
||||
android:name="com.example.android.smartreply.MainActivity"
|
||||
android:configChanges="orientation|keyboardHidden|screenSize"
|
||||
android:windowSoftInputMode="stateUnchanged|adjustPan"
|
||||
android:label="TfLite SmartReply Demo"
|
||||
android:screenOrientation="portrait" >
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -0,0 +1,65 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow/contrib/lite:build_def.bzl",
|
||||
"tflite_copts",
|
||||
"tflite_jni_binary",
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "assets",
|
||||
srcs = [
|
||||
"@tflite_smartreply//:model_files",
|
||||
],
|
||||
)
|
||||
|
||||
android_binary(
|
||||
name = "SmartReplyDemo",
|
||||
srcs = glob(["java/**/*.java"]),
|
||||
assets = [":assets"],
|
||||
assets_dir = "",
|
||||
custom_package = "com.example.android.smartreply",
|
||||
manifest = "AndroidManifest.xml",
|
||||
nocompress_extensions = [
|
||||
".tflite",
|
||||
],
|
||||
resource_files = glob(["res/**"]),
|
||||
tags = ["manual"],
|
||||
deps = [
|
||||
":smartreply_runtime",
|
||||
"@androidsdk//com.android.support:support-v13-25.2.0",
|
||||
"@androidsdk//com.android.support:support-v4-25.2.0",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "smartreply_runtime",
|
||||
srcs = ["libsmartreply_jni.so"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tflite_jni_binary(
|
||||
name = "libsmartreply_jni.so",
|
||||
deps = [
|
||||
":smartreply_jni_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "smartreply_jni_lib",
|
||||
srcs = [
|
||||
"smartreply_jni.cc",
|
||||
],
|
||||
copts = tflite_copts(),
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/models/smartreply:predictor_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
@ -0,0 +1,15 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(glob(["*"]))
|
||||
|
||||
filegroup(
|
||||
name = "assets_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"BUILD",
|
||||
],
|
||||
),
|
||||
)
|
@ -0,0 +1,16 @@
|
||||
Ok
|
||||
Yes
|
||||
No
|
||||
👍
|
||||
☺
|
||||
😟
|
||||
❤️
|
||||
Lol
|
||||
Thanks
|
||||
Got it
|
||||
Done
|
||||
Nice
|
||||
I don't know
|
||||
What?
|
||||
Why?
|
||||
What's up?
|
@ -0,0 +1,99 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package com.example.android.smartreply;
|
||||
|
||||
import android.app.Activity;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.util.Log;
|
||||
import android.view.View;
|
||||
import android.widget.Button;
|
||||
import android.widget.EditText;
|
||||
import android.widget.TextView;
|
||||
|
||||
/**
|
||||
* The main (and only) activity of this demo app. Displays a text box which updates as messages are
|
||||
* received.
|
||||
*/
|
||||
public class MainActivity extends Activity {
|
||||
private static final String TAG = "SmartReplyDemo";
|
||||
private SmartReplyClient client;
|
||||
|
||||
private Button sendButton;
|
||||
private TextView messageTextView;
|
||||
private EditText messageInput;
|
||||
|
||||
private Handler handler;
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
Log.v(TAG, "onCreate");
|
||||
setContentView(R.layout.main_activity);
|
||||
|
||||
client = new SmartReplyClient(getApplicationContext());
|
||||
handler = new Handler();
|
||||
|
||||
sendButton = (Button) findViewById(R.id.send_button);
|
||||
sendButton.setOnClickListener(
|
||||
(View v) -> {
|
||||
send(messageInput.getText().toString());
|
||||
});
|
||||
|
||||
messageTextView = (TextView) findViewById(R.id.message_text);
|
||||
messageInput = (EditText) findViewById(R.id.message_input);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onStart() {
|
||||
super.onStart();
|
||||
Log.v(TAG, "onStart");
|
||||
handler.post(
|
||||
() -> {
|
||||
client.loadModel();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onStop() {
|
||||
super.onStop();
|
||||
Log.v(TAG, "onStop");
|
||||
handler.post(
|
||||
() -> {
|
||||
client.unloadModel();
|
||||
});
|
||||
}
|
||||
|
||||
private void send(final String message) {
|
||||
handler.post(
|
||||
() -> {
|
||||
messageTextView.append("Input: " + message + "\n");
|
||||
|
||||
SmartReply[] ans = client.predict(new String[] {message});
|
||||
for (SmartReply reply : ans) {
|
||||
appendMessage("Reply: " + reply.getText());
|
||||
}
|
||||
appendMessage("------");
|
||||
});
|
||||
}
|
||||
|
||||
private void appendMessage(final String message) {
|
||||
handler.post(
|
||||
() -> {
|
||||
messageTextView.append(message + "\n");
|
||||
});
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package com.example.android.smartreply;
|
||||
|
||||
import android.support.annotation.Keep;
|
||||
|
||||
/**
|
||||
* SmartReply contains predicted message, and confidence.
|
||||
*
|
||||
* <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated.
|
||||
*/
|
||||
@Keep
|
||||
public class SmartReply {
|
||||
|
||||
private final String text;
|
||||
private final float score;
|
||||
|
||||
@Keep
|
||||
public SmartReply(String text, float score) {
|
||||
this.text = text;
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public String getText() {
|
||||
return text;
|
||||
}
|
||||
|
||||
public float getScore() {
|
||||
return score;
|
||||
}
|
||||
}
|
@ -0,0 +1,129 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package com.example.android.smartreply;
|
||||
|
||||
import android.content.Context;
|
||||
import android.content.res.AssetFileDescriptor;
|
||||
import android.support.annotation.Keep;
|
||||
import android.support.annotation.WorkerThread;
|
||||
import android.util.Log;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.nio.MappedByteBuffer;
|
||||
import java.nio.channels.FileChannel;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/** Interface to load TfLite model and provide predictions. */
|
||||
public class SmartReplyClient implements AutoCloseable {
|
||||
private static final String TAG = "SmartReplyDemo";
|
||||
private static final String MODEL_PATH = "smartreply.tflite";
|
||||
private static final String BACKOFF_PATH = "backoff_response.txt";
|
||||
private static final String JNI_LIB = "smartreply_jni";
|
||||
|
||||
private final Context context;
|
||||
private long storage;
|
||||
private MappedByteBuffer model;
|
||||
|
||||
private volatile boolean isLibraryLoaded;
|
||||
|
||||
public SmartReplyClient(Context context) {
|
||||
this.context = context;
|
||||
}
|
||||
|
||||
public boolean isLoaded() {
|
||||
return storage != 0;
|
||||
}
|
||||
|
||||
@WorkerThread
|
||||
public synchronized void loadModel() {
|
||||
if (!isLibraryLoaded) {
|
||||
System.loadLibrary(JNI_LIB);
|
||||
isLibraryLoaded = true;
|
||||
}
|
||||
|
||||
try {
|
||||
model = loadModelFile();
|
||||
String[] backoff = loadBackoffList();
|
||||
storage = loadJNI(model, backoff);
|
||||
} catch (IOException e) {
|
||||
Log.e(TAG, "Fail to load model", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@WorkerThread
|
||||
public synchronized SmartReply[] predict(String[] input) {
|
||||
if (storage != 0) {
|
||||
return predictJNI(storage, input);
|
||||
} else {
|
||||
return new SmartReply[] {};
|
||||
}
|
||||
}
|
||||
|
||||
@WorkerThread
|
||||
public synchronized void unloadModel() {
|
||||
close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void close() {
|
||||
if (storage != 0) {
|
||||
unloadJNI(storage);
|
||||
storage = 0;
|
||||
}
|
||||
}
|
||||
|
||||
private MappedByteBuffer loadModelFile() throws IOException {
|
||||
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
|
||||
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
|
||||
try {
|
||||
FileChannel fileChannel = inputStream.getChannel();
|
||||
long startOffset = fileDescriptor.getStartOffset();
|
||||
long declaredLength = fileDescriptor.getDeclaredLength();
|
||||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
|
||||
} finally {
|
||||
inputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
private String[] loadBackoffList() throws IOException {
|
||||
List<String> labelList = new ArrayList<String>();
|
||||
BufferedReader reader =
|
||||
new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)));
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
if (!line.isEmpty()) {
|
||||
labelList.add(line);
|
||||
}
|
||||
}
|
||||
reader.close();
|
||||
String[] ans = new String[labelList.size()];
|
||||
labelList.toArray(ans);
|
||||
return ans;
|
||||
}
|
||||
|
||||
@Keep
|
||||
private native long loadJNI(MappedByteBuffer buffer, String[] backoff);
|
||||
|
||||
@Keep
|
||||
private native SmartReply[] predictJNI(long storage, String[] text);
|
||||
|
||||
@Keep
|
||||
private native void unloadJNI(long storage);
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical">
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="fill_parent"
|
||||
android:layout_height="0dp"
|
||||
android:padding="5dip"
|
||||
android:layout_weight="3">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/message_text"
|
||||
android:layout_width="fill_parent"
|
||||
android:layout_height="fill_parent"
|
||||
android:scrollbars="vertical"
|
||||
android:gravity="bottom"/>
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="fill_parent"
|
||||
android:layout_height="0dp"
|
||||
android:padding="5dip"
|
||||
android:layout_weight="1">
|
||||
|
||||
<EditText
|
||||
android:id="@+id/message_input"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="fill_parent"
|
||||
android:layout_weight="6"
|
||||
android:scrollbars="vertical"
|
||||
android:hint="Enter Text"
|
||||
android:gravity="top"
|
||||
android:inputType="text"/>
|
||||
<Button
|
||||
android:id="@+id/send_button"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="fill_parent"
|
||||
android:layout_weight="2"
|
||||
android:text="Send" />
|
||||
</LinearLayout>
|
||||
|
||||
</LinearLayout>
|
@ -0,0 +1,129 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <jni.h>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
|
||||
|
||||
const char kIllegalStateException[] = "java/lang/IllegalStateException";
|
||||
|
||||
using tflite::custom::smartreply::GetSegmentPredictions;
|
||||
using tflite::custom::smartreply::PredictorResponse;
|
||||
|
||||
template <typename T>
|
||||
T CheckNotNull(JNIEnv* env, T&& t) {
|
||||
if (t == nullptr) {
|
||||
env->ThrowNew(env->FindClass(kIllegalStateException), "");
|
||||
return nullptr;
|
||||
}
|
||||
return std::forward<T>(t);
|
||||
}
|
||||
|
||||
std::vector<std::string> jniStringArrayToVector(JNIEnv* env,
|
||||
jobjectArray string_array) {
|
||||
int count = env->GetArrayLength(string_array);
|
||||
std::vector<std::string> result;
|
||||
for (int i = 0; i < count; i++) {
|
||||
auto jstr =
|
||||
reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i));
|
||||
const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
|
||||
result.emplace_back(std::string(raw_str));
|
||||
env->ReleaseStringUTFChars(jstr, raw_str);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
struct JNIStorage {
|
||||
std::vector<std::string> backoff_list;
|
||||
std::unique_ptr<::tflite::FlatBufferModel> model;
|
||||
};
|
||||
|
||||
extern "C" JNIEXPORT jlong JNICALL
|
||||
Java_com_example_android_smartreply_SmartReplyClient_loadJNI(
|
||||
JNIEnv* env, jobject thiz, jobject model_buffer,
|
||||
jobjectArray backoff_list) {
|
||||
const char* buf =
|
||||
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
|
||||
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
|
||||
|
||||
JNIStorage* storage = new JNIStorage;
|
||||
storage->model = tflite::FlatBufferModel::BuildFromBuffer(
|
||||
buf, static_cast<size_t>(capacity));
|
||||
storage->backoff_list = jniStringArrayToVector(env, backoff_list);
|
||||
|
||||
if (!storage->model) {
|
||||
delete storage;
|
||||
env->ThrowNew(env->FindClass(kIllegalStateException), "");
|
||||
return 0;
|
||||
}
|
||||
return reinterpret_cast<jlong>(storage);
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_example_android_smartreply_SmartReplyClient_predictJNI(
|
||||
JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) {
|
||||
// Predict
|
||||
if (storage_ptr == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
|
||||
if (storage == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<PredictorResponse> responses;
|
||||
GetSegmentPredictions(jniStringArrayToVector(env, input_text),
|
||||
*storage->model, {storage->backoff_list}, &responses);
|
||||
|
||||
// Create a SmartReply[] to return back to Java
|
||||
jclass smart_reply_class = CheckNotNull(
|
||||
env, env->FindClass("com/example/android/smartreply/SmartReply"));
|
||||
if (env->ExceptionCheck()) {
|
||||
return nullptr;
|
||||
}
|
||||
jmethodID smart_reply_ctor = CheckNotNull(
|
||||
env,
|
||||
env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V"));
|
||||
if (env->ExceptionCheck()) {
|
||||
return nullptr;
|
||||
}
|
||||
jobjectArray array = CheckNotNull(
|
||||
env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr));
|
||||
if (env->ExceptionCheck()) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int i = 0; i < responses.size(); i++) {
|
||||
jstring text =
|
||||
CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data()));
|
||||
if (env->ExceptionCheck()) {
|
||||
return nullptr;
|
||||
}
|
||||
jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text,
|
||||
responses[i].GetScore());
|
||||
env->SetObjectArrayElement(array, i, reply);
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
extern "C" JNIEXPORT void JNICALL
|
||||
Java_com_example_android_smartreply_SmartReplyClient_unloadJNI(
|
||||
JNIEnv* env, jobject thiz, jlong storage_ptr) {
|
||||
if (storage_ptr != 0) {
|
||||
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
|
||||
delete storage;
|
||||
}
|
||||
}
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "re2/re2.h"
|
||||
|
||||
#include "tensorflow/contrib/lite/context.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/string_util.h"
|
||||
@ -81,7 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* label = GetOutput(context, node, 0);
|
||||
TfLiteTensor* weight = GetOutput(context, node, 1);
|
||||
|
||||
std::map<int64, int> feature_id_counts;
|
||||
std::map<int64_t, int> feature_id_counts;
|
||||
for (int i = 0; i < num_strings; i++) {
|
||||
// Use fingerprint of feature name as id.
|
||||
auto strref = tflite::GetString(input, i);
|
||||
@ -91,10 +91,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64 feature_id =
|
||||
int64_t feature_id =
|
||||
::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
|
||||
|
||||
label->data.i32[i] = static_cast<int32>(feature_id);
|
||||
label->data.i32[i] = static_cast<int32_t>(feature_id);
|
||||
weight->data.f[i] =
|
||||
std::count(strref.str, strref.str + strref.len, ' ') + 1;
|
||||
}
|
||||
|
@ -21,7 +21,10 @@ limitations under the License.
|
||||
// Output:
|
||||
// Output[0]: Normalized sentence. string[1]
|
||||
//
|
||||
#include "absl/strings/ascii.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "re2/re2.h"
|
||||
@ -50,7 +53,7 @@ const std::map<string, string>* kRegexTransforms =
|
||||
|
||||
static const char kStartToken[] = "<S>";
|
||||
static const char kEndToken[] = "<E>";
|
||||
static const int32 kMaxInputChars = 300;
|
||||
static const int32_t kMaxInputChars = 300;
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
|
||||
|
@ -30,7 +30,7 @@ namespace custom {
|
||||
namespace smartreply {
|
||||
|
||||
// Split sentence into segments (using punctuation).
|
||||
std::vector<string> SplitSentence(const string& input) {
|
||||
std::vector<std::string> SplitSentence(const std::string& input) {
|
||||
string result(input);
|
||||
|
||||
RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
|
||||
@ -38,12 +38,13 @@ std::vector<string> SplitSentence(const string& input) {
|
||||
RE2::GlobalReplace(&result, "[ ]+", " ");
|
||||
RE2::GlobalReplace(&result, "\t+$", "");
|
||||
|
||||
return strings::Split(result, '\t');
|
||||
return absl::StrSplit(result, '\t');
|
||||
}
|
||||
|
||||
// Predict with TfLite model.
|
||||
void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
|
||||
std::map<string, float>* response_map) {
|
||||
void ExecuteTfLite(const std::string& sentence,
|
||||
::tflite::Interpreter* interpreter,
|
||||
std::map<std::string, float>* response_map) {
|
||||
{
|
||||
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
|
||||
tflite::DynamicBuffer buf;
|
||||
@ -67,8 +68,8 @@ void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
|
||||
}
|
||||
|
||||
void GetSegmentPredictions(
|
||||
const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
|
||||
const SmartReplyConfig& config,
|
||||
const std::vector<std::string>& input,
|
||||
const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
|
||||
std::vector<PredictorResponse>* predictor_responses) {
|
||||
// Initialize interpreter
|
||||
std::unique_ptr<::tflite::Interpreter> interpreter;
|
||||
@ -82,10 +83,10 @@ void GetSegmentPredictions(
|
||||
}
|
||||
|
||||
// Execute Tflite Model
|
||||
std::map<string, float> response_map;
|
||||
std::vector<string> sentences;
|
||||
for (const string& str : input) {
|
||||
std::vector<string> splitted_str = SplitSentence(str);
|
||||
std::map<std::string, float> response_map;
|
||||
std::vector<std::string> sentences;
|
||||
for (const std::string& str : input) {
|
||||
std::vector<std::string> splitted_str = SplitSentence(str);
|
||||
sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
|
||||
}
|
||||
for (const auto& sentence : sentences) {
|
||||
|
@ -34,7 +34,7 @@ struct SmartReplyConfig;
|
||||
// With a given string as input, predict the response with a Tflite model.
|
||||
// When config.backoff_response is not empty, predictor_responses will be filled
|
||||
// with messagees from backoff response.
|
||||
void GetSegmentPredictions(const std::vector<string>& input,
|
||||
void GetSegmentPredictions(const std::vector<std::string>& input,
|
||||
const ::tflite::FlatBufferModel& model,
|
||||
const SmartReplyConfig& config,
|
||||
std::vector<PredictorResponse>* predictor_responses);
|
||||
@ -43,17 +43,17 @@ void GetSegmentPredictions(const std::vector<string>& input,
|
||||
// It includes messages, and confidence.
|
||||
class PredictorResponse {
|
||||
public:
|
||||
PredictorResponse(const string& response_text, float score) {
|
||||
PredictorResponse(const std::string& response_text, float score) {
|
||||
response_text_ = response_text;
|
||||
prediction_score_ = score;
|
||||
}
|
||||
|
||||
// Accessor methods.
|
||||
const string& GetText() const { return response_text_; }
|
||||
const std::string& GetText() const { return response_text_; }
|
||||
float GetScore() const { return prediction_score_; }
|
||||
|
||||
private:
|
||||
string response_text_ = "";
|
||||
std::string response_text_ = "";
|
||||
float prediction_score_ = 0.0;
|
||||
};
|
||||
|
||||
@ -65,9 +65,9 @@ struct SmartReplyConfig {
|
||||
float backoff_confidence;
|
||||
// Backoff responses are used when predicted responses cannot fulfill the
|
||||
// list.
|
||||
const std::vector<string>& backoff_responses;
|
||||
const std::vector<std::string>& backoff_responses;
|
||||
|
||||
SmartReplyConfig(std::vector<string> backoff_responses)
|
||||
SmartReplyConfig(std::vector<std::string> backoff_responses)
|
||||
: num_response(kDefaultNumResponse),
|
||||
backoff_confidence(kDefaultBackoffConfidence),
|
||||
backoff_responses(backoff_responses) {}
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "base/logging.h"
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/contrib/lite/models/test_utils.h"
|
||||
#include "tensorflow/contrib/lite/string_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace custom {
|
||||
@ -65,7 +65,6 @@ TEST_F(PredictorTest, GetSegmentPredictions) {
|
||||
|
||||
float max = 0;
|
||||
for (const auto &item : predictions) {
|
||||
LOG(INFO) << "Response: " << item.GetText();
|
||||
if (item.GetScore() > max) {
|
||||
max = item.GetScore();
|
||||
}
|
||||
@ -86,7 +85,6 @@ TEST_F(PredictorTest, TestTwoSentences) {
|
||||
|
||||
float max = 0;
|
||||
for (const auto &item : predictions) {
|
||||
LOG(INFO) << "Response: " << item.GetText();
|
||||
if (item.GetScore() > max) {
|
||||
max = item.GetScore();
|
||||
}
|
||||
@ -119,7 +117,7 @@ TEST_F(PredictorTest, BatchTest) {
|
||||
string line;
|
||||
std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
|
||||
while (std::getline(fin, line)) {
|
||||
const std::vector<string> &fields = strings::Split(line, '\t');
|
||||
const std::vector<string> fields = absl::StrSplit(line, '\t');
|
||||
if (fields.empty()) {
|
||||
continue;
|
||||
}
|
||||
@ -139,9 +137,8 @@ TEST_F(PredictorTest, BatchTest) {
|
||||
fields.begin() + 1, fields.end())));
|
||||
}
|
||||
|
||||
LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
|
||||
LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
|
||||
EXPECT_EQ(total_triggers, total_items);
|
||||
EXPECT_GE(total_responses, total_triggers);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -13,6 +13,7 @@ tf_cc_binary(
|
||||
"//tensorflow/contrib/lite/tools:gen_op_registration",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -13,30 +13,50 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
const char kInputModelFlag[] = "input_model";
|
||||
const char kOutputRegistrationFlag[] = "output_registration";
|
||||
const char kTfLitePathFlag[] = "tflite_path";
|
||||
|
||||
using tensorflow::Flag;
|
||||
using tensorflow::Flags;
|
||||
using tensorflow::string;
|
||||
|
||||
void ParseFlagAndInit(int argc, char** argv, string* input_model,
|
||||
string* output_registration, string* tflite_path) {
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
Flag(kInputModelFlag, input_model, "path to the tflite model"),
|
||||
Flag(kOutputRegistrationFlag, output_registration,
|
||||
"filename for generated registration code"),
|
||||
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
|
||||
};
|
||||
|
||||
Flags::Parse(&argc, argv, flag_list);
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void GenerateFileContent(const string& filename,
|
||||
void GenerateFileContent(const std::string& tflite_path,
|
||||
const std::string& filename,
|
||||
const std::vector<string>& builtin_ops,
|
||||
const std::vector<string>& custom_ops) {
|
||||
std::ofstream fout(filename);
|
||||
|
||||
fout << "#include "
|
||||
"\"third_party/tensorflow/contrib/lite/model.h\"\n";
|
||||
fout << "#include "
|
||||
"\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n";
|
||||
fout << "#include \"" << tflite_path << "/model.h\"\n";
|
||||
fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n";
|
||||
|
||||
fout << "namespace tflite {\n";
|
||||
fout << "namespace ops {\n";
|
||||
if (!builtin_ops.empty()) {
|
||||
@ -78,22 +98,20 @@ void GenerateFileContent(const string& filename,
|
||||
int main(int argc, char** argv) {
|
||||
string input_model;
|
||||
string output_registration;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
Flag("input_model", &input_model, "path to the tflite model"),
|
||||
Flag("output_registration", &output_registration,
|
||||
"filename for generated registration code"),
|
||||
};
|
||||
Flags::Parse(&argc, argv, flag_list);
|
||||
string tflite_path;
|
||||
ParseFlagAndInit(argc, argv, &input_model, &output_registration,
|
||||
&tflite_path);
|
||||
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
std::vector<string> builtin_ops;
|
||||
std::vector<string> custom_ops;
|
||||
|
||||
std::ifstream fin(input_model);
|
||||
std::stringstream content;
|
||||
content << fin.rdbuf();
|
||||
const ::tflite::Model* model = ::tflite::GetModel(content.str().data());
|
||||
// Need to store content data first, otherwise, it won't work in bazel.
|
||||
string content_str = content.str();
|
||||
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
|
||||
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
|
||||
GenerateFileContent(output_registration, builtin_ops, custom_ops);
|
||||
GenerateFileContent(tflite_path, output_registration, builtin_ops,
|
||||
custom_ops);
|
||||
return 0;
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ class MutableOpResolver : public OpResolver {
|
||||
void AddCustom(const char* name, TfLiteRegistration* registration);
|
||||
|
||||
private:
|
||||
std::map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_;
|
||||
std::map<int, TfLiteRegistration*> builtins_;
|
||||
std::map<std::string, TfLiteRegistration*> custom_ops_;
|
||||
};
|
||||
|
||||
|
@ -207,11 +207,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
native.http_archive(
|
||||
name = "com_googlesource_code_re2",
|
||||
urls = [
|
||||
"https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
|
||||
"https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
|
||||
"https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
|
||||
|
||||
],
|
||||
sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f",
|
||||
strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c",
|
||||
sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b",
|
||||
strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857",
|
||||
)
|
||||
|
||||
native.http_archive(
|
||||
@ -800,3 +801,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
|
||||
],
|
||||
)
|
||||
|
||||
native.new_http_archive(
|
||||
name = "tflite_smartreply",
|
||||
build_file = str(Label("//third_party:tflite_smartreply.BUILD")),
|
||||
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip"
|
||||
],
|
||||
)
|
||||
|
13
third_party/tflite_smartreply.BUILD
vendored
Normal file
13
third_party/tflite_smartreply.BUILD
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
filegroup(
|
||||
name = "model_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"BUILD",
|
||||
],
|
||||
),
|
||||
)
|
Loading…
Reference in New Issue
Block a user