For custom optimizers, keep a copy of Session ConfigProto_.
PiperOrigin-RevId: 292974333 Change-Id: Ia49fae2149aea5e3e4fec5ba9d4b995406a6a1e4
This commit is contained in:
parent
ab94de29f5
commit
be453c9869
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -29,6 +30,15 @@ class CustomGraphOptimizer : public GraphOptimizer {
|
||||
virtual ~CustomGraphOptimizer() {}
|
||||
virtual Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer*
|
||||
config = nullptr) = 0;
|
||||
// Populates ConfigProto on which the Session is run prior to running Init.
|
||||
Status InitWithConfig(
|
||||
const ConfigProto& config_proto,
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config = nullptr) {
|
||||
config_proto_ = config_proto;
|
||||
return this->Init(config);
|
||||
}
|
||||
|
||||
ConfigProto config_proto_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -299,8 +299,8 @@ Status MetaOptimizer::InitializeOptimizersByName(
|
||||
|
||||
if (custom_optimizer) {
|
||||
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
|
||||
TF_RETURN_IF_ERROR(custom_optimizer->Init(
|
||||
GetCustomGraphOptimizerConfig(optimizer_name)));
|
||||
TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
|
||||
config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
|
||||
optimizers->push_back(std::move(custom_optimizer));
|
||||
initialized_custom_optimizers.insert(optimizer_name);
|
||||
} else {
|
||||
@ -326,7 +326,8 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
|
||||
if (custom_optimizer) {
|
||||
VLOG(2) << "Registered custom configurable graph optimizer: "
|
||||
<< optimizer_config.name();
|
||||
TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
|
||||
TF_RETURN_IF_ERROR(
|
||||
custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
|
||||
optimizers->push_back(std::move(custom_optimizer));
|
||||
} else {
|
||||
// If there are no custom optimizers with given name, try to initialize a
|
||||
|
Loading…
Reference in New Issue
Block a user