- Bug fix in ShapeInference
- CommaSeparatedString and VectorString added to xla_util.h - ReferenceUtil can now do more general Pad ops. Change: 153782516
This commit is contained in:
parent
7dba5ab874
commit
c1bd0fe248
@ -649,4 +649,39 @@ ReferenceUtil::ReduceToRowArray2D(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ Array4D<float> ReferenceUtil::PadArray4D(
|
||||||
|
const Array4D<float>& operand, const PaddingConfig& padding,
|
||||||
|
const float pad) {
|
||||||
|
CHECK_EQ(padding.dimensions_size(), 4);
|
||||||
|
|
||||||
|
const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
|
||||||
|
operand.n3(), operand.n4()};
|
||||||
|
std::vector<int64> pad_low(4);
|
||||||
|
std::vector<int64> pad_high(4);
|
||||||
|
std::vector<int64> output_bounds(4);
|
||||||
|
for (int64 i = 0; i < 4; ++i) {
|
||||||
|
pad_low[i] = padding.dimensions(i).edge_padding_low();
|
||||||
|
pad_high[i] = padding.dimensions(i).edge_padding_high();
|
||||||
|
CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented";
|
||||||
|
|
||||||
|
output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
|
||||||
|
output_bounds[3]);
|
||||||
|
result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
bool in_low_padding = indices[i] < pad_low[i];
|
||||||
|
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
|
||||||
|
if (in_low_padding || in_high_padding) {
|
||||||
|
*value = pad;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1],
|
||||||
|
indices[2] - pad_low[2], indices[3] - pad_low[3]);
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -395,6 +395,11 @@ class ReferenceUtil {
|
|||||||
const Array2D<float>& operand, const PaddingConfig& padding,
|
const Array2D<float>& operand, const PaddingConfig& padding,
|
||||||
const float pad);
|
const float pad);
|
||||||
|
|
||||||
|
// Returns the result of a 4D pad on an input array.
|
||||||
|
static Array4D<float> PadArray4D(const Array4D<float>& operand,
|
||||||
|
const PaddingConfig& padding,
|
||||||
|
const float pad);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
|
TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
|
||||||
};
|
};
|
||||||
|
@ -309,6 +309,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"the rank of the operand and the padding configuration do not match.");
|
"the rank of the operand and the padding configuration do not match.");
|
||||||
}
|
}
|
||||||
|
if (operand_shape.element_type() != padding_value_shape.element_type()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"the element types of the operands to pad do not match");
|
||||||
|
}
|
||||||
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
|
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
|
||||||
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
|
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
|
||||||
dimensions[i] = operand_shape.dimensions(i) +
|
dimensions[i] = operand_shape.dimensions(i) +
|
||||||
@ -338,7 +342,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
|
|
||||||
// Check if both element types are the same.
|
// Check if both element types are the same.
|
||||||
if (lhs.element_type() != rhs.element_type()) {
|
if (lhs.element_type() != rhs.element_type()) {
|
||||||
return fail("element types mismatch");
|
return fail("element types do not match");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
|
if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
#include "tensorflow/core/lib/math/math_util.h"
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
#include "tensorflow/core/lib/strings/numbers.h"
|
#include "tensorflow/core/lib/strings/numbers.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -200,6 +201,46 @@ int64 PositionInContainer(const Container& container, int64 value) {
|
|||||||
std::find(container.begin(), container.end(), value));
|
std::find(container.begin(), container.end(), value));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Formats the container as a comma-separated string. StrAppend must support
|
||||||
|
// appending the elements of the container. Prefix is prepended and suffix is
|
||||||
|
// appended to the returned string.
|
||||||
|
template <typename Container>
|
||||||
|
string CommaSeparatedString(const Container& c, const char* prefix = "",
|
||||||
|
const char* suffix = "") {
|
||||||
|
// Not using Join() since the implementation here is simple anyway and this
|
||||||
|
// avoids copying the string to append prefix.
|
||||||
|
string comma_separated = prefix;
|
||||||
|
const char* separator = "";
|
||||||
|
for (const auto& entry : c) {
|
||||||
|
tensorflow::strings::StrAppend(&comma_separated, separator, entry);
|
||||||
|
separator = ", ";
|
||||||
|
}
|
||||||
|
comma_separated += suffix;
|
||||||
|
return comma_separated;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overload needed to allow the container to be an initializer list. The default
|
||||||
|
// type for T makes an empty initializer list work as well.
|
||||||
|
template <typename T = int>
|
||||||
|
string CommaSeparatedString(const std::initializer_list<T>& c,
|
||||||
|
const char* prefix = "", const char* suffix = "") {
|
||||||
|
return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Formats the container in the mathematical notation for a vector, e.g. (1, 3,
|
||||||
|
// 7). StrAppend must support appending the elements of c.
|
||||||
|
template <typename Container>
|
||||||
|
string VectorString(const Container& c) {
|
||||||
|
return CommaSeparatedString(c, "(", ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overload needed to allow the container to be an initializer list. The default
|
||||||
|
// type for T makes an empty initializer list work as well.
|
||||||
|
template <typename T = int>
|
||||||
|
string VectorString(const std::initializer_list<T>& c) {
|
||||||
|
return VectorString<std::initializer_list<T>>(c);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns a PaddingConfig object that represents no padding for the given rank.
|
// Returns a PaddingConfig object that represents no padding for the given rank.
|
||||||
PaddingConfig MakeNoPaddingConfig(int64 rank);
|
PaddingConfig MakeNoPaddingConfig(int64 rank);
|
||||||
|
|
||||||
|
@ -80,6 +80,26 @@ TEST(UtilTest, HumanReadableNumFlopsExample) {
|
|||||||
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
|
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UtilTest, CommaSeparatedString) {
|
||||||
|
EXPECT_EQ(CommaSeparatedString({}), "");
|
||||||
|
EXPECT_EQ(CommaSeparatedString({"hello world"}), "hello world");
|
||||||
|
EXPECT_EQ(CommaSeparatedString({1, 57, 2}, "foo", "bar"), "foo1, 57, 2bar");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UtilTest, VectorString) {
|
||||||
|
std::list<int64> empty_list;
|
||||||
|
EXPECT_EQ(VectorString(empty_list), "()");
|
||||||
|
|
||||||
|
std::vector<float> float_vector = {5.5};
|
||||||
|
EXPECT_EQ(VectorString(float_vector), "(5.5)");
|
||||||
|
|
||||||
|
std::set<const char*> string_set = {"a", "b"};
|
||||||
|
EXPECT_EQ(VectorString(string_set), "(a, b)");
|
||||||
|
|
||||||
|
EXPECT_EQ(VectorString({}), "()");
|
||||||
|
EXPECT_EQ(VectorString({1, 57, 2}), "(1, 57, 2)");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(UtilTest, LogLines) {
|
TEST(UtilTest, LogLines) {
|
||||||
// Just make sure this code runs (not verifying the output).
|
// Just make sure this code runs (not verifying the output).
|
||||||
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
||||||
|
Loading…
Reference in New Issue
Block a user