[Java]: Support ConfigProto and RunOptions when loading SavedModels.

Fixes #18143
Fixes #20769

(Similar to #18716 by @raintung)

PiperOrigin-RevId: 204575441
This commit is contained in:
Asim Shankar 2018-07-14 00:15:49 -07:00 committed by TensorFlower Gardener
parent 88b656acd4
commit 8aa4179ffa
4 changed files with 141 additions and 5 deletions

View File

@ -25,18 +25,86 @@ package org.tensorflow;
* protocol buffer</a>).
*/
public class SavedModelBundle implements AutoCloseable {
/** Options for loading a SavedModel. */
public static final class Loader {
/** Load a <code>SavedModelBundle</code> with the configured options. */
public SavedModelBundle load() {
return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
}
/**
* Sets options to use when executing model initialization operations.
*
* @param options Serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
* protocol buffer</a>.
*/
public Loader withRunOptions(byte[] options) {
this.runOptions = options;
return this;
}
/**
* Set configuration of the <code>Session</code> object created when loading the model.
*
* @param configProto Serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
* protocol buffer</a>.
*/
public Loader withConfigProto(byte[] configProto) {
this.configProto = configProto;
return this;
}
/**
* Sets the set of tags that identify the specific graph in the saved model to load.
*
* @param tags the tags identifying the specific MetaGraphDef to load.
*/
public Loader withTags(String... tags) {
this.tags = tags;
return this;
}
private Loader(String exportDir) {
this.exportDir = exportDir;
}
private String exportDir = null;
private String[] tags = null;
private byte[] configProto = null;
private byte[] runOptions = null;
}
/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
* API</a>.
*
* <p>This method is a shorthand for:
*
* <pre>{@code
* SavedModelBundler.loader().withTags(tags).load();
* }</pre>
*
* @param exportDir the directory path containing a saved model.
* @param tags the tags identifying the specific metagraphdef to load.
* @return a bundle containing the graph and associated session.
*/
public static SavedModelBundle load(String exportDir, String... tags) {
return load(exportDir, tags, null);
return loader(exportDir).withTags(tags).load();
}
/**
* Load a saved model.
*
* <p/>Returns a <code>Loader</code> object that can set configuration options before actually
* loading the model,
*
* @param exportDir the directory path containing a saved model.
*/
public static Loader loader(String exportDir) {
return new Loader(exportDir);
}
/**
@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
return new SavedModelBundle(graph, session, metaGraphDef);
}
private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
private static native SavedModelBundle load(
String exportDir, String[] tags, byte[] config, byte[] runOptions);
static {
TensorFlow.init();

View File

@ -22,12 +22,25 @@ limitations under the License.
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
jbyteArray run_options) {
jbyteArray config, jbyteArray run_options) {
TF_Status* status = TF_NewStatus();
jobject bundle = nullptr;
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions* opts = TF_NewSessionOptions();
if (config != nullptr) {
size_t sz = env->GetArrayLength(config);
if (sz > 0) {
jbyte* config_data = env->GetByteArrayElements(config, nullptr);
TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteSessionOptions(opts);
TF_DeleteStatus(status);
return nullptr;
}
}
}
TF_Buffer* crun_options = nullptr;
if (run_options != nullptr) {
size_t sz = env->GetArrayLength(run_options);

View File

@ -26,10 +26,10 @@ extern "C" {
* Class: org_tensorflow_SavedModelBundle
* Method: load
* Signature:
* (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle;
* (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
*/
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray);
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
#ifdef __cplusplus
} // extern "C"

View File

@ -50,4 +50,58 @@ public class SavedModelBundleTest {
assertTrue(e.getMessage().contains("Could not find SavedModel"));
}
}
@Test
public void loader() {
try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH)
.withTags("serve")
.withConfigProto(sillyConfigProto())
.withRunOptions(sillyRunOptions())
.load()) {
assertNotNull(bundle.session());
assertNotNull(bundle.graph());
assertNotNull(bundle.metaGraphDef());
}
}
private static byte[] sillyRunOptions() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
// the Java files for the .proto files in tensorflow/core:protos_all is
// a bit cumbersome in bazel until the proto_library rule is setup.
//
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
//
// For this test, for now, the use of specific bytes suffices.
return new byte[] {0x08, 0x03};
/*
return org.tensorflow.framework.RunOptions.newBuilder()
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
.build()
.toByteArray();
*/
}
public static byte[] sillyConfigProto() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
// the Java files for the .proto files in tensorflow/core:protos_all is
// a bit cumbersome in bazel until the proto_library rule is setup.
//
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
//
// For this test, for now, the use of specific bytes suffices.
return new byte[] {0x10, 0x01, 0x28, 0x01};
/*
return org.tensorflow.framework.ConfigProto.newBuilder()
.setInterOpParallelismThreads(1)
.setIntraOpParallelismThreads(1)
.build()
.toByteArray();
*/
}
}