Add utiliry fct window_util::MakeWindow(sizes, strides)

This commit is contained in:
Frederic Bastien 2019-10-08 12:00:28 -07:00
parent 539de53a75
commit 392d6c2da3
3 changed files with 25 additions and 0 deletions

View File

@ -38,6 +38,19 @@ Window MakeWindow(absl::Span<const int64> sizes) {
return window; return window;
} }
Window MakeWindow(absl::Span<const int64> sizes, absl::Span<const int64> strides) {
Window window;
CHECK_EQ(sizes.size(), strides.size());
for (auto nb=0; nb < sizes.size(); ++nb) {
auto* dimension = window.add_dimensions();
dimension->set_size(sizes[nb]);
dimension->set_stride(strides[nb]);
dimension->set_base_dilation(1);
dimension->set_window_dilation(1);
}
return window;
}
PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) { PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
PaddingConfig config; PaddingConfig config;
for (int64 size : sizes) { for (int64 size : sizes) {

View File

@ -27,6 +27,9 @@ namespace window_util {
// to 1. // to 1.
Window MakeWindow(absl::Span<const int64> sizes); Window MakeWindow(absl::Span<const int64> sizes);
// Creates a window with the given sizes in the dimensions and given strides.
Window MakeWindow(absl::Span<const int64> sizes, absl::Span<const int64> strides);
// Creates a padding config with symmetrical padding in each dimension, of value // Creates a padding config with symmetrical padding in each dimension, of value
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero // given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each // pixels of padding in dimension 0, one pixel of padding symmetrically, on each

View File

@ -30,5 +30,14 @@ TEST(WindowUtilTest, HasOverlappingWindowTest) {
window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2}))); window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2})));
} }
TEST(WindowUtilTest, MakeWindowStrideTest) {
// MakeWindow() set a stride of 1 by default.
Window w = window_util::MakeWindow({1, 2}, {3, 4});
EXPECT_EQ(w.dimensions()[0].size(), 1);
EXPECT_EQ(w.dimensions()[1].size(), 2);
EXPECT_EQ(w.dimensions()[0].stride(), 3);
EXPECT_EQ(w.dimensions()[1].stride(), 4);
}
} // namespace } // namespace
} // namespace xla } // namespace xla