Build demo app for SmartReply

PiperOrigin-RevId: 177559103
This commit is contained in:
A. Unique TensorFlower 2017-11-30 23:58:26 -08:00 committed by TensorFlower Gardener
parent 370e521762
commit 6b6244c401
21 changed files with 758 additions and 51 deletions

View File

@ -223,11 +223,12 @@ def gen_selected_ops(name, model):
"""
out = name + "_registration.cc"
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
tflite_path = "//tensorflow/contrib/lite"
native.genrule(
name = name,
srcs = [model],
outs = [out],
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)")
% (tool, model, out),
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
% (tool, model, out, tflite_path[2:]),
tools = [tool],
)

View File

@ -1,7 +1,92 @@
package(default_visibility = ["//visibility:public"])
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
licenses(["notice"]) # Apache 2.0
gen_selected_ops(
name = "smartreply_ops",
model = "@tflite_smartreply//:smartreply.tflite",
)
cc_library(
name = "custom_ops",
srcs = [
"ops/extract_feature.cc",
"ops/normalize.cc",
"ops/predict.cc",
":smartreply_ops",
],
copts = tflite_copts(),
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/tools:mutable_op_resolver",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
"@farmhash_archive//:farmhash",
],
)
cc_library(
name = "predictor_lib",
srcs = ["predictor.cc"],
hdrs = ["predictor.h"],
copts = tflite_copts(),
deps = [
":custom_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/tools:mutable_op_resolver",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
],
)
cc_test(
name = "extract_feature_op_test",
size = "small",
srcs = ["ops/extract_feature_test.cc"],
deps = [
":custom_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
"@farmhash_archive//:farmhash",
],
)
cc_test(
name = "normalize_op_test",
size = "small",
srcs = ["ops/normalize_test.cc"],
deps = [
":custom_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "predict_op_test",
size = "small",
srcs = ["ops/predict_test.cc"],
deps = [
":custom_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,38 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright 2017 The Android Open Source Project
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.example.android.smartreply" >
<uses-sdk
android:minSdkVersion="15"
android:targetSdkVersion="24" />
<application android:label="TfLite SmartReply Demo">
<activity
android:name="com.example.android.smartreply.MainActivity"
android:configChanges="orientation|keyboardHidden|screenSize"
android:windowSoftInputMode="stateUnchanged|adjustPan"
android:label="TfLite SmartReply Demo"
android:screenOrientation="portrait" >
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@ -0,0 +1,65 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
load(
"//tensorflow/contrib/lite:build_def.bzl",
"tflite_copts",
"tflite_jni_binary",
)
filegroup(
name = "assets",
srcs = [
"@tflite_smartreply//:model_files",
],
)
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
manifest = "AndroidManifest.xml",
nocompress_extensions = [
".tflite",
],
resource_files = glob(["res/**"]),
tags = ["manual"],
deps = [
":smartreply_runtime",
"@androidsdk//com.android.support:support-v13-25.2.0",
"@androidsdk//com.android.support:support-v4-25.2.0",
],
)
cc_library(
name = "smartreply_runtime",
srcs = ["libsmartreply_jni.so"],
visibility = ["//visibility:public"],
)
tflite_jni_binary(
name = "libsmartreply_jni.so",
deps = [
":smartreply_jni_lib",
],
)
cc_library(
name = "smartreply_jni_lib",
srcs = [
"smartreply_jni.cc",
],
copts = tflite_copts(),
linkopts = [
"-lm",
"-ldl",
],
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/models/smartreply:predictor_lib",
],
alwayslink = 1,
)

View File

@ -0,0 +1,15 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(glob(["*"]))
filegroup(
name = "assets_files",
srcs = glob(
["**/*"],
exclude = [
"BUILD",
],
),
)

View File

@ -0,0 +1,16 @@
Ok
Yes
No
👍
😟
❤️
Lol
Thanks
Got it
Done
Nice
I don't know
What?
Why?
What's up?

View File

@ -0,0 +1,99 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.app.Activity;
import android.os.Bundle;
import android.os.Handler;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.EditText;
import android.widget.TextView;
/**
* The main (and only) activity of this demo app. Displays a text box which updates as messages are
* received.
*/
public class MainActivity extends Activity {
private static final String TAG = "SmartReplyDemo";
private SmartReplyClient client;
private Button sendButton;
private TextView messageTextView;
private EditText messageInput;
private Handler handler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Log.v(TAG, "onCreate");
setContentView(R.layout.main_activity);
client = new SmartReplyClient(getApplicationContext());
handler = new Handler();
sendButton = (Button) findViewById(R.id.send_button);
sendButton.setOnClickListener(
(View v) -> {
send(messageInput.getText().toString());
});
messageTextView = (TextView) findViewById(R.id.message_text);
messageInput = (EditText) findViewById(R.id.message_input);
}
@Override
protected void onStart() {
super.onStart();
Log.v(TAG, "onStart");
handler.post(
() -> {
client.loadModel();
});
}
@Override
protected void onStop() {
super.onStop();
Log.v(TAG, "onStop");
handler.post(
() -> {
client.unloadModel();
});
}
private void send(final String message) {
handler.post(
() -> {
messageTextView.append("Input: " + message + "\n");
SmartReply[] ans = client.predict(new String[] {message});
for (SmartReply reply : ans) {
appendMessage("Reply: " + reply.getText());
}
appendMessage("------");
});
}
private void appendMessage(final String message) {
handler.post(
() -> {
messageTextView.append(message + "\n");
});
}
}

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.support.annotation.Keep;
/**
* SmartReply contains predicted message, and confidence.
*
* <p>NOTE: this class used by JNI, class name and constructor should not be obfuscated.
*/
@Keep
public class SmartReply {
private final String text;
private final float score;
@Keep
public SmartReply(String text, float score) {
this.text = text;
this.score = score;
}
public String getText() {
return text;
}
public float getScore() {
return score;
}
}

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.smartreply;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.support.annotation.Keep;
import android.support.annotation.WorkerThread;
import android.util.Log;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.List;
/** Interface to load TfLite model and provide predictions. */
public class SmartReplyClient implements AutoCloseable {
private static final String TAG = "SmartReplyDemo";
private static final String MODEL_PATH = "smartreply.tflite";
private static final String BACKOFF_PATH = "backoff_response.txt";
private static final String JNI_LIB = "smartreply_jni";
private final Context context;
private long storage;
private MappedByteBuffer model;
private volatile boolean isLibraryLoaded;
public SmartReplyClient(Context context) {
this.context = context;
}
public boolean isLoaded() {
return storage != 0;
}
@WorkerThread
public synchronized void loadModel() {
if (!isLibraryLoaded) {
System.loadLibrary(JNI_LIB);
isLibraryLoaded = true;
}
try {
model = loadModelFile();
String[] backoff = loadBackoffList();
storage = loadJNI(model, backoff);
} catch (IOException e) {
Log.e(TAG, "Fail to load model", e);
return;
}
}
@WorkerThread
public synchronized SmartReply[] predict(String[] input) {
if (storage != 0) {
return predictJNI(storage, input);
} else {
return new SmartReply[] {};
}
}
@WorkerThread
public synchronized void unloadModel() {
close();
}
@Override
public synchronized void close() {
if (storage != 0) {
unloadJNI(storage);
storage = 0;
}
}
private MappedByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
try {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
} finally {
inputStream.close();
}
}
private String[] loadBackoffList() throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)));
String line;
while ((line = reader.readLine()) != null) {
if (!line.isEmpty()) {
labelList.add(line);
}
}
reader.close();
String[] ans = new String[labelList.size()];
labelList.toArray(ans);
return ans;
}
@Keep
private native long loadJNI(MappedByteBuffer buffer, String[] backoff);
@Keep
private native SmartReply[] predictJNI(long storage, String[] text);
@Keep
private native void unloadJNI(long storage);
}

View File

@ -0,0 +1,44 @@
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="0dp"
android:padding="5dip"
android:layout_weight="3">
<TextView
android:id="@+id/message_text"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:scrollbars="vertical"
android:gravity="bottom"/>
</LinearLayout>
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="0dp"
android:padding="5dip"
android:layout_weight="1">
<EditText
android:id="@+id/message_input"
android:layout_width="0dp"
android:layout_height="fill_parent"
android:layout_weight="6"
android:scrollbars="vertical"
android:hint="Enter Text"
android:gravity="top"
android:inputType="text"/>
<Button
android:id="@+id/send_button"
android:layout_width="0dp"
android:layout_height="fill_parent"
android:layout_weight="2"
android:text="Send" />
</LinearLayout>
</LinearLayout>

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <jni.h>
#include <utility>
#include <vector>
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
const char kIllegalStateException[] = "java/lang/IllegalStateException";
using tflite::custom::smartreply::GetSegmentPredictions;
using tflite::custom::smartreply::PredictorResponse;
template <typename T>
T CheckNotNull(JNIEnv* env, T&& t) {
if (t == nullptr) {
env->ThrowNew(env->FindClass(kIllegalStateException), "");
return nullptr;
}
return std::forward<T>(t);
}
std::vector<std::string> jniStringArrayToVector(JNIEnv* env,
jobjectArray string_array) {
int count = env->GetArrayLength(string_array);
std::vector<std::string> result;
for (int i = 0; i < count; i++) {
auto jstr =
reinterpret_cast<jstring>(env->GetObjectArrayElement(string_array, i));
const char* raw_str = env->GetStringUTFChars(jstr, JNI_FALSE);
result.emplace_back(std::string(raw_str));
env->ReleaseStringUTFChars(jstr, raw_str);
}
return result;
}
struct JNIStorage {
std::vector<std::string> backoff_list;
std::unique_ptr<::tflite::FlatBufferModel> model;
};
extern "C" JNIEXPORT jlong JNICALL
Java_com_example_android_smartreply_SmartReplyClient_loadJNI(
JNIEnv* env, jobject thiz, jobject model_buffer,
jobjectArray backoff_list) {
const char* buf =
static_cast<char*>(env->GetDirectBufferAddress(model_buffer));
jlong capacity = env->GetDirectBufferCapacity(model_buffer);
JNIStorage* storage = new JNIStorage;
storage->model = tflite::FlatBufferModel::BuildFromBuffer(
buf, static_cast<size_t>(capacity));
storage->backoff_list = jniStringArrayToVector(env, backoff_list);
if (!storage->model) {
delete storage;
env->ThrowNew(env->FindClass(kIllegalStateException), "");
return 0;
}
return reinterpret_cast<jlong>(storage);
}
extern "C" JNIEXPORT jobjectArray JNICALL
Java_com_example_android_smartreply_SmartReplyClient_predictJNI(
JNIEnv* env, jobject /*thiz*/, jlong storage_ptr, jobjectArray input_text) {
// Predict
if (storage_ptr == 0) {
return nullptr;
}
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
if (storage == nullptr) {
return nullptr;
}
std::vector<PredictorResponse> responses;
GetSegmentPredictions(jniStringArrayToVector(env, input_text),
*storage->model, {storage->backoff_list}, &responses);
// Create a SmartReply[] to return back to Java
jclass smart_reply_class = CheckNotNull(
env, env->FindClass("com/example/android/smartreply/SmartReply"));
if (env->ExceptionCheck()) {
return nullptr;
}
jmethodID smart_reply_ctor = CheckNotNull(
env,
env->GetMethodID(smart_reply_class, "<init>", "(Ljava/lang/String;F)V"));
if (env->ExceptionCheck()) {
return nullptr;
}
jobjectArray array = CheckNotNull(
env, env->NewObjectArray(responses.size(), smart_reply_class, nullptr));
if (env->ExceptionCheck()) {
return nullptr;
}
for (int i = 0; i < responses.size(); i++) {
jstring text =
CheckNotNull(env, env->NewStringUTF(responses[i].GetText().data()));
if (env->ExceptionCheck()) {
return nullptr;
}
jobject reply = env->NewObject(smart_reply_class, smart_reply_ctor, text,
responses[i].GetScore());
env->SetObjectArrayElement(array, i, reply);
}
return array;
}
extern "C" JNIEXPORT void JNICALL
Java_com_example_android_smartreply_SmartReplyClient_unloadJNI(
JNIEnv* env, jobject thiz, jlong storage_ptr) {
if (storage_ptr != 0) {
JNIStorage* storage = reinterpret_cast<JNIStorage*>(storage_ptr);
delete storage;
}
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#include <algorithm>
#include <map>
#include "re2/re2.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/string_util.h"
@ -81,7 +81,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* label = GetOutput(context, node, 0);
TfLiteTensor* weight = GetOutput(context, node, 1);
std::map<int64, int> feature_id_counts;
std::map<int64_t, int> feature_id_counts;
for (int i = 0; i < num_strings; i++) {
// Use fingerprint of feature name as id.
auto strref = tflite::GetString(input, i);
@ -91,10 +91,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
continue;
}
int64 feature_id =
int64_t feature_id =
::util::Fingerprint64(strref.str, strref.len) % kMaxDimension;
label->data.i32[i] = static_cast<int32>(feature_id);
label->data.i32[i] = static_cast<int32_t>(feature_id);
weight->data.f[i] =
std::count(strref.str, strref.str + strref.len, ' ') + 1;
}

View File

@ -21,7 +21,10 @@ limitations under the License.
// Output:
// Output[0]: Normalized sentence. string[1]
//
#include "absl/strings/ascii.h"
#include <algorithm>
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/strip.h"
#include "re2/re2.h"
@ -50,7 +53,7 @@ const std::map<string, string>* kRegexTransforms =
static const char kStartToken[] = "<S>";
static const char kEndToken[] = "<E>";
static const int32 kMaxInputChars = 300;
static const int32_t kMaxInputChars = 300;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);

View File

@ -30,7 +30,7 @@ namespace custom {
namespace smartreply {
// Split sentence into segments (using punctuation).
std::vector<string> SplitSentence(const string& input) {
std::vector<std::string> SplitSentence(const std::string& input) {
string result(input);
RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
@ -38,12 +38,13 @@ std::vector<string> SplitSentence(const string& input) {
RE2::GlobalReplace(&result, "[ ]+", " ");
RE2::GlobalReplace(&result, "\t+$", "");
return strings::Split(result, '\t');
return absl::StrSplit(result, '\t');
}
// Predict with TfLite model.
void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
std::map<string, float>* response_map) {
void ExecuteTfLite(const std::string& sentence,
::tflite::Interpreter* interpreter,
std::map<std::string, float>* response_map) {
{
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
tflite::DynamicBuffer buf;
@ -67,8 +68,8 @@ void ExecuteTfLite(const string& sentence, ::tflite::Interpreter* interpreter,
}
void GetSegmentPredictions(
const std::vector<string>& input, const ::tflite::FlatBufferModel& model,
const SmartReplyConfig& config,
const std::vector<std::string>& input,
const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses) {
// Initialize interpreter
std::unique_ptr<::tflite::Interpreter> interpreter;
@ -82,10 +83,10 @@ void GetSegmentPredictions(
}
// Execute Tflite Model
std::map<string, float> response_map;
std::vector<string> sentences;
for (const string& str : input) {
std::vector<string> splitted_str = SplitSentence(str);
std::map<std::string, float> response_map;
std::vector<std::string> sentences;
for (const std::string& str : input) {
std::vector<std::string> splitted_str = SplitSentence(str);
sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
}
for (const auto& sentence : sentences) {

View File

@ -34,7 +34,7 @@ struct SmartReplyConfig;
// With a given string as input, predict the response with a Tflite model.
// When config.backoff_response is not empty, predictor_responses will be filled
// with messagees from backoff response.
void GetSegmentPredictions(const std::vector<string>& input,
void GetSegmentPredictions(const std::vector<std::string>& input,
const ::tflite::FlatBufferModel& model,
const SmartReplyConfig& config,
std::vector<PredictorResponse>* predictor_responses);
@ -43,17 +43,17 @@ void GetSegmentPredictions(const std::vector<string>& input,
// It includes messages, and confidence.
class PredictorResponse {
public:
PredictorResponse(const string& response_text, float score) {
PredictorResponse(const std::string& response_text, float score) {
response_text_ = response_text;
prediction_score_ = score;
}
// Accessor methods.
const string& GetText() const { return response_text_; }
const std::string& GetText() const { return response_text_; }
float GetScore() const { return prediction_score_; }
private:
string response_text_ = "";
std::string response_text_ = "";
float prediction_score_ = 0.0;
};
@ -65,9 +65,9 @@ struct SmartReplyConfig {
float backoff_confidence;
// Backoff responses are used when predicted responses cannot fulfill the
// list.
const std::vector<string>& backoff_responses;
const std::vector<std::string>& backoff_responses;
SmartReplyConfig(std::vector<string> backoff_responses)
SmartReplyConfig(std::vector<std::string> backoff_responses)
: num_response(kDefaultNumResponse),
backoff_confidence(kDefaultBackoffConfidence),
backoff_responses(backoff_responses) {}

View File

@ -18,12 +18,12 @@ limitations under the License.
#include <fstream>
#include <unordered_set>
#include "base/logging.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/contrib/lite/models/test_utils.h"
#include "tensorflow/contrib/lite/string_util.h"
namespace tflite {
namespace custom {
@ -65,7 +65,6 @@ TEST_F(PredictorTest, GetSegmentPredictions) {
float max = 0;
for (const auto &item : predictions) {
LOG(INFO) << "Response: " << item.GetText();
if (item.GetScore() > max) {
max = item.GetScore();
}
@ -86,7 +85,6 @@ TEST_F(PredictorTest, TestTwoSentences) {
float max = 0;
for (const auto &item : predictions) {
LOG(INFO) << "Response: " << item.GetText();
if (item.GetScore() > max) {
max = item.GetScore();
}
@ -119,7 +117,7 @@ TEST_F(PredictorTest, BatchTest) {
string line;
std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
while (std::getline(fin, line)) {
const std::vector<string> &fields = strings::Split(line, '\t');
const std::vector<string> fields = absl::StrSplit(line, '\t');
if (fields.empty()) {
continue;
}
@ -139,9 +137,8 @@ TEST_F(PredictorTest, BatchTest) {
fields.begin() + 1, fields.end())));
}
LOG(INFO) << "Responses: " << total_responses << " / " << total_items;
LOG(INFO) << "Triggers: " << total_triggers << " / " << total_items;
EXPECT_EQ(total_triggers, total_items);
EXPECT_GE(total_responses, total_triggers);
}
} // namespace

View File

@ -13,6 +13,7 @@ tf_cc_binary(
"//tensorflow/contrib/lite/tools:gen_op_registration",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

View File

@ -13,30 +13,50 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cassert>
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <vector>
#include "absl/strings/strip.h"
#include "tensorflow/contrib/lite/tools/gen_op_registration.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/util/command_line_flags.h"
const char kInputModelFlag[] = "input_model";
const char kOutputRegistrationFlag[] = "output_registration";
const char kTfLitePathFlag[] = "tflite_path";
using tensorflow::Flag;
using tensorflow::Flags;
using tensorflow::string;
void ParseFlagAndInit(int argc, char** argv, string* input_model,
string* output_registration, string* tflite_path) {
std::vector<tensorflow::Flag> flag_list = {
Flag(kInputModelFlag, input_model, "path to the tflite model"),
Flag(kOutputRegistrationFlag, output_registration,
"filename for generated registration code"),
Flag(kTfLitePathFlag, tflite_path, "Path to tensorflow lite dir"),
};
Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(argv[0], &argc, &argv);
}
namespace {
void GenerateFileContent(const string& filename,
void GenerateFileContent(const std::string& tflite_path,
const std::string& filename,
const std::vector<string>& builtin_ops,
const std::vector<string>& custom_ops) {
std::ofstream fout(filename);
fout << "#include "
"\"third_party/tensorflow/contrib/lite/model.h\"\n";
fout << "#include "
"\"third_party/tensorflow/contrib/lite/tools/mutable_op_resolver.h\"\n";
fout << "#include \"" << tflite_path << "/model.h\"\n";
fout << "#include \"" << tflite_path << "/tools/mutable_op_resolver.h\"\n";
fout << "namespace tflite {\n";
fout << "namespace ops {\n";
if (!builtin_ops.empty()) {
@ -78,22 +98,20 @@ void GenerateFileContent(const string& filename,
int main(int argc, char** argv) {
string input_model;
string output_registration;
std::vector<tensorflow::Flag> flag_list = {
Flag("input_model", &input_model, "path to the tflite model"),
Flag("output_registration", &output_registration,
"filename for generated registration code"),
};
Flags::Parse(&argc, argv, flag_list);
string tflite_path;
ParseFlagAndInit(argc, argv, &input_model, &output_registration,
&tflite_path);
tensorflow::port::InitMain(argv[0], &argc, &argv);
std::vector<string> builtin_ops;
std::vector<string> custom_ops;
std::ifstream fin(input_model);
std::stringstream content;
content << fin.rdbuf();
const ::tflite::Model* model = ::tflite::GetModel(content.str().data());
// Need to store content data first, otherwise, it won't work in bazel.
string content_str = content.str();
const ::tflite::Model* model = ::tflite::GetModel(content_str.data());
::tflite::ReadOpsFromModel(model, &builtin_ops, &custom_ops);
GenerateFileContent(output_registration, builtin_ops, custom_ops);
GenerateFileContent(tflite_path, output_registration, builtin_ops,
custom_ops);
return 0;
}

View File

@ -46,7 +46,7 @@ class MutableOpResolver : public OpResolver {
void AddCustom(const char* name, TfLiteRegistration* registration);
private:
std::map<tflite::BuiltinOperator, TfLiteRegistration*> builtins_;
std::map<int, TfLiteRegistration*> builtins_;
std::map<std::string, TfLiteRegistration*> custom_ops_;
};

View File

@ -207,11 +207,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
native.http_archive(
name = "com_googlesource_code_re2",
urls = [
"https://mirror.bazel.build/github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
"https://github.com/google/re2/archive/b94b7cd42e9f02673cd748c1ac1d16db4052514c.tar.gz",
"https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
"https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
],
sha256 = "bd63550101e056427c9e7ff12a408c1c8b74e9803f393ca916b2926fc2c4906f",
strip_prefix = "re2-b94b7cd42e9f02673cd748c1ac1d16db4052514c",
sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b",
strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857",
)
native.http_archive(
@ -800,3 +801,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
],
)
native.new_http_archive(
name = "tflite_smartreply",
build_file = str(Label("//third_party:tflite_smartreply.BUILD")),
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
urls = [
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip"
],
)

13
third_party/tflite_smartreply.BUILD vendored Normal file
View File

@ -0,0 +1,13 @@
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
filegroup(
name = "model_files",
srcs = glob(
["**/*"],
exclude = [
"BUILD",
],
),
)