Merge pull request #33257 from nouiz:small2
PiperOrigin-RevId: 274526704
This commit is contained in:
commit
7557b548ca
tensorflow/compiler/xla
@ -120,6 +120,21 @@ class Array4D : public Array<T> {
|
||||
}
|
||||
}
|
||||
|
||||
// Fills all of the {p,x} with the array provided, which specifies {z,y}.
|
||||
void FillWithZY(const Array2D<T>& value) {
|
||||
CHECK_EQ(value.height(), depth());
|
||||
CHECK_EQ(value.width(), height());
|
||||
for (int64 plane = 0; plane < planes(); ++plane) {
|
||||
for (int64 depth = 0; depth < this->depth(); ++depth) {
|
||||
for (int64 height = 0; height < this->height(); ++height) {
|
||||
for (int64 width = 0; width < this->width(); ++width) {
|
||||
(*this)(plane, depth, height, width) = value(depth, height);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fills all of the {x,y} with the array provided, which specifies {p,z}.
|
||||
void FillWithPZ(const Array2D<T>& value) {
|
||||
CHECK_EQ(value.height(), planes());
|
||||
|
@ -40,7 +40,7 @@ class HloPassFix : public Pass {
|
||||
int64 iteration_count = 0;
|
||||
int64 limit =
|
||||
std::max(static_cast<int64>(1000), module->instruction_count());
|
||||
VLOG(3) << "Running HloPassFix.";
|
||||
VLOG(3) << "Running HloPassFix on " << Pass::name();
|
||||
while (changed_this_iteration) {
|
||||
TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
|
||||
changed |= changed_this_iteration;
|
||||
|
@ -619,6 +619,9 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
|
||||
<< consumer->ToString();
|
||||
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
|
||||
fusion_instruction->FuseInstruction(producer);
|
||||
if (fusion_instruction != producer && fusion_instruction != consumer) {
|
||||
VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
|
||||
}
|
||||
return fusion_instruction;
|
||||
}
|
||||
|
||||
|
@ -928,8 +928,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
VLOG(2) << StrFormat(
|
||||
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
|
||||
HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
|
||||
ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
|
||||
HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
|
||||
ShapeUtil::HumanStringWithLayout(rhs),
|
||||
StrJoin(broadcast_dimensions, ", "));
|
||||
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
|
||||
|
||||
|
@ -38,6 +38,20 @@ Window MakeWindow(absl::Span<const int64> sizes) {
|
||||
return window;
|
||||
}
|
||||
|
||||
Window MakeWindow(absl::Span<const int64> sizes,
|
||||
absl::Span<const int64> strides) {
|
||||
Window window;
|
||||
CHECK_EQ(sizes.size(), strides.size());
|
||||
for (auto nb = 0; nb < sizes.size(); ++nb) {
|
||||
auto* dimension = window.add_dimensions();
|
||||
dimension->set_size(sizes[nb]);
|
||||
dimension->set_stride(strides[nb]);
|
||||
dimension->set_base_dilation(1);
|
||||
dimension->set_window_dilation(1);
|
||||
}
|
||||
return window;
|
||||
}
|
||||
|
||||
PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
|
||||
PaddingConfig config;
|
||||
for (int64 size : sizes) {
|
||||
|
@ -27,6 +27,10 @@ namespace window_util {
|
||||
// to 1.
|
||||
Window MakeWindow(absl::Span<const int64> sizes);
|
||||
|
||||
// Creates a window with the given sizes in the dimensions and given strides.
|
||||
Window MakeWindow(absl::Span<const int64> sizes,
|
||||
absl::Span<const int64> strides);
|
||||
|
||||
// Creates a padding config with symmetrical padding in each dimension, of value
|
||||
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
|
||||
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
|
||||
|
@ -30,5 +30,14 @@ TEST(WindowUtilTest, HasOverlappingWindowTest) {
|
||||
window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2})));
|
||||
}
|
||||
|
||||
TEST(WindowUtilTest, MakeWindowStrideTest) {
|
||||
// MakeWindow() set a stride of 1 by default.
|
||||
Window w = window_util::MakeWindow({1, 2}, {3, 4});
|
||||
EXPECT_EQ(w.dimensions()[0].size(), 1);
|
||||
EXPECT_EQ(w.dimensions()[1].size(), 2);
|
||||
EXPECT_EQ(w.dimensions()[0].stride(), 3);
|
||||
EXPECT_EQ(w.dimensions()[1].stride(), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user