Some models are using "TF2" but also using Session and passing a ConfigProto. This is TF1 code running on the TF2 MLIR bridge. The TF2 MLIR bridge assumed that this was not possible. This cl updates the TF2 version of the MLIR bridge to support a ConfigProto passed in via Session. Disable MLIR bridge presubmit testing of saved_model_test.py because this fix reveals that the test is actually broken. It is using TF2 control flow but loading a model with Session. PiperOrigin-RevId: 335915669 Change-Id: Ib50bef389449ce0011878dd50b73856e9c520289
74 lines
2.9 KiB
C++
74 lines
2.9 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
|
|
|
#include <string>
|
|
|
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
|
#include "tensorflow/core/public/session_options.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
auto* mlir_bridge_gauge_v1 = monitoring::Gauge<bool, 0>::New(
|
|
"/tensorflow/config/experimental/enable_mlir_bridge_gauge_v1",
|
|
"Tracks usage of the MLIR-based TF2XLA bridge among TF1 models");
|
|
auto* mlir_bridge_gauge_v2 = monitoring::Gauge<bool, 0>::New(
|
|
"/tensorflow/config/experimental/enable_mlir_bridge_gauge_v2",
|
|
"Tracks usage of the MLIR-based TF2XLA bridge among TF2 models");
|
|
|
|
// This runs the first phase of the "bridge", transforming the graph in a form
|
|
// that can be executed with delegation of some computations to an accelerator.
|
|
// This builds on the model of XLA where a subset of the graph is encapsulated
|
|
// and attached to a "compile" operation, whose result is fed to an "execute"
|
|
// operation. The kernel for these operations is responsible to lower the
|
|
// encapsulated graph to a particular device.
|
|
Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
|
mlir::ModuleOp module) {
|
|
if (!IsEnabled(config_proto)) {
|
|
VLOG(0) << "Skipping MLIR TPU Bridge, session flag not enabled";
|
|
mlir_bridge_gauge_v2->GetCell()->Set(false);
|
|
return Status::OK();
|
|
}
|
|
|
|
VLOG(0) << "Running MLIR TPU Bridge";
|
|
mlir_bridge_gauge_v2->GetCell()->Set(true);
|
|
TF_RETURN_IF_ERROR(
|
|
mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
|
|
|
return Status::OK();
|
|
}
|
|
Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
|
mlir::ModuleOp module) {
|
|
// Skip function graphs as MlirBridgePass will be used instead.
|
|
if (options.is_function_graph) return Status::OK();
|
|
|
|
if (!IsEnabled(options.session_options->config)) {
|
|
VLOG(0) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
|
|
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
|
return Status::OK();
|
|
}
|
|
|
|
VLOG(0) << "Running MLIR TPU Bridge V1 Compat";
|
|
mlir_bridge_gauge_v1->GetCell()->Set(true);
|
|
TF_RETURN_IF_ERROR(
|
|
mlir::TFTPU::TPUBridgeV1Compat(module, /*enable_logging=*/VLOG_IS_ON(1)));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace tensorflow
|