Add debug flags to check if XLA cluster inputs/outputs contain NaNs or Infs

PiperOrigin-RevId: 262379591
This commit is contained in:
Sanjoy Das 2019-08-08 10:13:37 -07:00 committed by TensorFlower Gardener
parent b4d36cdc68
commit 8ea321bf7d
3 changed files with 95 additions and 20 deletions

View File

@ -48,6 +48,19 @@ limitations under the License.
namespace tensorflow {
namespace {
struct DebuggingOpts {
// If true, insert Print nodes to print every output from an XLA cluster.
bool print_outputs;
// If true, insert CheckNumerics nodes for every floating point typed input to
// an XLA cluster.
bool check_input_numerics;
// If true, insert CheckNumerics nodes for every floating point typed output
// from an XLA cluster.
bool check_output_numerics;
};
void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
old_node->out_edges().end());
@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) {
// Replaces each outgoing edge from `old_node` with a merge node that merges in
// the corresponding output from `new_node`.
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
bool insert_print_nodes) {
absl::string_view cluster_name,
const DebuggingOpts& debugging_opts) {
if (!s.status().ok()) {
return;
}
@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
int oidx = e->src_output();
Output merged_output = merged_outputs[oidx];
if (merged_output.node() == nullptr) {
ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
{Output(old_node, oidx), Output(new_node, oidx)});
if (insert_print_nodes) {
Output new_output(new_node, oidx);
if (debugging_opts.print_outputs) {
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
ops::Print print_op(s.WithOpName("print_", oidx)
.WithDevice(cpu_device)
.WithAssignedDevice(cpu_device),
merge_op.output, {merge_op.output},
new_output, {new_output},
ops::Print::Attrs{}
.Message(absl::StrCat("output ", oidx, " from ",
old_node->name(), " is "))
.FirstN(1000)
.Summarize(-1));
merged_output = merged_outputs[oidx] = print_op;
} else {
merged_output = merged_outputs[oidx] = merge_op.output;
new_output = print_op;
}
if (debugging_opts.check_output_numerics &&
DataTypeIsFloating(new_output.type())) {
ops::CheckNumerics check_numerics_op(
s.WithOpName("check_output_", oidx)
.WithDevice(new_node->requested_device())
.WithAssignedDevice(new_node->assigned_device_name()),
new_output,
absl::StrCat("CheckNumerics failed for output ", oidx, "(",
new_output.name(), ") from cluster ", cluster_name));
new_output = check_numerics_op;
}
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
{Output(old_node, oidx), new_output});
merged_output = merged_outputs[oidx] = merge_op.output;
}
Node* dst = e->dst();
@ -324,11 +351,34 @@ xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
return result;
}
std::vector<Output> GetXlaRunArgs(const Scope& s,
const XlaClusterInfo& cluster_info,
const DebuggingOpts& debugging_opts) {
std::vector<Output> xla_run_args;
xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
cluster_info.resource_inputs.size());
int input_idx = 0;
for (const Output& o : cluster_info.non_constant_inputs) {
if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
ops::CheckNumerics check_numerics_op(
s.WithOpName("check_input_", input_idx), o,
absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
o.name(), ") into ", cluster_info.function.name()));
xla_run_args.push_back(check_numerics_op);
} else {
xla_run_args.push_back(o);
}
input_idx++;
}
absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
return xla_run_args;
}
Status ReplaceNodeWithXlaCompileAndXlaRun(
jit::DeviceInfoCache* device_info_cache,
const GraphOptimizationPassOptions& options,
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
bool insert_print_nodes, Graph* g, Node* n) {
const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
TF_RETURN_IF_ERROR(
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
std::vector<Output> xla_run_args =
GetXlaRunArgs(root, cluster_info, debugging_opts);
if (requires_compilation) {
// "Strict" compilation: every _XlaCompile invocation must compile the
// cluster.
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
absl::c_copy(cluster_info.resource_inputs,
std::back_inserter(xla_run_args));
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
xla_compile.key, n->output_types());
@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
Output predicated_compilation_key = s.output_true;
Output inverse_predicated_compilation_key = s.output_false;
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
absl::c_copy(cluster_info.resource_inputs,
std::back_inserter(xla_run_args));
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
predicated_compilation_key, n->output_types());
@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
MergeOutgoingDataEdges(root, /*old_node=*/n,
/*new_node=*/xla_run.operation.node(),
insert_print_nodes);
cluster_info.function.name(), debugging_opts);
TF_RETURN_IF_ERROR(root.status());
@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
enable_lazy_compilation_
? *enable_lazy_compilation_
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
bool insert_print_nodes =
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
jit::DeviceInfoCache device_info_cache;
const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
DebuggingOpts debugging_opts;
debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
debugging_opts.check_input_numerics =
flags.tf_xla_check_cluster_input_numerics;
debugging_opts.check_output_numerics =
flags.tf_xla_check_cluster_output_numerics;
VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
&device_info_cache, options, *options.flib_def,
lazy_compilation_enabled, insert_print_nodes, graph, n));
lazy_compilation_enabled, debugging_opts, graph, n));
}
if (VLOG_IS_ON(1)) {

View File

@ -105,6 +105,8 @@ void AllocateAndParseFlags() {
build_ops_flags = new BuildXlaOpsPassFlags;
build_ops_flags->tf_xla_enable_lazy_compilation = true;
build_ops_flags->tf_xla_print_cluster_outputs = false;
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
build_ops_flags->tf_xla_disable_constant_folding = false;
mark_for_compilation_flags = new MarkForCompilationPassFlags;
@ -144,6 +146,14 @@ void AllocateAndParseFlags() {
&build_ops_flags->tf_xla_print_cluster_outputs,
"If true then insert Print nodes to print out values produced by "
"XLA clusters."),
Flag("tf_xla_check_cluster_input_numerics",
&build_ops_flags->tf_xla_check_cluster_input_numerics,
"If true then insert CheckNumerics nodes to to check all cluster "
"inputs."),
Flag("tf_xla_check_cluster_output_numerics",
&build_ops_flags->tf_xla_check_cluster_output_numerics,
"If true then insert CheckNumerics nodes to to check all cluster "
"outputs."),
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "

View File

@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags {
// clusters. Useful for debugging.
bool tf_xla_print_cluster_outputs;
// If true, insert CheckNumerics nodes for every floating point typed input to
// an XLA cluster.
bool tf_xla_check_cluster_input_numerics;
// If true, insert CheckNumerics nodes for every floating point typed output
// from an XLA cluster.
bool tf_xla_check_cluster_output_numerics;
// Disables all constant folding. The primary use for this is for testing to
// guarantee that tests are run on XLA and not on TF's CPU implementation.
bool tf_xla_disable_constant_folding;