Add debug flags to check if XLA cluster inputs/outputs contain NaNs or Infs
PiperOrigin-RevId: 262379591
This commit is contained in:
parent
b4d36cdc68
commit
8ea321bf7d
@ -48,6 +48,19 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
struct DebuggingOpts {
|
||||||
|
// If true, insert Print nodes to print every output from an XLA cluster.
|
||||||
|
bool print_outputs;
|
||||||
|
|
||||||
|
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||||
|
// an XLA cluster.
|
||||||
|
bool check_input_numerics;
|
||||||
|
|
||||||
|
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||||
|
// from an XLA cluster.
|
||||||
|
bool check_output_numerics;
|
||||||
|
};
|
||||||
|
|
||||||
void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
|
void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
|
||||||
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
|
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
|
||||||
old_node->out_edges().end());
|
old_node->out_edges().end());
|
||||||
@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) {
|
|||||||
// Replaces each outgoing edge from `old_node` with a merge node that merges in
|
// Replaces each outgoing edge from `old_node` with a merge node that merges in
|
||||||
// the corresponding output from `new_node`.
|
// the corresponding output from `new_node`.
|
||||||
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
||||||
bool insert_print_nodes) {
|
absl::string_view cluster_name,
|
||||||
|
const DebuggingOpts& debugging_opts) {
|
||||||
if (!s.status().ok()) {
|
if (!s.status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node,
|
|||||||
int oidx = e->src_output();
|
int oidx = e->src_output();
|
||||||
Output merged_output = merged_outputs[oidx];
|
Output merged_output = merged_outputs[oidx];
|
||||||
if (merged_output.node() == nullptr) {
|
if (merged_output.node() == nullptr) {
|
||||||
ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)),
|
Output new_output(new_node, oidx);
|
||||||
{Output(old_node, oidx), Output(new_node, oidx)});
|
if (debugging_opts.print_outputs) {
|
||||||
if (insert_print_nodes) {
|
|
||||||
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx))
|
ops::Print print_op(s.WithOpName("print_", oidx)
|
||||||
.WithDevice(cpu_device)
|
.WithDevice(cpu_device)
|
||||||
.WithAssignedDevice(cpu_device),
|
.WithAssignedDevice(cpu_device),
|
||||||
merge_op.output, {merge_op.output},
|
new_output, {new_output},
|
||||||
ops::Print::Attrs{}
|
ops::Print::Attrs{}
|
||||||
.Message(absl::StrCat("output ", oidx, " from ",
|
.Message(absl::StrCat("output ", oidx, " from ",
|
||||||
old_node->name(), " is "))
|
old_node->name(), " is "))
|
||||||
.FirstN(1000)
|
.FirstN(1000)
|
||||||
.Summarize(-1));
|
.Summarize(-1));
|
||||||
merged_output = merged_outputs[oidx] = print_op;
|
new_output = print_op;
|
||||||
} else {
|
|
||||||
merged_output = merged_outputs[oidx] = merge_op.output;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (debugging_opts.check_output_numerics &&
|
||||||
|
DataTypeIsFloating(new_output.type())) {
|
||||||
|
ops::CheckNumerics check_numerics_op(
|
||||||
|
s.WithOpName("check_output_", oidx)
|
||||||
|
.WithDevice(new_node->requested_device())
|
||||||
|
.WithAssignedDevice(new_node->assigned_device_name()),
|
||||||
|
new_output,
|
||||||
|
absl::StrCat("CheckNumerics failed for output ", oidx, "(",
|
||||||
|
new_output.name(), ") from cluster ", cluster_name));
|
||||||
|
new_output = check_numerics_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx),
|
||||||
|
{Output(old_node, oidx), new_output});
|
||||||
|
merged_output = merged_outputs[oidx] = merge_op.output;
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* dst = e->dst();
|
Node* dst = e->dst();
|
||||||
@ -324,11 +351,34 @@ xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Output> GetXlaRunArgs(const Scope& s,
|
||||||
|
const XlaClusterInfo& cluster_info,
|
||||||
|
const DebuggingOpts& debugging_opts) {
|
||||||
|
std::vector<Output> xla_run_args;
|
||||||
|
xla_run_args.reserve(cluster_info.non_constant_inputs.size() +
|
||||||
|
cluster_info.resource_inputs.size());
|
||||||
|
int input_idx = 0;
|
||||||
|
for (const Output& o : cluster_info.non_constant_inputs) {
|
||||||
|
if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) {
|
||||||
|
ops::CheckNumerics check_numerics_op(
|
||||||
|
s.WithOpName("check_input_", input_idx), o,
|
||||||
|
absl::StrCat("CheckNumerics failed for input ", input_idx, "(",
|
||||||
|
o.name(), ") into ", cluster_info.function.name()));
|
||||||
|
xla_run_args.push_back(check_numerics_op);
|
||||||
|
} else {
|
||||||
|
xla_run_args.push_back(o);
|
||||||
|
}
|
||||||
|
input_idx++;
|
||||||
|
}
|
||||||
|
absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
|
||||||
|
return xla_run_args;
|
||||||
|
}
|
||||||
|
|
||||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||||
jit::DeviceInfoCache* device_info_cache,
|
jit::DeviceInfoCache* device_info_cache,
|
||||||
const GraphOptimizationPassOptions& options,
|
const GraphOptimizationPassOptions& options,
|
||||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||||
bool insert_print_nodes, Graph* g, Node* n) {
|
const DebuggingOpts& debugging_opts, Graph* g, Node* n) {
|
||||||
XlaClusterInfo cluster_info;
|
XlaClusterInfo cluster_info;
|
||||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||||
|
|
||||||
@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
|
||||||
|
|
||||||
|
std::vector<Output> xla_run_args =
|
||||||
|
GetXlaRunArgs(root, cluster_info, debugging_opts);
|
||||||
|
|
||||||
if (requires_compilation) {
|
if (requires_compilation) {
|
||||||
// "Strict" compilation: every _XlaCompile invocation must compile the
|
// "Strict" compilation: every _XlaCompile invocation must compile the
|
||||||
// cluster.
|
// cluster.
|
||||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
|
||||||
absl::c_copy(cluster_info.resource_inputs,
|
|
||||||
std::back_inserter(xla_run_args));
|
|
||||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||||
xla_compile.key, n->output_types());
|
xla_compile.key, n->output_types());
|
||||||
|
|
||||||
@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
|||||||
Output predicated_compilation_key = s.output_true;
|
Output predicated_compilation_key = s.output_true;
|
||||||
Output inverse_predicated_compilation_key = s.output_false;
|
Output inverse_predicated_compilation_key = s.output_false;
|
||||||
|
|
||||||
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
|
|
||||||
absl::c_copy(cluster_info.resource_inputs,
|
|
||||||
std::back_inserter(xla_run_args));
|
|
||||||
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
|
||||||
predicated_compilation_key, n->output_types());
|
predicated_compilation_key, n->output_types());
|
||||||
|
|
||||||
@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun(
|
|||||||
|
|
||||||
MergeOutgoingDataEdges(root, /*old_node=*/n,
|
MergeOutgoingDataEdges(root, /*old_node=*/n,
|
||||||
/*new_node=*/xla_run.operation.node(),
|
/*new_node=*/xla_run.operation.node(),
|
||||||
insert_print_nodes);
|
cluster_info.function.name(), debugging_opts);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(root.status());
|
TF_RETURN_IF_ERROR(root.status());
|
||||||
|
|
||||||
@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
|||||||
enable_lazy_compilation_
|
enable_lazy_compilation_
|
||||||
? *enable_lazy_compilation_
|
? *enable_lazy_compilation_
|
||||||
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
|
: GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation;
|
||||||
bool insert_print_nodes =
|
|
||||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
|
||||||
|
|
||||||
jit::DeviceInfoCache device_info_cache;
|
jit::DeviceInfoCache device_info_cache;
|
||||||
|
const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags();
|
||||||
|
|
||||||
|
DebuggingOpts debugging_opts;
|
||||||
|
debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs;
|
||||||
|
debugging_opts.check_input_numerics =
|
||||||
|
flags.tf_xla_check_cluster_input_numerics;
|
||||||
|
debugging_opts.check_output_numerics =
|
||||||
|
flags.tf_xla_check_cluster_output_numerics;
|
||||||
|
|
||||||
|
VLOG(1) << "print_outputs = " << debugging_opts.print_outputs;
|
||||||
|
VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics;
|
||||||
|
VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics;
|
||||||
|
|
||||||
for (Node* n : xla_compiled_kernels) {
|
for (Node* n : xla_compiled_kernels) {
|
||||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||||
&device_info_cache, options, *options.flib_def,
|
&device_info_cache, options, *options.flib_def,
|
||||||
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
lazy_compilation_enabled, debugging_opts, graph, n));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
|
|||||||
@ -105,6 +105,8 @@ void AllocateAndParseFlags() {
|
|||||||
build_ops_flags = new BuildXlaOpsPassFlags;
|
build_ops_flags = new BuildXlaOpsPassFlags;
|
||||||
build_ops_flags->tf_xla_enable_lazy_compilation = true;
|
build_ops_flags->tf_xla_enable_lazy_compilation = true;
|
||||||
build_ops_flags->tf_xla_print_cluster_outputs = false;
|
build_ops_flags->tf_xla_print_cluster_outputs = false;
|
||||||
|
build_ops_flags->tf_xla_check_cluster_input_numerics = false;
|
||||||
|
build_ops_flags->tf_xla_check_cluster_output_numerics = false;
|
||||||
build_ops_flags->tf_xla_disable_constant_folding = false;
|
build_ops_flags->tf_xla_disable_constant_folding = false;
|
||||||
|
|
||||||
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
mark_for_compilation_flags = new MarkForCompilationPassFlags;
|
||||||
@ -144,6 +146,14 @@ void AllocateAndParseFlags() {
|
|||||||
&build_ops_flags->tf_xla_print_cluster_outputs,
|
&build_ops_flags->tf_xla_print_cluster_outputs,
|
||||||
"If true then insert Print nodes to print out values produced by "
|
"If true then insert Print nodes to print out values produced by "
|
||||||
"XLA clusters."),
|
"XLA clusters."),
|
||||||
|
Flag("tf_xla_check_cluster_input_numerics",
|
||||||
|
&build_ops_flags->tf_xla_check_cluster_input_numerics,
|
||||||
|
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||||
|
"inputs."),
|
||||||
|
Flag("tf_xla_check_cluster_output_numerics",
|
||||||
|
&build_ops_flags->tf_xla_check_cluster_output_numerics,
|
||||||
|
"If true then insert CheckNumerics nodes to to check all cluster "
|
||||||
|
"outputs."),
|
||||||
|
|
||||||
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
|
||||||
"Switch a device into 'on-demand' mode, where instead of "
|
"Switch a device into 'on-demand' mode, where instead of "
|
||||||
|
|||||||
@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags {
|
|||||||
// clusters. Useful for debugging.
|
// clusters. Useful for debugging.
|
||||||
bool tf_xla_print_cluster_outputs;
|
bool tf_xla_print_cluster_outputs;
|
||||||
|
|
||||||
|
// If true, insert CheckNumerics nodes for every floating point typed input to
|
||||||
|
// an XLA cluster.
|
||||||
|
bool tf_xla_check_cluster_input_numerics;
|
||||||
|
|
||||||
|
// If true, insert CheckNumerics nodes for every floating point typed output
|
||||||
|
// from an XLA cluster.
|
||||||
|
bool tf_xla_check_cluster_output_numerics;
|
||||||
|
|
||||||
// Disables all constant folding. The primary use for this is for testing to
|
// Disables all constant folding. The primary use for this is for testing to
|
||||||
// guarantee that tests are run on XLA and not on TF's CPU implementation.
|
// guarantee that tests are run on XLA and not on TF's CPU implementation.
|
||||||
bool tf_xla_disable_constant_folding;
|
bool tf_xla_disable_constant_folding;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user