tfdbg: add runtime shape and dtype info to DebugNumericSummary

PiperOrigin-RevId: 157291215
This commit is contained in:
Shanqing Cai 2017-05-26 21:32:19 -07:00 committed by TensorFlower Gardener
parent 4fb2425f8f
commit b4466279a6
9 changed files with 297 additions and 167 deletions
tensorflow
core
python/debug
tensorboard
backend/event_processing
components
tf_backend
tf_graph_common

View File

@ -279,9 +279,9 @@ class DebugNumericSummaryOp : public BaseDebugOp {
// Equal to negative_count + zero_count + positive_count.
int64 non_inf_nan_count = 0;
const TensorShape& input_shape = input.shape();
if (input.IsInitialized()) {
is_initialized = 1;
const TensorShape& input_shape = input.shape();
const T* input_flat = input.template flat<T>().data();
element_count = input_shape.num_elements();
@ -338,7 +338,7 @@ class DebugNumericSummaryOp : public BaseDebugOp {
}
}
TensorShape shape({12});
TensorShape shape({14 + input_shape.dims()});
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
output_tensor->vec<double>()(0) = static_cast<double>(is_initialized);
output_tensor->vec<double>()(1) = static_cast<double>(element_count);
@ -353,6 +353,13 @@ class DebugNumericSummaryOp : public BaseDebugOp {
output_tensor->vec<double>()(10) = mean;
output_tensor->vec<double>()(11) = variance;
output_tensor->vec<double>()(12) = static_cast<double>(input.dtype());
output_tensor->vec<double>()(13) = static_cast<double>(input_shape.dims());
for (size_t d = 0; d < input_shape.dims(); ++d) {
output_tensor->vec<double>()(14 + d) =
static_cast<double>(input_shape.dim_sizes()[d]);
}
bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
positive_inf_count == 0;
if (!mute) {

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -267,21 +268,24 @@ TEST_F(DebugNumericSummaryOpTest, Float_full_house) {
std::numeric_limits<float>::quiet_NaN()});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({15}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
2.0, // -inf count.
2.0, // negative number count (excluding -inf).
3.0, // zero count.
2.0, // positive number count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
2.0, // -inf count.
2.0, // negative number count (excluding -inf).
3.0, // zero count.
2.0, // positive number count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673, // variance of non-inf and non-nan elements.
static_cast<double>(DT_FLOAT), // dtype.
1.0, // Number of dimensions.
18.0}); // Dimension size.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -303,21 +307,24 @@ TEST_F(DebugNumericSummaryOpTest, Double_full_house) {
std::numeric_limits<double>::quiet_NaN()});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({15}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
2.0, // -inf count.
2.0, // negative count (excluding -inf).
3.0, // zero count.
2.0, // positive count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
2.0, // -inf count.
2.0, // negative count (excluding -inf).
3.0, // zero count.
2.0, // positive count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673, // variance of non-inf and non-nan elements.
static_cast<double>(DT_DOUBLE), // dtype.
1.0, // Number of dimensions.
18.0}); // Dimension size.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -328,21 +335,24 @@ TEST_F(DebugNumericSummaryOpTest, Float_only_valid_values) {
{0.0f, 0.0f, -1.0f, 3.0f, 3.0f, 7.0f});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
2.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
2.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333, // variance of non-inf and non-nan elements.
static_cast<double>(DT_FLOAT), // dtype
2.0, // Number of dimensions.
2.0, 3.0}); // Dimensoin sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -364,7 +374,6 @@ TEST_F(DebugNumericSummaryOpTest, Float_all_Inf_or_NaN) {
Tensor output_tensor = *GetOutput(0);
const double* output = output_tensor.template flat<double>().data();
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
// Use ASSERT_NEAR below because test::ExpectTensorNear does not work with
// NaNs.
ASSERT_NEAR(1.0, output[0], 1e-8); // Is initialized.
@ -381,6 +390,69 @@ TEST_F(DebugNumericSummaryOpTest, Float_all_Inf_or_NaN) {
ASSERT_EQ(-std::numeric_limits<float>::infinity(), output[9]);
ASSERT_TRUE(Eigen::numext::isnan(output[10]));
ASSERT_TRUE(Eigen::numext::isnan(output[11]));
ASSERT_EQ(static_cast<double>(DT_FLOAT), output[12]);
ASSERT_EQ(2.0, output[13]);
ASSERT_EQ(3.0, output[14]);
ASSERT_EQ(3.0, output[15]);
}
TEST_F(DebugNumericSummaryOpTest, Many_dimensions_tensor_shape) {
TF_ASSERT_OK(Init(DT_FLOAT));
AddInputFromArray<float>(TensorShape({1, 3, 1, 1, 1, 1, 1}),
{std::numeric_limits<float>::quiet_NaN(),
-std::numeric_limits<float>::infinity(), -8.0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({21}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
3.0, // Total element count.
1.0, // nan count.
1.0, // -inf count.
1.0, // negative number count (excluding -inf).
0.0, // zero count.
0.0, // positive number count (excluding +inf).
0.0, // +inf count.
-8.0, // minimum of non-inf and non-nan elements.
-8.0, // maximum of non-inf and non-nan elements.
-8.0, // mean of non-inf and non-nan elements.
0.0, // variance of non-inf and non-nan elements.
static_cast<double>(DT_FLOAT), // dtype.
7.0, // Number of dimensions.
1.0,
3.0,
1.0,
1.0,
1.0,
1.0,
1.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
TEST_F(DebugNumericSummaryOpTest, Scalar_tensor_shape) {
TF_ASSERT_OK(Init(DT_FLOAT));
AddInputFromArray<float>(TensorShape({}), {42.0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({14}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
1.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
0.0, // negative number count (excluding -inf).
0.0, // zero count.
1.0, // positive number count (excluding +inf).
0.0, // +inf count.
42.0, // minimum of non-inf and non-nan elements.
42.0, // maximum of non-inf and non-nan elements.
42.0, // mean of non-inf and non-nan elements.
0.0, // variance of non-inf and non-nan elements.
static_cast<double>(DT_FLOAT), // dtype.
0.0}); // Number of dimensions.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
TEST_F(DebugNumericSummaryOpTest, Int16Success) {
@ -388,21 +460,23 @@ TEST_F(DebugNumericSummaryOpTest, Int16Success) {
AddInputFromArray<int16>(TensorShape({4, 1}), {-1, -3, 3, 7});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
4.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
2.0, // negative count (excluding -inf).
0.0, // zero count.
2.0, // positive count (excluding +inf).
0.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
1.5, // mean of non-inf and non-nan elements.
14.75}); // variance of non-inf and non-nan elements.
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
4.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
2.0, // negative count (excluding -inf).
0.0, // zero count.
2.0, // positive count (excluding +inf).
0.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
1.5, // mean of non-inf and non-nan elements.
14.75, // variance of non-inf and non-nan elements.
static_cast<double>(DT_INT16), // dtype.
2.0, // Number of dimensions.
4.0, 1.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -412,21 +486,24 @@ TEST_F(DebugNumericSummaryOpTest, Int32Success) {
AddInputFromArray<int32>(TensorShape({2, 3}), {0, 0, -1, 3, 3, 7});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
2.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
2.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333, // variance of non-inf and non-nan elements.
static_cast<double>(DT_INT32), // dtype.
2.0, // Number of dimensions.
2.0, 3.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -436,21 +513,23 @@ TEST_F(DebugNumericSummaryOpTest, Int64Success) {
AddInputFromArray<int64>(TensorShape({2, 2, 2}), {0, 0, -1, 3, 3, 7, 0, 0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
8.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
4.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
1.5, // mean of non-inf and non-nan elements.
6.25}); // variance of non-inf and non-nan elements.
Tensor expected(allocator(), DT_DOUBLE, TensorShape({17}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
8.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
1.0, // negative count (excluding -inf).
4.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
1.5, // mean of non-inf and non-nan elements.
6.25, // variance of non-inf and non-nan elements.
static_cast<double>(DT_INT64), // dtype.
3.0, // Number of dimensions.
2.0, 2.0, 2.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -460,21 +539,23 @@ TEST_F(DebugNumericSummaryOpTest, UInt8Success) {
AddInputFromArray<uint8>(TensorShape({1, 5}), {0, 10, 30, 30, 70});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
5.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
0.0, // negative count (excluding -inf).
1.0, // zero count.
4.0, // positive count (excluding +inf).
0.0, // +inf count.
0.0, // minimum of non-inf and non-nan elements.
70.0, // maximum of non-inf and non-nan elements.
28.0, // mean of non-inf and non-nan elements.
576.0}); // variance of non-inf and non-nan elements.
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
5.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
0.0, // negative count (excluding -inf).
1.0, // zero count.
4.0, // positive count (excluding +inf).
0.0, // +inf count.
0.0, // minimum of non-inf and non-nan elements.
70.0, // maximum of non-inf and non-nan elements.
28.0, // mean of non-inf and non-nan elements.
576.0, // variance of non-inf and non-nan elements.
static_cast<double>(DT_UINT8), // dtypes.
2.0, // Number of dimensions.
1.0, 5.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -484,21 +565,23 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) {
AddInputFromArray<bool>(TensorShape({2, 3}), {0, 0, 1, 1, 1, 0});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
0.0, // negative count (excluding -inf).
3.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
0.0, // minimum of non-inf and non-nan elements.
1.0, // maximum of non-inf and non-nan elements.
0.5, // mean of non-inf and non-nan elements.
0.25}); // variance of non-inf and non-nan elements.
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(&expected,
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
0.0, // -inf count.
0.0, // negative count (excluding -inf).
3.0, // zero count.
3.0, // positive count (excluding +inf).
0.0, // +inf count.
0.0, // minimum of non-inf and non-nan elements.
1.0, // maximum of non-inf and non-nan elements.
0.5, // mean of non-inf and non-nan elements.
0.25, // variance of non-inf and non-nan elements.
static_cast<double>(DT_BOOL), // dtype.
2.0, // Number of dimensions.
2.0, 3.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -562,21 +645,24 @@ TEST_F(DebugNumericSummaryOpCustomLowerBoundTest, Float_full_house) {
std::numeric_limits<float>::quiet_NaN()});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({15}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
3.0, // -inf count.
1.0, // negative number count (excluding -inf).
3.0, // zero count.
2.0, // positive number count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
18.0, // Total element count.
4.0, // nan count.
3.0, // -inf count.
1.0, // negative number count (excluding -inf).
3.0, // zero count.
2.0, // positive number count (excluding +inf).
5.0, // +inf count.
-3.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
0.85714285714, // mean of non-inf and non-nan elements.
8.97959183673, // variance of non-inf and non-nan elements.
static_cast<double>(DT_FLOAT), // dtype.
1.0, // Number of dimensions.
18.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}
@ -600,21 +686,24 @@ TEST_F(DebugNumericSummaryOpCustomLowerUpperBoundsTest, Int32Success) {
AddInputFromArray<int32>(TensorShape({2, 3}), {0, 0, -1, 3, 3, 7});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_DOUBLE, TensorShape({12}));
Tensor expected(allocator(), DT_DOUBLE, TensorShape({16}));
test::FillValues<double>(
&expected,
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
1.0, // -inf count.
0.0, // negative count (excluding -inf).
2.0, // zero count.
2.0, // positive count (excluding +inf).
1.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333}); // variance of non-inf and non-nan elements.
{1.0, // Is initialized.
6.0, // Total element count.
0.0, // nan count.
1.0, // -inf count.
0.0, // negative count (excluding -inf).
2.0, // zero count.
2.0, // positive count (excluding +inf).
1.0, // +inf count.
-1.0, // minimum of non-inf and non-nan elements.
7.0, // maximum of non-inf and non-nan elements.
2.0, // mean of non-inf and non-nan elements.
7.33333333333, // variance of non-inf and non-nan elements.
static_cast<double>(DT_INT32), // dtype.
2.0, // Number of dimensions.
2.0, 3.0}); // Dimension sizes.
test::ExpectTensorNear<double>(expected, *GetOutput(0), 1e-8);
}

View File

@ -151,7 +151,8 @@ Debug Numeric Summary Op.
Provide a basic summary of numeric value types, range and distribution.
input: Input tensor, non-Reference type, float or double.
output: A double tensor of shape [12], the elements of which are:
output: A double tensor of shape [14 + nDimensions], where nDimensions is the
the number of dimensions of the tensor's shape. The elements of output are:
[0]: is initialized (1.0) or not (0.0).
[1]: total number of elements
[2]: NaN element count
@ -173,6 +174,10 @@ Output elements [1:8] are all zero, if the tensor is uninitialized.
If uninitialized or no such element exists: NaN.
[11]: variance of all non-inf and non-NaN elements.
If uninitialized or no such element exists: NaN.
[12]: Data type of the tensor encoded as an enum integer. See the DataType
proto for more details.
[13]: Number of dimensions of the tensor (ndims).
[14+]: Sizes of the dimensions.
tensor_name: Name of the input tensor.
debug_urls: List of URLs to debug targets, e.g.,

View File

@ -1189,7 +1189,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
self.assertAllClose([[
1.0, 18.0, 4.0, 2.0, 2.0, 3.0, 2.0, 5.0, -3.0, 7.0, 0.85714286,
8.97959184
8.97959184, 1.0, 1.0, 18.0
]], dump.get_tensors("numeric_summary/a/read", 0, "DebugNumericSummary"))
def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
@ -1217,6 +1217,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
"DebugNumericSummary")[0]
self.assertAllClose([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
numeric_summary[0:8])
# Check dtype (index 12), ndims (index 13) and dimension sizes (index
# 14+).
self.assertAllClose([1.0, 1.0, 1.0], numeric_summary[12:])
self.assertTrue(np.isinf(numeric_summary[8]))
self.assertGreater(numeric_summary[8], 0.0)
self.assertTrue(np.isinf(numeric_summary[9]))
@ -1349,14 +1352,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
# debug ops with mute_if_healthy=false attribute during validation.
self.assertEqual(2, dump.size)
self.assertAllClose(
[[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, np.nan,
np.nan]],
dump.get_tensors("x", 0, "DebugNumericSummary"))
self.assertAllClose(
[[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, np.inf, -np.inf, np.nan,
np.nan]],
dump.get_tensors("y", 0, "DebugNumericSummary"))
self.assertAllClose([[
1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, np.nan,
np.nan, 1.0, 0.0
]], dump.get_tensors("x", 0, "DebugNumericSummary"))
self.assertAllClose([[
1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, np.inf, -np.inf, np.nan,
np.nan, 1.0, 0.0
]], dump.get_tensors("y", 0, "DebugNumericSummary"))
# Another run with the default mute_if_healthy (false) value should
# dump all the tensors.
@ -1402,9 +1405,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
# debug ops with mute_if_healthy=false attribute during validation.
self.assertEqual(1, dump.size)
self.assertAllClose(
[[1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 20.0, 16.0, 16.0]],
dump.get_tensors("x", 0, "DebugNumericSummary"))
self.assertAllClose([[
1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 20.0, 16.0, 16.0, 1.0,
1.0, 2.0]], dump.get_tensors("x", 0, "DebugNumericSummary"))
def testDebugQueueOpsDoesNotoErrorOut(self):
with session.Session() as sess:

View File

@ -208,7 +208,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
dump = debug_data.DebugDumpDir(dump_dirs[0])
self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
self.assertEqual(12,
self.assertEqual(14,
len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
def testDumpingWithWatchFnWithNonDefaultDebugOpsWorks(self):
@ -235,7 +235,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
dump = debug_data.DebugDumpDir(dump_dirs[0])
self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
self.assertEqual(12,
self.assertEqual(14,
len(dump.get_tensors("v", 0, "DebugNumericSummary")[0]))
dumped_nodes = [dump.node_name for dump in dump.dumped_tensor_data]

View File

@ -26,6 +26,7 @@ import numpy as np
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.tensorboard.backend.event_processing import directory_watcher
from tensorflow.tensorboard.backend.event_processing import event_file_loader
@ -36,8 +37,8 @@ namedtuple = collections.namedtuple
ScalarEvent = namedtuple('ScalarEvent', ['wall_time', 'step', 'value'])
HealthPillEvent = namedtuple('HealthPillEvent', [
'wall_time', 'step', 'device_name', 'node_name', 'output_slot', 'value'
])
'wall_time', 'step', 'device_name', 'node_name', 'output_slot', 'dtype',
'shape', 'value'])
CompressedHistogramEvent = namedtuple('CompressedHistogramEvent',
['wall_time', 'step',
@ -685,7 +686,7 @@ class EventAccumulator(object):
device_name: The name of the node's device.
node_name: The name of the node for this health pill.
output_slot: The output slot for this health pill.
elements: An ND array of 12 floats. The elements of the health pill.
elements: An ND array of 20 floats. The elements of the health pill.
"""
# Key by the node name for fast retrieval of health pills by node name. The
# array is cast to a list so that it is JSON-able. The debugger data plugin
@ -697,6 +698,8 @@ class EventAccumulator(object):
device_name=device_name,
node_name=node_name,
output_slot=output_slot,
dtype=repr(dtypes.as_dtype(elements[12])),
shape=list(elements[14:]),
value=list(elements)))
def _Purge(self, event, by_tags):

View File

@ -281,10 +281,14 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
"""HealthPills should be properly inserted into EventAccumulator."""
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
health_pill_elements_1 = list(range(1, 13)) + [
float(types_pb2.DT_FLOAT), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0',
'Add', 0, range(1, 13))
'Add', 0, health_pill_elements_1)
health_pill_elements_2 = list(range(42, 54)) + [
float(types_pb2.DT_DOUBLE), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/gpu:0',
'Add', 1, range(42, 54))
'Add', 1, health_pill_elements_2)
acc.Reload()
# Retrieve the health pills for each node name.
@ -297,7 +301,9 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
device_name='/job:localhost/replica:0/task:0/cpu:0',
node_name='Add',
output_slot=0,
value=range(1, 13)), gotten_events[0])
dtype='tf.float32',
shape=[1, 2],
value=health_pill_elements_1), gotten_events[0])
self._compareHealthPills(
ea.HealthPillEvent(
wall_time=13381338,
@ -305,15 +311,21 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
step=42,
node_name='Add',
output_slot=1,
value=range(42, 54)), gotten_events[1])
dtype='tf.float64',
shape=[3, 4],
value=health_pill_elements_2), gotten_events[1])
def testGetOpsWithHealthPills(self):
gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen)
health_pill_elements_1 = list(range(1, 13)) + [
float(types_pb2.DT_FLOAT), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0',
'Add', 0, range(1, 13))
'Add', 0, health_pill_elements_1)
health_pill_elements_2 = list(range(42, 54)) + [
float(types_pb2.DT_DOUBLE), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/cpu:0',
'MatMul', 1, range(42, 54))
'MatMul', 1, health_pill_elements_2)
acc.Reload()
self.assertItemsEqual(['Add', 'MatMul'], acc.GetOpsWithHealthPills())

View File

@ -87,6 +87,8 @@ export interface HealthPill {
device_name: string;
node_name: string;
output_slot: number;
dtype: string;
shape: number[];
value: number[];
}

View File

@ -83,6 +83,8 @@ module tf.graph.scene {
device_name: string;
node_name: string;
output_slot: number;
dtype: string;
shape: number[];
value: number[];
wall_time: number;
step: number;
@ -541,6 +543,12 @@ function _addHealthPill(
}
const deviceName = healthPill.device_name;
const dtypeName = healthPill.dtype;
let shapeStr = '(';
for (const dimSize of healthPill.shape) {
shapeStr += dimSize + ',';
}
shapeStr += ')';
let lastHealthPillData = healthPill.value;
// For now, we only visualize the 6 values that summarize counts of tensor
@ -610,11 +618,12 @@ function _addHealthPill(
// Show a title with specific counts on hover.
let titleSvg = document.createElementNS(svgNamespace, 'title');
titleSvg.textContent = 'Device: ' + deviceName + '\n\n' +
'#(elements): ' + totalCount + '\n\n' +
titleSvg.textContent = 'Device: ' + deviceName + '\ndtype: ' + dtypeName +
'\nshape: ' + shapeStr + '\n\n#(elements): ' + totalCount + '\n' +
titleOnHoverTextEntries.join(', ') + '\n\nmin: ' + minVal +
', max: ' + maxVal + '\nmean: ' + meanVal + ', stddev: ' + stddevVal;
healthPillGroup.appendChild(titleSvg);
// TODO(cais): Make the tooltip content prettier.
// Center this health pill just right above the node for the op.
let healthPillX = nodeInfo.x - healthPillWidth / 2;