From 60adca7bce76c76cab2f872426cecf4b63a64526 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sun, 7 May 2017 14:18:47 -0800 Subject: [PATCH] [XLA] Make LiteralUtil::EachCellAsString handle all ranks LiteralUtil::EachCellAsString was hard coded to handle only up to rank 4. This simply reworks it to handle all ranks. Change: 155331612 --- tensorflow/compiler/xla/literal_util.cc | 45 ++++--------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 4ec695b2f59..3b3e004864a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -736,47 +736,14 @@ template const Literal& literal, const std::function indices, const string& value)>& per_cell) { - if (ShapeUtil::Rank(literal.shape()) == 1) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - per_cell({i0}, GetAsString(literal, {i0})); - } + if (ShapeUtil::HasZeroElements(literal.shape())) { return; } - - if (ShapeUtil::Rank(literal.shape()) == 2) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - per_cell({i0, i1}, GetAsString(literal, {i0, i1})); - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 3) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2})); - } - } - } - return; - } - - if (ShapeUtil::Rank(literal.shape()) == 4) { - for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) { - for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) { - for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) { - for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) { - per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3})); - } - } - } - } - return; - } - - LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape()); + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + literal.shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(literal, indices)); + } while (IndexUtil::BumpIndices(literal.shape(), &indices)); } namespace {