[tfdbg] tf.debugging.enable_check_numerics() uses CheckNumericsV2 op
- Add the CPU and GPU implementationsof CheckNumericsV2Op - CheckNumericsV2Op inherits from CheckNumercsOp, but has the new feature of distinguishing -/+ `inf`s. PiperOrigin-RevId: 281195662 Change-Id: I7e07b9338ec25b0c0e2a6b131eaf59bf1e530317
This commit is contained in:
parent
2adcf83790
commit
67cf49948c
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "CheckNumericsV2"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "message"
|
||||
description: <<END
|
||||
Prefix of the error message.
|
||||
END
|
||||
}
|
||||
summary: "Checks a tensor for NaN, -Inf and +Inf values."
|
||||
description: <<END
|
||||
When run, reports an `InvalidArgument` error if `tensor` has any values
|
||||
that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
Unlike CheckNumerics (V1), CheckNumericsV2 distinguishes -Inf and +Inf in the
|
||||
errors it throws.
|
||||
END
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "CheckNumericsV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "CheckNumericsV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -50,25 +50,10 @@ struct CheckNumericsLaunch {
|
||||
extern template struct CheckNumericsLaunch<Eigen::half>;
|
||||
extern template struct CheckNumericsLaunch<float>;
|
||||
extern template struct CheckNumericsLaunch<double>;
|
||||
|
||||
template <typename T>
|
||||
struct CheckNumericsLaunchV2 {
|
||||
void Run(const GPUDevice& d, const T* data, int size,
|
||||
int abnormal_detected[3]);
|
||||
};
|
||||
|
||||
extern template struct CheckNumericsLaunchV2<Eigen::half>;
|
||||
extern template struct CheckNumericsLaunchV2<float>;
|
||||
extern template struct CheckNumericsLaunchV2<double>;
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
const int kInfBit = 0x01;
|
||||
const int kNaNBit = 0x02;
|
||||
const int kNegativeInfBit = 0x04;
|
||||
const int kPositiveInfBit = 0x08;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CheckNumericsOp;
|
||||
|
||||
@ -92,34 +77,19 @@ class CheckNumericsOp<CPUDevice, T> : public OpKernel {
|
||||
const T* data = in.data();
|
||||
const int64 size = in.size();
|
||||
// Check to see if any element of the tensor is NaN or Inf.
|
||||
int fp_props = std::accumulate(
|
||||
data, data + size, 0,
|
||||
[this](const int x, const T& y) { return checkFloatingElement(x, y); });
|
||||
if (fp_props != 0) {
|
||||
const string& status = getErrorString(fp_props);
|
||||
if (!status.empty()) {
|
||||
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
|
||||
status, " values"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int checkFloatingElement(const int x, const T& y) {
|
||||
int fp_props =
|
||||
std::accumulate(data, data + size, 0, [](const int x, const T& y) {
|
||||
int result = x;
|
||||
if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
|
||||
// Do nothing: common case.
|
||||
} else {
|
||||
if (Eigen::numext::isinf(y)) {
|
||||
// Do nothing: common case
|
||||
} else if (Eigen::numext::isinf(y)) {
|
||||
result |= kInfBit;
|
||||
} else if (Eigen::numext::isnan(y)) {
|
||||
result |= kNaNBit;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual const string getErrorString(const int fp_props) {
|
||||
});
|
||||
if (fp_props != 0) {
|
||||
string status;
|
||||
if ((fp_props & kInfBit) && (fp_props & kNaNBit)) {
|
||||
status = "Inf and NaN";
|
||||
@ -131,59 +101,17 @@ class CheckNumericsOp<CPUDevice, T> : public OpKernel {
|
||||
status = "NaN";
|
||||
}
|
||||
}
|
||||
return status;
|
||||
if (!status.empty()) {
|
||||
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
|
||||
status, " values"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string message_;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class CheckNumericsV2Op;
|
||||
|
||||
// Partial specialization for CPU: v2.
|
||||
// The v2 op differs from the v1 in that it distinguishes -inf and +inf.
|
||||
template <typename T>
|
||||
class CheckNumericsV2Op<CPUDevice, T> : public CheckNumericsOp<CPUDevice, T> {
|
||||
public:
|
||||
explicit CheckNumericsV2Op(OpKernelConstruction* context)
|
||||
: CheckNumericsOp<CPUDevice, T>(context) {}
|
||||
|
||||
protected:
|
||||
int checkFloatingElement(const int x, const T& y) override {
|
||||
int result = x;
|
||||
if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
|
||||
// Do nothing: common case.
|
||||
} else {
|
||||
if (Eigen::numext::isinf(y)) {
|
||||
result |= y < static_cast<T>(0.) ? kNegativeInfBit : kPositiveInfBit;
|
||||
} else if (Eigen::numext::isnan(y)) {
|
||||
result |= kNaNBit;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const string getErrorString(const int fp_props) override {
|
||||
std::vector<string> anomalies;
|
||||
if (fp_props & kNegativeInfBit) {
|
||||
anomalies.push_back("-Inf");
|
||||
}
|
||||
if (fp_props & kPositiveInfBit) {
|
||||
anomalies.push_back("+Inf");
|
||||
}
|
||||
if (fp_props & kNaNBit) {
|
||||
anomalies.push_back("NaN");
|
||||
}
|
||||
if (anomalies.size() == 3) {
|
||||
return strings::StrCat(anomalies[0], ", ", anomalies[1], ", and ",
|
||||
anomalies[2]);
|
||||
} else if (anomalies.size() == 2) {
|
||||
return strings::StrCat(anomalies[0], " and ", anomalies[1]);
|
||||
} else {
|
||||
return anomalies[0];
|
||||
}
|
||||
}
|
||||
static const int kInfBit = 0x01;
|
||||
static const int kNaNBit = 0x02;
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
@ -210,8 +138,8 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
|
||||
auto input = context->input(0).flat<T>();
|
||||
|
||||
// Allocate and initialize the elements to hold the check results
|
||||
const int abnormal_detected_size = 2;
|
||||
Tensor abnormal_detected;
|
||||
const int abnormal_detected_size = getAnomalyIndicatorSize();
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(
|
||||
DT_INT32, TensorShape({abnormal_detected_size}),
|
||||
&abnormal_detected));
|
||||
@ -228,7 +156,7 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
|
||||
|
||||
// Call the GPU kernels for the numerical checks
|
||||
const Device& d = context->eigen_device<Device>();
|
||||
RunKernel(d, input.data(), input.size(),
|
||||
CheckNumericsLaunch<T>().Run(d, input.data(), input.size(),
|
||||
abnormal_detected.flat<int>().data());
|
||||
|
||||
// Copy the results from device to host
|
||||
@ -262,97 +190,42 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
|
||||
se::rocm::ScopedActivateExecutorContext scoped_activation{
|
||||
stream->parent()};
|
||||
#endif
|
||||
TTypes<const int>::Vec abnormal_detected_host_flat =
|
||||
abnormal_detected_host.flat<int>();
|
||||
auto abnormal_detected_host_flat = abnormal_detected_host.flat<int>();
|
||||
int is_nan = abnormal_detected_host_flat(0);
|
||||
int is_inf = abnormal_detected_host_flat(1);
|
||||
abnormal_detected_ref.Unref();
|
||||
checkForAnomalies(context, abnormal_detected_host_flat);
|
||||
if (is_nan || is_inf) {
|
||||
string status;
|
||||
LOG(ERROR) << "abnormal_detected_host @"
|
||||
<< abnormal_detected_host_flat.data() << " = {" << is_nan
|
||||
<< ", " << is_inf << "} " << message_;
|
||||
|
||||
// Results should always be 1 or 0. If we see anything else then
|
||||
// there has been some GPU memory corruption.
|
||||
CHECK_GE(is_nan, 0);
|
||||
CHECK_GE(is_inf, 0);
|
||||
CHECK_LE(is_nan, 1);
|
||||
CHECK_LE(is_inf, 1);
|
||||
|
||||
if (is_nan && is_inf) {
|
||||
status = "Inf and NaN";
|
||||
} else if (is_nan) {
|
||||
status = "NaN";
|
||||
} else if (is_inf) {
|
||||
status = "Inf";
|
||||
}
|
||||
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
|
||||
status, " values"));
|
||||
}
|
||||
done();
|
||||
};
|
||||
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||
stream, std::move(check_cb));
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual int getAnomalyIndicatorSize() { return 2; }
|
||||
|
||||
virtual void RunKernel(const GPUDevice& d, const T* data, int size,
|
||||
int* abnormal_detected) {
|
||||
CheckNumericsLaunch<T>().Run(d, data, size, abnormal_detected);
|
||||
}
|
||||
|
||||
virtual void checkForAnomalies(
|
||||
OpKernelContext* context,
|
||||
const TTypes<const int>::Vec& abnormality_indicators) {
|
||||
const int is_nan = abnormality_indicators(0);
|
||||
const int is_inf = abnormality_indicators(1);
|
||||
if (is_nan || is_inf) {
|
||||
LOG(ERROR) << "abnormal_detected_host @" << abnormality_indicators.data()
|
||||
<< " = {" << is_nan << ", " << is_inf << "} " << message_;
|
||||
|
||||
string anomalies;
|
||||
if (is_nan && is_inf) {
|
||||
anomalies = "Inf and NaN";
|
||||
} else if (is_nan) {
|
||||
anomalies = "NaN";
|
||||
} else if (is_inf) {
|
||||
anomalies = "Inf";
|
||||
}
|
||||
context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
|
||||
anomalies, " values"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
string message_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CheckNumericsV2Op<GPUDevice, T> : public CheckNumericsOp<GPUDevice, T> {
|
||||
public:
|
||||
CheckNumericsV2Op(OpKernelConstruction* context)
|
||||
: CheckNumericsOp<GPUDevice, T>(context) {}
|
||||
|
||||
protected:
|
||||
int getAnomalyIndicatorSize() override { return 3; }
|
||||
|
||||
void RunKernel(const GPUDevice& d, const T* data, int size,
|
||||
int* abnormal_detected) override {
|
||||
CheckNumericsLaunchV2<T>().Run(d, data, size, abnormal_detected);
|
||||
}
|
||||
|
||||
void checkForAnomalies(
|
||||
OpKernelContext* context,
|
||||
const TTypes<const int>::Vec& abnormality_indicators) override {
|
||||
const int is_nan = abnormality_indicators(0);
|
||||
const int is_negative_inf = abnormality_indicators(1);
|
||||
const int is_positive_inf = abnormality_indicators(2);
|
||||
if (is_negative_inf || is_positive_inf || is_nan) {
|
||||
std::vector<string> anomalies;
|
||||
if (is_negative_inf) {
|
||||
anomalies.push_back("-Inf");
|
||||
}
|
||||
if (is_positive_inf) {
|
||||
anomalies.push_back("+Inf");
|
||||
}
|
||||
if (is_nan) {
|
||||
anomalies.push_back("NaN");
|
||||
}
|
||||
string all_anomalies;
|
||||
if (anomalies.size() == 3) {
|
||||
all_anomalies = strings::StrCat(anomalies[0], ", ", anomalies[1],
|
||||
", and ", anomalies[2]);
|
||||
} else if (anomalies.size() == 2) {
|
||||
all_anomalies = strings::StrCat(anomalies[0], " and ", anomalies[1]);
|
||||
} else {
|
||||
all_anomalies = anomalies[0];
|
||||
}
|
||||
context->SetStatus(errors::InvalidArgument(
|
||||
this->message_, " : Tensor had ", all_anomalies, " values"));
|
||||
}
|
||||
}
|
||||
|
||||
static const int abnormal_detected_size = 3;
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace
|
||||
@ -366,15 +239,6 @@ TF_CALL_bfloat16(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
|
||||
#define REGISTER_V2_CPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("CheckNumericsV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
CheckNumericsV2Op<CPUDevice, T>);
|
||||
TF_CALL_half(REGISTER_V2_CPU_KERNEL);
|
||||
TF_CALL_bfloat16(REGISTER_V2_CPU_KERNEL);
|
||||
TF_CALL_float(REGISTER_V2_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_V2_CPU_KERNEL);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||
@ -385,16 +249,6 @@ REGISTER_KERNEL_BUILDER(
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||
CheckNumericsOp<GPUDevice, double>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
|
||||
CheckNumericsV2Op<GPUDevice, Eigen::half>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<float>("T"),
|
||||
CheckNumericsV2Op<GPUDevice, float>);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<double>("T"),
|
||||
CheckNumericsV2Op<GPUDevice, double>);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -54,29 +54,6 @@ __global__ void CheckNumericsKernel(const T* __restrict__ data, int size,
|
||||
}
|
||||
}
|
||||
|
||||
// V2 of CheckNumericsKernel for GPU.
|
||||
// Unlike CheckNumericsKernel (V1), this kernel disinguishes -Inf and +Inf.
|
||||
// The 3 elements of `abnormal_detected` are used to signify NaN, -Inf and +Inf,
|
||||
// respectively.
|
||||
template <typename T>
|
||||
__global__ void CheckNumericsKernelV2(const T* __restrict__ data, int size,
|
||||
int abnormal_detected[3]) {
|
||||
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32 total_thread_count = gridDim.x * blockDim.x;
|
||||
|
||||
int32 offset = thread_id;
|
||||
|
||||
while (offset < size) {
|
||||
if (isnan(data[offset])) {
|
||||
abnormal_detected[0] = 1;
|
||||
}
|
||||
if (isinf(data[offset])) {
|
||||
abnormal_detected[data[offset] < static_cast<T>(0.f) ? 1 : 2] = 1;
|
||||
}
|
||||
offset += total_thread_count;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A simple launch pad to launch the Cuda kernels that checks the numerical
|
||||
@ -99,24 +76,5 @@ template struct CheckNumericsLaunch<Eigen::half>;
|
||||
template struct CheckNumericsLaunch<float>;
|
||||
template struct CheckNumericsLaunch<double>;
|
||||
|
||||
template <typename T>
|
||||
struct CheckNumericsLaunchV2 {
|
||||
void Run(const GPUDevice& d, const T* data, int size,
|
||||
int abnormal_detected[3]) {
|
||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||
const int32 num_blocks =
|
||||
(d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
|
||||
block_size;
|
||||
|
||||
TF_CHECK_OK(GpuLaunchKernel(CheckNumericsKernelV2<T>, num_blocks,
|
||||
block_size, 0, d.stream(), data, size,
|
||||
abnormal_detected));
|
||||
}
|
||||
};
|
||||
|
||||
template struct CheckNumericsLaunchV2<Eigen::half>;
|
||||
template struct CheckNumericsLaunchV2<float>;
|
||||
template struct CheckNumericsLaunchV2<double>;
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -1349,15 +1349,6 @@ REGISTER_OP("CheckNumerics")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("CheckNumericsV2")
|
||||
.Input("tensor: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.Attr("message: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Reshape")
|
||||
.Input("tensor: T")
|
||||
|
@ -246,7 +246,7 @@ class CheckNumericsCallback(object):
|
||||
for slot, output in enumerate(outputs):
|
||||
if (output.dtype.is_floating and
|
||||
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
|
||||
checked_output = array_ops.check_numerics_v2(
|
||||
checked_output = array_ops.check_numerics(
|
||||
# TF v2 has automatic control dependencies added to stateful async
|
||||
# ops, which allows us to run check_numerics asynchronously.
|
||||
# In the above case we use debug_summary to reduce all output
|
||||
@ -268,7 +268,7 @@ class CheckNumericsCallback(object):
|
||||
instrumented_outputs.append(output)
|
||||
return instrumented_outputs
|
||||
else:
|
||||
if op_type_bytes == b"CheckNumericsV2":
|
||||
if op_type_bytes == b"CheckNumerics":
|
||||
# TODO(b/140334369): Remove this special casing logic once op_callback.
|
||||
# automatically prevents infinite recursion in eager mode.
|
||||
return None
|
||||
@ -276,10 +276,14 @@ class CheckNumericsCallback(object):
|
||||
for slot, output in enumerate(outputs):
|
||||
if (output.dtype.is_floating and
|
||||
(op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
|
||||
array_ops.check_numerics_v2(
|
||||
array_ops.check_numerics(
|
||||
output,
|
||||
get_check_numerics_error_message(
|
||||
slot, len(outputs), op_type, output, inputs,
|
||||
slot,
|
||||
len(outputs),
|
||||
op_type,
|
||||
output,
|
||||
inputs,
|
||||
stack_height_limit=self._stack_height_limit,
|
||||
path_length_limit=self._path_length_limit))
|
||||
|
||||
|
@ -22,7 +22,6 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -130,51 +129,6 @@ class NumericsTest(test.TestCase):
|
||||
r"or `tf.while_loop\(\)`\."):
|
||||
numerics.add_check_numerics_ops()
|
||||
|
||||
def testCheckNumericsV2OpNegativeAndPositveInf(self):
|
||||
"""Test that CheckNumericsV2 op distinguishes negative and positive infs."""
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([-1.0, 1.0])
|
||||
t2 = constant_op.constant([0.0, 0.0])
|
||||
checked = array_ops.check_numerics_v2(
|
||||
t1 / t2, message="pass through test")
|
||||
caught = None
|
||||
try:
|
||||
self.evaluate(checked)
|
||||
except errors.InvalidArgumentError as error:
|
||||
caught = error
|
||||
self.assertIn("had -Inf and +Inf values", caught.message)
|
||||
self.assertIn("pass through test", caught.message)
|
||||
|
||||
def testCheckNumericsV2OpNegativeAndPositveInfAndNaN(self):
|
||||
"""CheckNumericsV2 op distinguishes - & + infs when nan is present."""
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([-1.0, 1.0, 0.0])
|
||||
t2 = constant_op.constant([0.0, 0.0, 0.0])
|
||||
checked = array_ops.check_numerics_v2(
|
||||
t1 / t2, message="pass through test")
|
||||
caught = None
|
||||
try:
|
||||
self.evaluate(checked)
|
||||
except errors.InvalidArgumentError as error:
|
||||
caught = error
|
||||
self.assertIn("had -Inf, +Inf, and NaN values", caught.message)
|
||||
self.assertIn("pass through test", caught.message)
|
||||
|
||||
def testCheckNumericsV2PositveInfAndNaN(self):
|
||||
"""Test that CheckNumericsV2 op shows sign of inf when nan is present."""
|
||||
with self.session(graph=ops.Graph()):
|
||||
t1 = constant_op.constant([0.0, 1.0])
|
||||
t2 = constant_op.constant([0.0, 0.0])
|
||||
checked = array_ops.check_numerics_v2(
|
||||
t1 / t2, message="pass through test")
|
||||
caught = None
|
||||
try:
|
||||
self.evaluate(checked)
|
||||
except errors.InvalidArgumentError as error:
|
||||
caught = error
|
||||
self.assertIn("had +Inf and NaN values", caught.message)
|
||||
self.assertIn("pass through test", caught.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -710,15 +710,6 @@ def _CheckNumericsGrad(op, grad):
|
||||
op.get_attr("message"))
|
||||
|
||||
|
||||
@ops.RegisterGradient("CheckNumericsV2")
|
||||
def _CheckNumericsV2Grad(op, grad):
|
||||
"""Gradient for check_numerics op."""
|
||||
return array_ops.check_numerics_v2(
|
||||
grad,
|
||||
"Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
|
||||
op.get_attr("message"))
|
||||
|
||||
|
||||
@ops.RegisterGradient("PlaceholderWithDefault")
|
||||
@ops.RegisterGradient("Identity")
|
||||
def _IdGrad(_, grad):
|
||||
|
@ -2556,8 +2556,8 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
||||
[0, 0, 0, 0]], dtype=int32)>
|
||||
|
||||
Args:
|
||||
shape: A `list` of integers, a `tuple` of integers, or
|
||||
a 1-D `Tensor` of type `int32`.
|
||||
shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
|
||||
type `int32`.
|
||||
dtype: The DType of an element in the resulting `Tensor`.
|
||||
name: Optional string. A name for the operation.
|
||||
|
||||
@ -2787,8 +2787,8 @@ def ones(shape, dtype=dtypes.float32, name=None):
|
||||
[1, 1, 1, 1]], dtype=int32)>
|
||||
|
||||
Args:
|
||||
shape: A `list` of integers, a `tuple` of integers, or
|
||||
a 1-D `Tensor` of type `int32`.
|
||||
shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
|
||||
type `int32`.
|
||||
dtype: Optional DType of an element in the resulting `Tensor`. Default is
|
||||
`tf.float32`.
|
||||
name: Optional string. A name for the operation.
|
||||
@ -4760,8 +4760,8 @@ def quantize(
|
||||
axis=axis)
|
||||
|
||||
|
||||
@tf_export("quantization.dequantize", v1=["quantization.dequantize",
|
||||
"dequantize"])
|
||||
@tf_export(
|
||||
"quantization.dequantize", v1=["quantization.dequantize", "dequantize"])
|
||||
@deprecation.deprecated_endpoints("dequantize")
|
||||
def dequantize( # pylint: disable=missing-docstring
|
||||
input, # pylint: disable=redefined-builtin
|
||||
|
@ -644,10 +644,6 @@ tf_module {
|
||||
name: "CheckNumerics"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CheckNumericsV2"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Cholesky"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -644,10 +644,6 @@ tf_module {
|
||||
name: "CheckNumerics"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CheckNumericsV2"
|
||||
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Cholesky"
|
||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user