Track usage of reference variables in TF graphs.

Adds a gauge to track usage of reference variables. This is checked on Graph import to MLIR; thus the gauge is only set if the MLIR-based bridge is enabled.

PiperOrigin-RevId: 338165779
Change-Id: Ic807b66e1b4678f3651e662547334fb6e654f463
This commit is contained in:
Lucy Fox 2020-10-20 17:00:30 -07:00 committed by TensorFlower Gardener
parent 19c8b34112
commit 573ded613d

View File

@ -104,6 +104,7 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/path.h"
@ -130,6 +131,10 @@ using stream_executor::port::StatusOr;
namespace {
auto* reference_variable_gauge = tensorflow::monitoring::Gauge<bool, 0>::New(
"/tensorflow/core/uses_reference_variables",
"Tracks if reference variables are used anywhere in the graph");
constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
bool IsOutputShapesAttribute(const AttrValue& attr_value,
@ -2057,6 +2062,11 @@ class GraphDefImporter : public ImporterBase {
llvm::StringRef func_name);
private:
// Checks if a Module contains any ref variables in any operation operands
// or results, including checking Block arguments and operations within
// regions.
static bool ModuleContainsRefType(mlir::ModuleOp module);
explicit GraphDefImporter(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const GraphImportConfig& specs, mlir::ModuleOp module,
@ -2092,6 +2102,38 @@ class GraphDefImporter : public ImporterBase {
absl::InlinedVector<Node*, 4>* control_ret_nodes);
};
bool IsTensorFlowRefType(mlir::Type ty) {
return mlir::getElementTypeOrSelf(ty).isa<mlir::TF::TensorFlowRefType>();
}
bool OpHasRefTypeOperandOrResult(mlir::Operation* op) {
// Check op operands.
for (mlir::Type ty : op->getOperandTypes())
if (IsTensorFlowRefType(ty)) return true;
// Check op results.
for (mlir::Type ty : op->getResultTypes())
if (IsTensorFlowRefType(ty)) return true;
// Check all block arguments within any regions the op has.
for (mlir::Region& region : op->getRegions())
for (mlir::Block& block : region)
for (auto& arg : block.getArguments())
if (IsTensorFlowRefType(arg.getType())) return true;
return false;
}
bool GraphDefImporter::ModuleContainsRefType(mlir::ModuleOp module) {
// If walk is interrupted at any point, that means a ref variable was found.
// At this point, we've confirmed existence of a ref variable and don't need
// to continue looking.
return module
.walk([&](mlir::Operation* op) {
if (OpHasRefTypeOperandOrResult(op))
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
})
.wasInterrupted();
}
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
@ -2189,6 +2231,13 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
// Check if there are any reference variables in the module.
bool contains_ref_var = ModuleContainsRefType(*module);
reference_variable_gauge->GetCell()->Set(contains_ref_var);
if (contains_ref_var) {
VLOG(1) << "Graph contains one or more reference variables";
}
// Mark main function public, others private.
for (auto function : module.get().getOps<mlir::FuncOp>()) {
auto visibility = function.getName() == func_name