[Java]: Support ConfigProto and RunOptions when loading SavedModels.
Fixes #18143 Fixes #20769 (Similar to #18716 by @raintung) PiperOrigin-RevId: 204575441
This commit is contained in:
parent
88b656acd4
commit
8aa4179ffa
@ -25,18 +25,86 @@ package org.tensorflow;
|
||||
* protocol buffer</a>).
|
||||
*/
|
||||
public class SavedModelBundle implements AutoCloseable {
|
||||
/** Options for loading a SavedModel. */
|
||||
public static final class Loader {
|
||||
/** Load a <code>SavedModelBundle</code> with the configured options. */
|
||||
public SavedModelBundle load() {
|
||||
return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets options to use when executing model initialization operations.
|
||||
*
|
||||
* @param options Serialized <a
|
||||
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
|
||||
* protocol buffer</a>.
|
||||
*/
|
||||
public Loader withRunOptions(byte[] options) {
|
||||
this.runOptions = options;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set configuration of the <code>Session</code> object created when loading the model.
|
||||
*
|
||||
* @param configProto Serialized <a
|
||||
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
|
||||
* protocol buffer</a>.
|
||||
*/
|
||||
public Loader withConfigProto(byte[] configProto) {
|
||||
this.configProto = configProto;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the set of tags that identify the specific graph in the saved model to load.
|
||||
*
|
||||
* @param tags the tags identifying the specific MetaGraphDef to load.
|
||||
*/
|
||||
public Loader withTags(String... tags) {
|
||||
this.tags = tags;
|
||||
return this;
|
||||
}
|
||||
|
||||
private Loader(String exportDir) {
|
||||
this.exportDir = exportDir;
|
||||
}
|
||||
|
||||
private String exportDir = null;
|
||||
private String[] tags = null;
|
||||
private byte[] configProto = null;
|
||||
private byte[] runOptions = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a saved model from an export directory. The model that is being loaded should be created
|
||||
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
|
||||
* API</a>.
|
||||
*
|
||||
* <p>This method is a shorthand for:
|
||||
*
|
||||
* <pre>{@code
|
||||
* SavedModelBundler.loader().withTags(tags).load();
|
||||
* }</pre>
|
||||
*
|
||||
* @param exportDir the directory path containing a saved model.
|
||||
* @param tags the tags identifying the specific metagraphdef to load.
|
||||
* @return a bundle containing the graph and associated session.
|
||||
*/
|
||||
public static SavedModelBundle load(String exportDir, String... tags) {
|
||||
return load(exportDir, tags, null);
|
||||
return loader(exportDir).withTags(tags).load();
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a saved model.
|
||||
*
|
||||
* <p/>Returns a <code>Loader</code> object that can set configuration options before actually
|
||||
* loading the model,
|
||||
*
|
||||
* @param exportDir the directory path containing a saved model.
|
||||
*/
|
||||
public static Loader loader(String exportDir) {
|
||||
return new Loader(exportDir);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
|
||||
return new SavedModelBundle(graph, session, metaGraphDef);
|
||||
}
|
||||
|
||||
private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
|
||||
private static native SavedModelBundle load(
|
||||
String exportDir, String[] tags, byte[] config, byte[] runOptions);
|
||||
|
||||
static {
|
||||
TensorFlow.init();
|
||||
|
@ -22,12 +22,25 @@ limitations under the License.
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
||||
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
|
||||
jbyteArray run_options) {
|
||||
jbyteArray config, jbyteArray run_options) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
jobject bundle = nullptr;
|
||||
|
||||
// allocate parameters for TF_LoadSessionFromSavedModel
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
if (config != nullptr) {
|
||||
size_t sz = env->GetArrayLength(config);
|
||||
if (sz > 0) {
|
||||
jbyte* config_data = env->GetByteArrayElements(config, nullptr);
|
||||
TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
|
||||
env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
|
||||
if (!throwExceptionIfNotOK(env, status)) {
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TF_DeleteStatus(status);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
TF_Buffer* crun_options = nullptr;
|
||||
if (run_options != nullptr) {
|
||||
size_t sz = env->GetArrayLength(run_options);
|
||||
|
@ -26,10 +26,10 @@ extern "C" {
|
||||
* Class: org_tensorflow_SavedModelBundle
|
||||
* Method: load
|
||||
* Signature:
|
||||
* (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle;
|
||||
* (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
||||
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray);
|
||||
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
|
@ -50,4 +50,58 @@ public class SavedModelBundleTest {
|
||||
assertTrue(e.getMessage().contains("Could not find SavedModel"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void loader() {
|
||||
try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH)
|
||||
.withTags("serve")
|
||||
.withConfigProto(sillyConfigProto())
|
||||
.withRunOptions(sillyRunOptions())
|
||||
.load()) {
|
||||
assertNotNull(bundle.session());
|
||||
assertNotNull(bundle.graph());
|
||||
assertNotNull(bundle.metaGraphDef());
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] sillyRunOptions() {
|
||||
// Ideally this would use the generated Java sources for protocol buffers
|
||||
// and end up with something like the snippet below. However, generating
|
||||
// the Java files for the .proto files in tensorflow/core:protos_all is
|
||||
// a bit cumbersome in bazel until the proto_library rule is setup.
|
||||
//
|
||||
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
|
||||
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
|
||||
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
|
||||
//
|
||||
// For this test, for now, the use of specific bytes suffices.
|
||||
return new byte[] {0x08, 0x03};
|
||||
/*
|
||||
return org.tensorflow.framework.RunOptions.newBuilder()
|
||||
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
|
||||
.build()
|
||||
.toByteArray();
|
||||
*/
|
||||
}
|
||||
|
||||
public static byte[] sillyConfigProto() {
|
||||
// Ideally this would use the generated Java sources for protocol buffers
|
||||
// and end up with something like the snippet below. However, generating
|
||||
// the Java files for the .proto files in tensorflow/core:protos_all is
|
||||
// a bit cumbersome in bazel until the proto_library rule is setup.
|
||||
//
|
||||
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
|
||||
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
|
||||
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
|
||||
//
|
||||
// For this test, for now, the use of specific bytes suffices.
|
||||
return new byte[] {0x10, 0x01, 0x28, 0x01};
|
||||
/*
|
||||
return org.tensorflow.framework.ConfigProto.newBuilder()
|
||||
.setInterOpParallelismThreads(1)
|
||||
.setIntraOpParallelismThreads(1)
|
||||
.build()
|
||||
.toByteArray();
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user