diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 1265ff9138a..61695d532d1 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -48,6 +48,19 @@ limitations under the License. namespace tensorflow { namespace { +struct DebuggingOpts { + // If true, insert Print nodes to print every output from an XLA cluster. + bool print_outputs; + + // If true, insert CheckNumerics nodes for every floating point typed input to + // an XLA cluster. + bool check_input_numerics; + + // If true, insert CheckNumerics nodes for every floating point typed output + // from an XLA cluster. + bool check_output_numerics; +}; + void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { std::vector out_edges(old_node->out_edges().begin(), old_node->out_edges().end()); @@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) { // Replaces each outgoing edge from `old_node` with a merge node that merges in // the corresponding output from `new_node`. void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, - bool insert_print_nodes) { + absl::string_view cluster_name, + const DebuggingOpts& debugging_opts) { if (!s.status().ok()) { return; } @@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, int oidx = e->src_output(); Output merged_output = merged_outputs[oidx]; if (merged_output.node() == nullptr) { - ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), - {Output(old_node, oidx), Output(new_node, oidx)}); - if (insert_print_nodes) { + Output new_output(new_node, oidx); + if (debugging_opts.print_outputs) { string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; - ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx)) + ops::Print print_op(s.WithOpName("print_", oidx) .WithDevice(cpu_device) .WithAssignedDevice(cpu_device), - merge_op.output, {merge_op.output}, + new_output, {new_output}, ops::Print::Attrs{} .Message(absl::StrCat("output ", oidx, " from ", old_node->name(), " is ")) .FirstN(1000) .Summarize(-1)); - merged_output = merged_outputs[oidx] = print_op; - } else { - merged_output = merged_outputs[oidx] = merge_op.output; + new_output = print_op; } + + if (debugging_opts.check_output_numerics && + DataTypeIsFloating(new_output.type())) { + ops::CheckNumerics check_numerics_op( + s.WithOpName("check_output_", oidx) + .WithDevice(new_node->requested_device()) + .WithAssignedDevice(new_node->assigned_device_name()), + new_output, + absl::StrCat("CheckNumerics failed for output ", oidx, "(", + new_output.name(), ") from cluster ", cluster_name)); + new_output = check_numerics_op; + } + + ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx), + {Output(old_node, oidx), new_output}); + merged_output = merged_outputs[oidx] = merge_op.output; } Node* dst = e->dst(); @@ -324,11 +351,34 @@ xla::StatusOr InferDeviceForCluster( return result; } +std::vector GetXlaRunArgs(const Scope& s, + const XlaClusterInfo& cluster_info, + const DebuggingOpts& debugging_opts) { + std::vector xla_run_args; + xla_run_args.reserve(cluster_info.non_constant_inputs.size() + + cluster_info.resource_inputs.size()); + int input_idx = 0; + for (const Output& o : cluster_info.non_constant_inputs) { + if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) { + ops::CheckNumerics check_numerics_op( + s.WithOpName("check_input_", input_idx), o, + absl::StrCat("CheckNumerics failed for input ", input_idx, "(", + o.name(), ") into ", cluster_info.function.name())); + xla_run_args.push_back(check_numerics_op); + } else { + xla_run_args.push_back(o); + } + input_idx++; + } + absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); + return xla_run_args; +} + Status ReplaceNodeWithXlaCompileAndXlaRun( jit::DeviceInfoCache* device_info_cache, const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, - bool insert_print_nodes, Graph* g, Node* n) { + const DebuggingOpts& debugging_opts, Graph* g, Node* n) { XlaClusterInfo cluster_info; TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); @@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( TF_RETURN_IF_ERROR( CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); + std::vector xla_run_args = + GetXlaRunArgs(root, cluster_info, debugging_opts); + if (requires_compilation) { // "Strict" compilation: every _XlaCompile invocation must compile the // cluster. - std::vector xla_run_args = cluster_info.non_constant_inputs; - absl::c_copy(cluster_info.resource_inputs, - std::back_inserter(xla_run_args)); ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, xla_compile.key, n->output_types()); @@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( Output predicated_compilation_key = s.output_true; Output inverse_predicated_compilation_key = s.output_false; - std::vector xla_run_args = cluster_info.non_constant_inputs; - absl::c_copy(cluster_info.resource_inputs, - std::back_inserter(xla_run_args)); ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, predicated_compilation_key, n->output_types()); @@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( MergeOutgoingDataEdges(root, /*old_node=*/n, /*new_node=*/xla_run.operation.node(), - insert_print_nodes); + cluster_info.function.name(), debugging_opts); TF_RETURN_IF_ERROR(root.status()); @@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { enable_lazy_compilation_ ? *enable_lazy_compilation_ : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation; - bool insert_print_nodes = - GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs; jit::DeviceInfoCache device_info_cache; + const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags(); + + DebuggingOpts debugging_opts; + debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs; + debugging_opts.check_input_numerics = + flags.tf_xla_check_cluster_input_numerics; + debugging_opts.check_output_numerics = + flags.tf_xla_check_cluster_output_numerics; + + VLOG(1) << "print_outputs = " << debugging_opts.print_outputs; + VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics; + VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( &device_info_cache, options, *options.flib_def, - lazy_compilation_enabled, insert_print_nodes, graph, n)); + lazy_compilation_enabled, debugging_opts, graph, n)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index f69a28b71b3..53f9b70c876 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -105,6 +105,8 @@ void AllocateAndParseFlags() { build_ops_flags = new BuildXlaOpsPassFlags; build_ops_flags->tf_xla_enable_lazy_compilation = true; build_ops_flags->tf_xla_print_cluster_outputs = false; + build_ops_flags->tf_xla_check_cluster_input_numerics = false; + build_ops_flags->tf_xla_check_cluster_output_numerics = false; build_ops_flags->tf_xla_disable_constant_folding = false; mark_for_compilation_flags = new MarkForCompilationPassFlags; @@ -144,6 +146,14 @@ void AllocateAndParseFlags() { &build_ops_flags->tf_xla_print_cluster_outputs, "If true then insert Print nodes to print out values produced by " "XLA clusters."), + Flag("tf_xla_check_cluster_input_numerics", + &build_ops_flags->tf_xla_check_cluster_input_numerics, + "If true then insert CheckNumerics nodes to to check all cluster " + "inputs."), + Flag("tf_xla_check_cluster_output_numerics", + &build_ops_flags->tf_xla_check_cluster_output_numerics, + "If true then insert CheckNumerics nodes to to check all cluster " + "outputs."), Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, "Switch a device into 'on-demand' mode, where instead of " diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 91e93f30d11..9307874133c 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags { // clusters. Useful for debugging. bool tf_xla_print_cluster_outputs; + // If true, insert CheckNumerics nodes for every floating point typed input to + // an XLA cluster. + bool tf_xla_check_cluster_input_numerics; + + // If true, insert CheckNumerics nodes for every floating point typed output + // from an XLA cluster. + bool tf_xla_check_cluster_output_numerics; + // Disables all constant folding. The primary use for this is for testing to // guarantee that tests are run on XLA and not on TF's CPU implementation. bool tf_xla_disable_constant_folding;