Add debug flags to check if XLA cluster inputs/outputs contain NaNs or Infs
PiperOrigin-RevId: 262379591
This commit is contained in:
parent
b4d36cdc68
commit
8ea321bf7d
@ -48,6 +48,19 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
struct DebuggingOpts {
|
||||
// If true, insert Print nodes to print every output from an XLA cluster.
|
||||
bool print_outputs;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||
// an XLA cluster.
|
||||
bool check_input_numerics;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||
// from an XLA cluster.
|
||||
bool check_output_numerics;
|
||||
};
|
||||
|
||||
void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
|
||||
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
|
||||
old_node->out_edges().end());
|
||||
@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) {
|
||||
// Replaces each outgoing edge from `old_node` with a merge node that merges in
|
||||
// the corresponding output from `new_node`.
|
||||
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||
bool insert_print_nodes) {
|
||||
absl::string_view cluster_name,
|
||||
const DebuggingOpts& debugging_opts) {
|
||||
if (!s.status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||
int oidx = e->src_output();
|
||||
Output merged_output = merged_outputs[oidx];
|
||||
if (merged_output.node() == nullptr) {
|
||||
ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
|
||||
{Output(old_node, oidx), Output(new_node, oidx)});
|
||||
if (insert_print_nodes) {
|
||||
Output new_output(new_node, oidx);
|
||||
if (debugging_opts.print_outputs) {
|
||||
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
|
||||
ops::Print print_op(s.WithOpName("print_", oidx)
|
||||
.WithDevice(cpu_device)
|
||||
.WithAssignedDevice(cpu_device),
|
||||
merge_op.output, {merge_op.output},
|
||||
new_output, {new_output},
|
||||
ops::Print::Attrs{}
|
||||
.Message(absl::StrCat("output ", oidx, " from ",
|
||||
old_node->name(), " is "))
|
||||
.FirstN(1000)
|
||||
.Summarize(-1));
|
||||
merged_output = merged_outputs[oidx] = print_op;
|
||||
} else {
|
||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||
new_output = print_op;
|
||||
}
|
||||
|
||||
if (debugging_opts.check_output_numerics &&
|
||||
DataTypeIsFloating(new_output.type())) {
|
||||
ops::CheckNumerics check_numerics_op(
|
||||
s.WithOpName("check_output_", oidx)
|
||||
.WithDevice(new_node->requested_device())
|
||||
.WithAssignedDevice(new_node->assigned_device_name()),
|
||||
new_output,
|
||||
absl::StrCat("CheckNumerics failed for output ", oidx, "(",
|
||||
new_output.name(), ") from cluster ", cluster_name));
|
||||
new_output = check_numerics_op;
|
||||
}
|
||||
|
||||
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||
{Output(old_node, oidx), new_output});
|
||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||
}
|
||||
|
||||
Node* dst = e->dst();
|
||||
@ -324,11 +351,34 @@ xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<Output> GetXlaRunArgs(const Scope& s,
|
||||
const XlaClusterInfo& cluster_info,
|
||||
const DebuggingOpts& debugging_opts) {
|
||||
std::vector<Output> xla_run_args;
|
||||
xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
|
||||
cluster_info.resource_inputs.size());
|
||||
int input_idx = 0;
|
||||
for (const Output& o : cluster_info.non_constant_inputs) {
|
||||
if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
|
||||
ops::CheckNumerics check_numerics_op(
|
||||
s.WithOpName("check_input_", input_idx), o,
|
||||
absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
|
||||
o.name(), ") into ", cluster_info.function.name()));
|
||||
xla_run_args.push_back(check_numerics_op);
|
||||
} else {
|
||||
xla_run_args.push_back(o);
|
||||
}
|
||||
input_idx++;
|
||||
}
|
||||
absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
|
||||
return xla_run_args;
|
||||
}
|
||||
|
||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
jit::DeviceInfoCache* device_info_cache,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||
bool insert_print_nodes, Graph* g, Node* n) {
|
||||
const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
|
||||
XlaClusterInfo cluster_info;
|
||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||
|
||||
@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
TF_RETURN_IF_ERROR(
|
||||
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
||||
|
||||
std::vector<Output> xla_run_args =
|
||||
GetXlaRunArgs(root, cluster_info, debugging_opts);
|
||||
|
||||
if (requires_compilation) {
|
||||
// "Strict" compilation: every _XlaCompile invocation must compile the
|
||||
// cluster.
|
||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
||||
absl::c_copy(cluster_info.resource_inputs,
|
||||
std::back_inserter(xla_run_args));
|
||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||
xla_compile.key, n->output_types());
|
||||
|
||||
@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
Output predicated_compilation_key = s.output_true;
|
||||
Output inverse_predicated_compilation_key = s.output_false;
|
||||
|
||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
||||
absl::c_copy(cluster_info.resource_inputs,
|
||||
std::back_inserter(xla_run_args));
|
||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||
predicated_compilation_key, n->output_types());
|
||||
|
||||
@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
|
||||
MergeOutgoingDataEdges(root, /*old_node=*/n,
|
||||
/*new_node=*/xla_run.operation.node(),
|
||||
insert_print_nodes);
|
||||
cluster_info.function.name(), debugging_opts);
|
||||
|
||||
TF_RETURN_IF_ERROR(root.status());
|
||||
|
||||
@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
enable_lazy_compilation_
|
||||
? *enable_lazy_compilation_
|
||||
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
|
||||
bool insert_print_nodes =
|
||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
||||
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
|
||||
|
||||
DebuggingOpts debugging_opts;
|
||||
debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
|
||||
debugging_opts.check_input_numerics =
|
||||
flags.tf_xla_check_cluster_input_numerics;
|
||||
debugging_opts.check_output_numerics =
|
||||
flags.tf_xla_check_cluster_output_numerics;
|
||||
|
||||
VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
|
||||
VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
|
||||
VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
|
||||
|
||||
for (Node* n : xla_compiled_kernels) {
|
||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
&device_info_cache, options, *options.flib_def,
|
||||
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
||||
lazy_compilation_enabled, debugging_opts, graph, n));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
|
||||
@ -105,6 +105,8 @@ void AllocateAndParseFlags() {
|
||||
build_ops_flags = new BuildXlaOpsPassFlags;
|
||||
build_ops_flags->tf_xla_enable_lazy_compilation = true;
|
||||
build_ops_flags->tf_xla_print_cluster_outputs = false;
|
||||
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
|
||||
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
|
||||
build_ops_flags->tf_xla_disable_constant_folding = false;
|
||||
|
||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||
@ -144,6 +146,14 @@ void AllocateAndParseFlags() {
|
||||
&build_ops_flags->tf_xla_print_cluster_outputs,
|
||||
"If true then insert Print nodes to print out values produced by "
|
||||
"XLA clusters."),
|
||||
Flag("tf_xla_check_cluster_input_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_input_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"inputs."),
|
||||
Flag("tf_xla_check_cluster_output_numerics",
|
||||
&build_ops_flags->tf_xla_check_cluster_output_numerics,
|
||||
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||
"outputs."),
|
||||
|
||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||
"Switch a device into 'on-demand' mode, where instead of "
|
||||
|
||||
@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags {
|
||||
// clusters. Useful for debugging.
|
||||
bool tf_xla_print_cluster_outputs;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||
// an XLA cluster.
|
||||
bool tf_xla_check_cluster_input_numerics;
|
||||
|
||||
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||
// from an XLA cluster.
|
||||
bool tf_xla_check_cluster_output_numerics;
|
||||
|
||||
// Disables all constant folding. The primary use for this is for testing to
|
||||
// guarantee that tests are run on XLA and not on TF's CPU implementation.
|
||||
bool tf_xla_disable_constant_folding;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user