diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h index ab9af5acff4..b7a6029846d 100644 --- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { @@ -29,6 +30,15 @@ class CustomGraphOptimizer : public GraphOptimizer { virtual ~CustomGraphOptimizer() {} virtual Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config = nullptr) = 0; + // Populates ConfigProto on which the Session is run prior to running Init. + Status InitWithConfig( + const ConfigProto& config_proto, + const tensorflow::RewriterConfig_CustomGraphOptimizer* config = nullptr) { + config_proto_ = config_proto; + return this->Init(config); + } + + ConfigProto config_proto_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6c6f9944d9a..e4562357e91 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -299,8 +299,8 @@ Status MetaOptimizer::InitializeOptimizersByName( if (custom_optimizer) { VLOG(2) << "Registered custom graph optimizer: " << optimizer_name; - TF_RETURN_IF_ERROR(custom_optimizer->Init( - GetCustomGraphOptimizerConfig(optimizer_name))); + TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig( + config_proto_, GetCustomGraphOptimizerConfig(optimizer_name))); optimizers->push_back(std::move(custom_optimizer)); initialized_custom_optimizers.insert(optimizer_name); } else { @@ -326,7 +326,8 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers( if (custom_optimizer) { VLOG(2) << "Registered custom configurable graph optimizer: " << optimizer_config.name(); - TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config)); + TF_RETURN_IF_ERROR( + custom_optimizer->InitWithConfig(config_proto_, &optimizer_config)); optimizers->push_back(std::move(custom_optimizer)); } else { // If there are no custom optimizers with given name, try to initialize a