diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
index c8b9126f033..49594e6b47b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -25,18 +25,86 @@ package org.tensorflow;
* protocol buffer).
*/
public class SavedModelBundle implements AutoCloseable {
+ /** Options for loading a SavedModel. */
+ public static final class Loader {
+ /** Load a SavedModelBundle
with the configured options. */
+ public SavedModelBundle load() {
+ return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
+ }
+
+ /**
+ * Sets options to use when executing model initialization operations.
+ *
+ * @param options Serialized RunOptions
+ * protocol buffer.
+ */
+ public Loader withRunOptions(byte[] options) {
+ this.runOptions = options;
+ return this;
+ }
+
+ /**
+ * Set configuration of the Session
object created when loading the model.
+ *
+ * @param configProto Serialized ConfigProto
+ * protocol buffer.
+ */
+ public Loader withConfigProto(byte[] configProto) {
+ this.configProto = configProto;
+ return this;
+ }
+
+ /**
+ * Sets the set of tags that identify the specific graph in the saved model to load.
+ *
+ * @param tags the tags identifying the specific MetaGraphDef to load.
+ */
+ public Loader withTags(String... tags) {
+ this.tags = tags;
+ return this;
+ }
+
+ private Loader(String exportDir) {
+ this.exportDir = exportDir;
+ }
+
+ private String exportDir = null;
+ private String[] tags = null;
+ private byte[] configProto = null;
+ private byte[] runOptions = null;
+ }
/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the Saved Model
* API.
*
+ *
This method is a shorthand for: + * + *
{@code + * SavedModelBundler.loader().withTags(tags).load(); + * }+ * * @param exportDir the directory path containing a saved model. * @param tags the tags identifying the specific metagraphdef to load. * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return load(exportDir, tags, null); + return loader(exportDir).withTags(tags).load(); + } + + /** + * Load a saved model. + * + * Returns a
Loader
object that can set configuration options before actually
+ * loading the model,
+ *
+ * @param exportDir the directory path containing a saved model.
+ */
+ public static Loader loader(String exportDir) {
+ return new Loader(exportDir);
}
/**
@@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
return new SavedModelBundle(graph, session, metaGraphDef);
}
- private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
+ private static native SavedModelBundle load(
+ String exportDir, String[] tags, byte[] config, byte[] runOptions);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
index de6382a79c4..68999fb2da8 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
@@ -22,12 +22,25 @@ limitations under the License.
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
- jbyteArray run_options) {
+ jbyteArray config, jbyteArray run_options) {
TF_Status* status = TF_NewStatus();
jobject bundle = nullptr;
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions* opts = TF_NewSessionOptions();
+ if (config != nullptr) {
+ size_t sz = env->GetArrayLength(config);
+ if (sz > 0) {
+ jbyte* config_data = env->GetByteArrayElements(config, nullptr);
+ TF_SetConfig(opts, static_cast