[tfdbg] tf.debugging.enable_check_numerics() uses CheckNumericsV2 op

- Add the CPU and GPU implementationsof CheckNumericsV2Op
- CheckNumericsV2Op inherits from CheckNumercsOp, but has the new feature of
  distinguishing -/+ `inf`s.

PiperOrigin-RevId: 281195662
Change-Id: I7e07b9338ec25b0c0e2a6b131eaf59bf1e530317
This commit is contained in:
A. Unique TensorFlower 2019-11-18 17:58:45 -08:00 committed by TensorFlower Gardener
parent 2adcf83790
commit 67cf49948c
12 changed files with 69 additions and 350 deletions

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "CheckNumericsV2"
visibility: HIDDEN
attr {
name: "message"
description: <<END
Prefix of the error message.
END
}
summary: "Checks a tensor for NaN, -Inf and +Inf values."
description: <<END
When run, reports an `InvalidArgument` error if `tensor` has any values
that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
Unlike CheckNumerics (V1), CheckNumericsV2 distinguishes -Inf and +Inf in the
errors it throws.
END
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "CheckNumericsV2"
visibility: HIDDEN
}

View File

@ -1,4 +0,0 @@
op {
graph_op_name: "CheckNumericsV2"
visibility: HIDDEN
}

View File

@ -50,25 +50,10 @@ struct CheckNumericsLaunch {
extern template struct CheckNumericsLaunch<Eigen::half>;
extern template struct CheckNumericsLaunch<float>;
extern template struct CheckNumericsLaunch<double>;
template <typename T>
struct CheckNumericsLaunchV2 {
void Run(const GPUDevice& d, const T* data, int size,
int abnormal_detected[3]);
};
extern template struct CheckNumericsLaunchV2<Eigen::half>;
extern template struct CheckNumericsLaunchV2<float>;
extern template struct CheckNumericsLaunchV2<double>;
#endif
namespace {
const int kInfBit = 0x01;
const int kNaNBit = 0x02;
const int kNegativeInfBit = 0x04;
const int kPositiveInfBit = 0x08;
template <typename Device, typename T>
class CheckNumericsOp;
@ -92,34 +77,19 @@ class CheckNumericsOp<CPUDevice, T> : public OpKernel {
const T* data = in.data();
const int64 size = in.size();
// Check to see if any element of the tensor is NaN or Inf.
int fp_props = std::accumulate(
data, data + size, 0,
[this](const int x, const T& y) { return checkFloatingElement(x, y); });
if (fp_props != 0) {
const string& status = getErrorString(fp_props);
if (!status.empty()) {
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
status, " values"));
}
}
}
protected:
virtual int checkFloatingElement(const int x, const T& y) {
int fp_props =
std::accumulate(data, data + size, 0, [](const int x, const T& y) {
int result = x;
if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
// Do nothing: common case.
} else {
if (Eigen::numext::isinf(y)) {
// Do nothing: common case
} else if (Eigen::numext::isinf(y)) {
result |= kInfBit;
} else if (Eigen::numext::isnan(y)) {
result |= kNaNBit;
}
}
return result;
}
virtual const string getErrorString(const int fp_props) {
});
if (fp_props != 0) {
string status;
if ((fp_props & kInfBit) && (fp_props & kNaNBit)) {
status = "Inf and NaN";
@ -131,59 +101,17 @@ class CheckNumericsOp<CPUDevice, T> : public OpKernel {
status = "NaN";
}
}
return status;
if (!status.empty()) {
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
status, " values"));
}
}
}
private:
string message_;
};
template <typename Device, typename T>
class CheckNumericsV2Op;
// Partial specialization for CPU: v2.
// The v2 op differs from the v1 in that it distinguishes -inf and +inf.
template <typename T>
class CheckNumericsV2Op<CPUDevice, T> : public CheckNumericsOp<CPUDevice, T> {
public:
explicit CheckNumericsV2Op(OpKernelConstruction* context)
: CheckNumericsOp<CPUDevice, T>(context) {}
protected:
int checkFloatingElement(const int x, const T& y) override {
int result = x;
if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
// Do nothing: common case.
} else {
if (Eigen::numext::isinf(y)) {
result |= y < static_cast<T>(0.) ? kNegativeInfBit : kPositiveInfBit;
} else if (Eigen::numext::isnan(y)) {
result |= kNaNBit;
}
}
return result;
}
const string getErrorString(const int fp_props) override {
std::vector<string> anomalies;
if (fp_props & kNegativeInfBit) {
anomalies.push_back("-Inf");
}
if (fp_props & kPositiveInfBit) {
anomalies.push_back("+Inf");
}
if (fp_props & kNaNBit) {
anomalies.push_back("NaN");
}
if (anomalies.size() == 3) {
return strings::StrCat(anomalies[0], ", ", anomalies[1], ", and ",
anomalies[2]);
} else if (anomalies.size() == 2) {
return strings::StrCat(anomalies[0], " and ", anomalies[1]);
} else {
return anomalies[0];
}
}
static const int kInfBit = 0x01;
static const int kNaNBit = 0x02;
};
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -210,8 +138,8 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
auto input = context->input(0).flat<T>();
// Allocate and initialize the elements to hold the check results
const int abnormal_detected_size = 2;
Tensor abnormal_detected;
const int abnormal_detected_size = getAnomalyIndicatorSize();
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT32, TensorShape({abnormal_detected_size}),
&abnormal_detected));
@ -228,7 +156,7 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
// Call the GPU kernels for the numerical checks
const Device& d = context->eigen_device<Device>();
RunKernel(d, input.data(), input.size(),
CheckNumericsLaunch<T>().Run(d, input.data(), input.size(),
abnormal_detected.flat<int>().data());
// Copy the results from device to host
@ -262,97 +190,42 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
se::rocm::ScopedActivateExecutorContext scoped_activation{
stream->parent()};
#endif
TTypes<const int>::Vec abnormal_detected_host_flat =
abnormal_detected_host.flat<int>();
auto abnormal_detected_host_flat = abnormal_detected_host.flat<int>();
int is_nan = abnormal_detected_host_flat(0);
int is_inf = abnormal_detected_host_flat(1);
abnormal_detected_ref.Unref();
checkForAnomalies(context, abnormal_detected_host_flat);
if (is_nan || is_inf) {
string status;
LOG(ERROR) << "abnormal_detected_host @"
<< abnormal_detected_host_flat.data() << " = {" << is_nan
<< ", " << is_inf << "} " << message_;
// Results should always be 1 or 0. If we see anything else then
// there has been some GPU memory corruption.
CHECK_GE(is_nan, 0);
CHECK_GE(is_inf, 0);
CHECK_LE(is_nan, 1);
CHECK_LE(is_inf, 1);
if (is_nan && is_inf) {
status = "Inf and NaN";
} else if (is_nan) {
status = "NaN";
} else if (is_inf) {
status = "Inf";
}
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
status, " values"));
}
done();
};
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
stream, std::move(check_cb));
}
protected:
virtual int getAnomalyIndicatorSize() { return 2; }
virtual void RunKernel(const GPUDevice& d, const T* data, int size,
int* abnormal_detected) {
CheckNumericsLaunch<T>().Run(d, data, size, abnormal_detected);
}
virtual void checkForAnomalies(
OpKernelContext* context,
const TTypes<const int>::Vec& abnormality_indicators) {
const int is_nan = abnormality_indicators(0);
const int is_inf = abnormality_indicators(1);
if (is_nan || is_inf) {
LOG(ERROR) << "abnormal_detected_host @" << abnormality_indicators.data()
<< " = {" << is_nan << ", " << is_inf << "} " << message_;
string anomalies;
if (is_nan && is_inf) {
anomalies = "Inf and NaN";
} else if (is_nan) {
anomalies = "NaN";
} else if (is_inf) {
anomalies = "Inf";
}
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
anomalies, " values"));
}
}
private:
string message_;
};
template <typename T>
class CheckNumericsV2Op<GPUDevice, T> : public CheckNumericsOp<GPUDevice, T> {
public:
CheckNumericsV2Op(OpKernelConstruction* context)
: CheckNumericsOp<GPUDevice, T>(context) {}
protected:
int getAnomalyIndicatorSize() override { return 3; }
void RunKernel(const GPUDevice& d, const T* data, int size,
int* abnormal_detected) override {
CheckNumericsLaunchV2<T>().Run(d, data, size, abnormal_detected);
}
void checkForAnomalies(
OpKernelContext* context,
const TTypes<const int>::Vec& abnormality_indicators) override {
const int is_nan = abnormality_indicators(0);
const int is_negative_inf = abnormality_indicators(1);
const int is_positive_inf = abnormality_indicators(2);
if (is_negative_inf || is_positive_inf || is_nan) {
std::vector<string> anomalies;
if (is_negative_inf) {
anomalies.push_back("-Inf");
}
if (is_positive_inf) {
anomalies.push_back("+Inf");
}
if (is_nan) {
anomalies.push_back("NaN");
}
string all_anomalies;
if (anomalies.size() == 3) {
all_anomalies = strings::StrCat(anomalies[0], ", ", anomalies[1],
", and ", anomalies[2]);
} else if (anomalies.size() == 2) {
all_anomalies = strings::StrCat(anomalies[0], " and ", anomalies[1]);
} else {
all_anomalies = anomalies[0];
}
context->SetStatus(errors::InvalidArgument(
this->message_, " : Tensor had ", all_anomalies, " values"));
}
}
static const int abnormal_detected_size = 3;
};
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace
@ -366,15 +239,6 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#define REGISTER_V2_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("CheckNumericsV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
CheckNumericsV2Op<CPUDevice, T>);
TF_CALL_half(REGISTER_V2_CPU_KERNEL);
TF_CALL_bfloat16(REGISTER_V2_CPU_KERNEL);
TF_CALL_float(REGISTER_V2_CPU_KERNEL);
TF_CALL_double(REGISTER_V2_CPU_KERNEL);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(
Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
@ -385,16 +249,6 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(
Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<double>("T"),
CheckNumericsOp<GPUDevice, double>);
REGISTER_KERNEL_BUILDER(
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
CheckNumericsV2Op<GPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<float>("T"),
CheckNumericsV2Op<GPUDevice, float>);
REGISTER_KERNEL_BUILDER(
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<double>("T"),
CheckNumericsV2Op<GPUDevice, double>);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -54,29 +54,6 @@ __global__ void CheckNumericsKernel(const T* __restrict__ data, int size,
}
}
// V2 of CheckNumericsKernel for GPU.
// Unlike CheckNumericsKernel (V1), this kernel disinguishes -Inf and +Inf.
// The 3 elements of `abnormal_detected` are used to signify NaN, -Inf and +Inf,
// respectively.
template <typename T>
__global__ void CheckNumericsKernelV2(const T* __restrict__ data, int size,
int abnormal_detected[3]) {
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_thread_count = gridDim.x * blockDim.x;
int32 offset = thread_id;
while (offset < size) {
if (isnan(data[offset])) {
abnormal_detected[0] = 1;
}
if (isinf(data[offset])) {
abnormal_detected[data[offset] < static_cast<T>(0.f) ? 1 : 2] = 1;
}
offset += total_thread_count;
}
}
} // namespace
// A simple launch pad to launch the Cuda kernels that checks the numerical
@ -99,24 +76,5 @@ template struct CheckNumericsLaunch<Eigen::half>;
template struct CheckNumericsLaunch<float>;
template struct CheckNumericsLaunch<double>;
template <typename T>
struct CheckNumericsLaunchV2 {
void Run(const GPUDevice& d, const T* data, int size,
int abnormal_detected[3]) {
const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks =
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
block_size;
TF_CHECK_OK(GpuLaunchKernel(CheckNumericsKernelV2<T>, num_blocks,
block_size, 0, d.stream(), data, size,
abnormal_detected));
}
};
template struct CheckNumericsLaunchV2<Eigen::half>;
template struct CheckNumericsLaunchV2<float>;
template struct CheckNumericsLaunchV2<double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -1349,15 +1349,6 @@ REGISTER_OP("CheckNumerics")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
// --------------------------------------------------------------------------
REGISTER_OP("CheckNumericsV2")
.Input("tensor: T")
.Output("output: T")
.Attr("T: {bfloat16, half, float, double}")
.Attr("message: string")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape);
// --------------------------------------------------------------------------
REGISTER_OP("Reshape")
.Input("tensor: T")

View File

@ -246,7 +246,7 @@ class CheckNumericsCallback(object):
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
checked_output = array_ops.check_numerics_v2(
checked_output = array_ops.check_numerics(
# TF v2 has automatic control dependencies added to stateful async
# ops, which allows us to run check_numerics asynchronously.
# In the above case we use debug_summary to reduce all output
@ -268,7 +268,7 @@ class CheckNumericsCallback(object):
instrumented_outputs.append(output)
return instrumented_outputs
else:
if op_type_bytes == b"CheckNumericsV2":
if op_type_bytes == b"CheckNumerics":
# TODO(b/140334369): Remove this special casing logic once op_callback.
# automatically prevents infinite recursion in eager mode.
return None
@ -276,10 +276,14 @@ class CheckNumericsCallback(object):
for slot, output in enumerate(outputs):
if (output.dtype.is_floating and
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
array_ops.check_numerics_v2(
array_ops.check_numerics(
output,
get_check_numerics_error_message(
slot, len(outputs), op_type, output, inputs,
slot,
len(outputs),
op_type,
output,
inputs,
stack_height_limit=self._stack_height_limit,
path_length_limit=self._path_length_limit))

View File

@ -22,7 +22,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@ -130,51 +129,6 @@ class NumericsTest(test.TestCase):
r"or `tf.while_loop\(\)`\."):
numerics.add_check_numerics_ops()
def testCheckNumericsV2OpNegativeAndPositveInf(self):
"""Test that CheckNumericsV2 op distinguishes negative and positive infs."""
with self.session(graph=ops.Graph()):
t1 = constant_op.constant([-1.0, 1.0])
t2 = constant_op.constant([0.0, 0.0])
checked = array_ops.check_numerics_v2(
t1 / t2, message="pass through test")
caught = None
try:
self.evaluate(checked)
except errors.InvalidArgumentError as error:
caught = error
self.assertIn("had -Inf and +Inf values", caught.message)
self.assertIn("pass through test", caught.message)
def testCheckNumericsV2OpNegativeAndPositveInfAndNaN(self):
"""CheckNumericsV2 op distinguishes - & + infs when nan is present."""
with self.session(graph=ops.Graph()):
t1 = constant_op.constant([-1.0, 1.0, 0.0])
t2 = constant_op.constant([0.0, 0.0, 0.0])
checked = array_ops.check_numerics_v2(
t1 / t2, message="pass through test")
caught = None
try:
self.evaluate(checked)
except errors.InvalidArgumentError as error:
caught = error
self.assertIn("had -Inf, +Inf, and NaN values", caught.message)
self.assertIn("pass through test", caught.message)
def testCheckNumericsV2PositveInfAndNaN(self):
"""Test that CheckNumericsV2 op shows sign of inf when nan is present."""
with self.session(graph=ops.Graph()):
t1 = constant_op.constant([0.0, 1.0])
t2 = constant_op.constant([0.0, 0.0])
checked = array_ops.check_numerics_v2(
t1 / t2, message="pass through test")
caught = None
try:
self.evaluate(checked)
except errors.InvalidArgumentError as error:
caught = error
self.assertIn("had +Inf and NaN values", caught.message)
self.assertIn("pass through test", caught.message)
if __name__ == "__main__":
test.main()

View File

@ -710,15 +710,6 @@ def _CheckNumericsGrad(op, grad):
op.get_attr("message"))
@ops.RegisterGradient("CheckNumericsV2")
def _CheckNumericsV2Grad(op, grad):
"""Gradient for check_numerics op."""
return array_ops.check_numerics_v2(
grad,
"Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
op.get_attr("message"))
@ops.RegisterGradient("PlaceholderWithDefault")
@ops.RegisterGradient("Identity")
def _IdGrad(_, grad):

View File

@ -2556,8 +2556,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
[0, 0, 0, 0]], dtype=int32)>
Args:
shape: A `list` of integers, a `tuple` of integers, or
a 1-D `Tensor` of type `int32`.
shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
type `int32`.
dtype: The DType of an element in the resulting `Tensor`.
name: Optional string. A name for the operation.
@ -2787,8 +2787,8 @@ def ones(shape, dtype=dtypes.float32, name=None):
[1, 1, 1, 1]], dtype=int32)>
Args:
shape: A `list` of integers, a `tuple` of integers, or
a 1-D `Tensor` of type `int32`.
shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
type `int32`.
dtype: Optional DType of an element in the resulting `Tensor`. Default is
`tf.float32`.
name: Optional string. A name for the operation.
@ -4760,8 +4760,8 @@ def quantize(
axis=axis)
@tf_export("quantization.dequantize", v1=["quantization.dequantize",
"dequantize"])
@tf_export(
"quantization.dequantize", v1=["quantization.dequantize", "dequantize"])
@deprecation.deprecated_endpoints("dequantize")
def dequantize( # pylint: disable=missing-docstring
input, # pylint: disable=redefined-builtin

View File

@ -644,10 +644,6 @@ tf_module {
name: "CheckNumerics"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "CheckNumericsV2"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Cholesky"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -644,10 +644,6 @@ tf_module {
name: "CheckNumerics"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "CheckNumericsV2"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "Cholesky"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "