Track usage of reference variables in TF graphs.
Adds a gauge to track usage of reference variables. This is checked on Graph import to MLIR; thus the gauge is only set if the MLIR-based bridge is enabled. PiperOrigin-RevId: 338165779 Change-Id: Ic807b66e1b4678f3651e662547334fb6e654f463
This commit is contained in:
parent
19c8b34112
commit
573ded613d
@ -104,6 +104,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
@ -130,6 +131,10 @@ using stream_executor::port::StatusOr;
|
||||
|
||||
namespace {
|
||||
|
||||
auto* reference_variable_gauge = tensorflow::monitoring::Gauge<bool, 0>::New(
|
||||
"/tensorflow/core/uses_reference_variables",
|
||||
"Tracks if reference variables are used anywhere in the graph");
|
||||
|
||||
constexpr char kTpuReplicateAttr[] = "_tpu_replicate";
|
||||
|
||||
bool IsOutputShapesAttribute(const AttrValue& attr_value,
|
||||
@ -2057,6 +2062,11 @@ class GraphDefImporter : public ImporterBase {
|
||||
llvm::StringRef func_name);
|
||||
|
||||
private:
|
||||
// Checks if a Module contains any ref variables in any operation operands
|
||||
// or results, including checking Block arguments and operations within
|
||||
// regions.
|
||||
static bool ModuleContainsRefType(mlir::ModuleOp module);
|
||||
|
||||
explicit GraphDefImporter(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
@ -2092,6 +2102,38 @@ class GraphDefImporter : public ImporterBase {
|
||||
absl::InlinedVector<Node*, 4>* control_ret_nodes);
|
||||
};
|
||||
|
||||
bool IsTensorFlowRefType(mlir::Type ty) {
|
||||
return mlir::getElementTypeOrSelf(ty).isa<mlir::TF::TensorFlowRefType>();
|
||||
}
|
||||
|
||||
bool OpHasRefTypeOperandOrResult(mlir::Operation* op) {
|
||||
// Check op operands.
|
||||
for (mlir::Type ty : op->getOperandTypes())
|
||||
if (IsTensorFlowRefType(ty)) return true;
|
||||
// Check op results.
|
||||
for (mlir::Type ty : op->getResultTypes())
|
||||
if (IsTensorFlowRefType(ty)) return true;
|
||||
// Check all block arguments within any regions the op has.
|
||||
for (mlir::Region& region : op->getRegions())
|
||||
for (mlir::Block& block : region)
|
||||
for (auto& arg : block.getArguments())
|
||||
if (IsTensorFlowRefType(arg.getType())) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GraphDefImporter::ModuleContainsRefType(mlir::ModuleOp module) {
|
||||
// If walk is interrupted at any point, that means a ref variable was found.
|
||||
// At this point, we've confirmed existence of a ref variable and don't need
|
||||
// to continue looking.
|
||||
return module
|
||||
.walk([&](mlir::Operation* op) {
|
||||
if (OpHasRefTypeOperandOrResult(op))
|
||||
return mlir::WalkResult::interrupt();
|
||||
return mlir::WalkResult::advance();
|
||||
})
|
||||
.wasInterrupted();
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||
@ -2189,6 +2231,13 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
|
||||
|
||||
// Check if there are any reference variables in the module.
|
||||
bool contains_ref_var = ModuleContainsRefType(*module);
|
||||
reference_variable_gauge->GetCell()->Set(contains_ref_var);
|
||||
if (contains_ref_var) {
|
||||
VLOG(1) << "Graph contains one or more reference variables";
|
||||
}
|
||||
|
||||
// Mark main function public, others private.
|
||||
for (auto function : module.get().getOps<mlir::FuncOp>()) {
|
||||
auto visibility = function.getName() == func_name
|
||||
|
Loading…
x
Reference in New Issue
Block a user