[Java]: Support ConfigProto and RunOptions when loading SavedModels.
Fixes #18143 Fixes #20769 (Similar to #18716 by @raintung) PiperOrigin-RevId: 204575441
This commit is contained in:
parent
88b656acd4
commit
8aa4179ffa
@ -25,18 +25,86 @@ package org.tensorflow;
|
|||||||
* protocol buffer</a>).
|
* protocol buffer</a>).
|
||||||
*/
|
*/
|
||||||
public class SavedModelBundle implements AutoCloseable {
|
public class SavedModelBundle implements AutoCloseable {
|
||||||
|
/** Options for loading a SavedModel. */
|
||||||
|
public static final class Loader {
|
||||||
|
/** Load a <code>SavedModelBundle</code> with the configured options. */
|
||||||
|
public SavedModelBundle load() {
|
||||||
|
return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets options to use when executing model initialization operations.
|
||||||
|
*
|
||||||
|
* @param options Serialized <a
|
||||||
|
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
|
||||||
|
* protocol buffer</a>.
|
||||||
|
*/
|
||||||
|
public Loader withRunOptions(byte[] options) {
|
||||||
|
this.runOptions = options;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set configuration of the <code>Session</code> object created when loading the model.
|
||||||
|
*
|
||||||
|
* @param configProto Serialized <a
|
||||||
|
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
|
||||||
|
* protocol buffer</a>.
|
||||||
|
*/
|
||||||
|
public Loader withConfigProto(byte[] configProto) {
|
||||||
|
this.configProto = configProto;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the set of tags that identify the specific graph in the saved model to load.
|
||||||
|
*
|
||||||
|
* @param tags the tags identifying the specific MetaGraphDef to load.
|
||||||
|
*/
|
||||||
|
public Loader withTags(String... tags) {
|
||||||
|
this.tags = tags;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Loader(String exportDir) {
|
||||||
|
this.exportDir = exportDir;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String exportDir = null;
|
||||||
|
private String[] tags = null;
|
||||||
|
private byte[] configProto = null;
|
||||||
|
private byte[] runOptions = null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a saved model from an export directory. The model that is being loaded should be created
|
* Load a saved model from an export directory. The model that is being loaded should be created
|
||||||
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
|
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
|
||||||
* API</a>.
|
* API</a>.
|
||||||
*
|
*
|
||||||
|
* <p>This method is a shorthand for:
|
||||||
|
*
|
||||||
|
* <pre>{@code
|
||||||
|
* SavedModelBundler.loader().withTags(tags).load();
|
||||||
|
* }</pre>
|
||||||
|
*
|
||||||
* @param exportDir the directory path containing a saved model.
|
* @param exportDir the directory path containing a saved model.
|
||||||
* @param tags the tags identifying the specific metagraphdef to load.
|
* @param tags the tags identifying the specific metagraphdef to load.
|
||||||
* @return a bundle containing the graph and associated session.
|
* @return a bundle containing the graph and associated session.
|
||||||
*/
|
*/
|
||||||
public static SavedModelBundle load(String exportDir, String... tags) {
|
public static SavedModelBundle load(String exportDir, String... tags) {
|
||||||
return load(exportDir, tags, null);
|
return loader(exportDir).withTags(tags).load();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a saved model.
|
||||||
|
*
|
||||||
|
* <p/>Returns a <code>Loader</code> object that can set configuration options before actually
|
||||||
|
* loading the model,
|
||||||
|
*
|
||||||
|
* @param exportDir the directory path containing a saved model.
|
||||||
|
*/
|
||||||
|
public static Loader loader(String exportDir) {
|
||||||
|
return new Loader(exportDir);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
|
|||||||
return new SavedModelBundle(graph, session, metaGraphDef);
|
return new SavedModelBundle(graph, session, metaGraphDef);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
|
private static native SavedModelBundle load(
|
||||||
|
String exportDir, String[] tags, byte[] config, byte[] runOptions);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
TensorFlow.init();
|
TensorFlow.init();
|
||||||
|
@ -22,12 +22,25 @@ limitations under the License.
|
|||||||
|
|
||||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
||||||
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
|
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
|
||||||
jbyteArray run_options) {
|
jbyteArray config, jbyteArray run_options) {
|
||||||
TF_Status* status = TF_NewStatus();
|
TF_Status* status = TF_NewStatus();
|
||||||
jobject bundle = nullptr;
|
jobject bundle = nullptr;
|
||||||
|
|
||||||
// allocate parameters for TF_LoadSessionFromSavedModel
|
// allocate parameters for TF_LoadSessionFromSavedModel
|
||||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||||
|
if (config != nullptr) {
|
||||||
|
size_t sz = env->GetArrayLength(config);
|
||||||
|
if (sz > 0) {
|
||||||
|
jbyte* config_data = env->GetByteArrayElements(config, nullptr);
|
||||||
|
TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
|
||||||
|
env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
|
||||||
|
if (!throwExceptionIfNotOK(env, status)) {
|
||||||
|
TF_DeleteSessionOptions(opts);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
TF_Buffer* crun_options = nullptr;
|
TF_Buffer* crun_options = nullptr;
|
||||||
if (run_options != nullptr) {
|
if (run_options != nullptr) {
|
||||||
size_t sz = env->GetArrayLength(run_options);
|
size_t sz = env->GetArrayLength(run_options);
|
||||||
|
@ -26,10 +26,10 @@ extern "C" {
|
|||||||
* Class: org_tensorflow_SavedModelBundle
|
* Class: org_tensorflow_SavedModelBundle
|
||||||
* Method: load
|
* Method: load
|
||||||
* Signature:
|
* Signature:
|
||||||
* (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle;
|
* (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
|
||||||
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray);
|
JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -50,4 +50,58 @@ public class SavedModelBundleTest {
|
|||||||
assertTrue(e.getMessage().contains("Could not find SavedModel"));
|
assertTrue(e.getMessage().contains("Could not find SavedModel"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void loader() {
|
||||||
|
try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH)
|
||||||
|
.withTags("serve")
|
||||||
|
.withConfigProto(sillyConfigProto())
|
||||||
|
.withRunOptions(sillyRunOptions())
|
||||||
|
.load()) {
|
||||||
|
assertNotNull(bundle.session());
|
||||||
|
assertNotNull(bundle.graph());
|
||||||
|
assertNotNull(bundle.metaGraphDef());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte[] sillyRunOptions() {
|
||||||
|
// Ideally this would use the generated Java sources for protocol buffers
|
||||||
|
// and end up with something like the snippet below. However, generating
|
||||||
|
// the Java files for the .proto files in tensorflow/core:protos_all is
|
||||||
|
// a bit cumbersome in bazel until the proto_library rule is setup.
|
||||||
|
//
|
||||||
|
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
|
||||||
|
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
|
||||||
|
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
|
||||||
|
//
|
||||||
|
// For this test, for now, the use of specific bytes suffices.
|
||||||
|
return new byte[] {0x08, 0x03};
|
||||||
|
/*
|
||||||
|
return org.tensorflow.framework.RunOptions.newBuilder()
|
||||||
|
.setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
|
||||||
|
.build()
|
||||||
|
.toByteArray();
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
public static byte[] sillyConfigProto() {
|
||||||
|
// Ideally this would use the generated Java sources for protocol buffers
|
||||||
|
// and end up with something like the snippet below. However, generating
|
||||||
|
// the Java files for the .proto files in tensorflow/core:protos_all is
|
||||||
|
// a bit cumbersome in bazel until the proto_library rule is setup.
|
||||||
|
//
|
||||||
|
// See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
|
||||||
|
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
|
||||||
|
// https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
|
||||||
|
//
|
||||||
|
// For this test, for now, the use of specific bytes suffices.
|
||||||
|
return new byte[] {0x10, 0x01, 0x28, 0x01};
|
||||||
|
/*
|
||||||
|
return org.tensorflow.framework.ConfigProto.newBuilder()
|
||||||
|
.setInterOpParallelismThreads(1)
|
||||||
|
.setIntraOpParallelismThreads(1)
|
||||||
|
.build()
|
||||||
|
.toByteArray();
|
||||||
|
*/
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user