Merge pull request #33257 from nouiz:small2
PiperOrigin-RevId: 274526704
This commit is contained in:
commit
7557b548ca
@ -120,6 +120,21 @@ class Array4D : public Array<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fills all of the {p,x} with the array provided, which specifies {z,y}.
|
||||||
|
void FillWithZY(const Array2D<T>& value) {
|
||||||
|
CHECK_EQ(value.height(), depth());
|
||||||
|
CHECK_EQ(value.width(), height());
|
||||||
|
for (int64 plane = 0; plane < planes(); ++plane) {
|
||||||
|
for (int64 depth = 0; depth < this->depth(); ++depth) {
|
||||||
|
for (int64 height = 0; height < this->height(); ++height) {
|
||||||
|
for (int64 width = 0; width < this->width(); ++width) {
|
||||||
|
(*this)(plane, depth, height, width) = value(depth, height);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fills all of the {x,y} with the array provided, which specifies {p,z}.
|
// Fills all of the {x,y} with the array provided, which specifies {p,z}.
|
||||||
void FillWithPZ(const Array2D<T>& value) {
|
void FillWithPZ(const Array2D<T>& value) {
|
||||||
CHECK_EQ(value.height(), planes());
|
CHECK_EQ(value.height(), planes());
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class HloPassFix : public Pass {
|
|||||||
int64 iteration_count = 0;
|
int64 iteration_count = 0;
|
||||||
int64 limit =
|
int64 limit =
|
||||||
std::max(static_cast<int64>(1000), module->instruction_count());
|
std::max(static_cast<int64>(1000), module->instruction_count());
|
||||||
VLOG(3) << "Running HloPassFix.";
|
VLOG(3) << "Running HloPassFix on " << Pass::name();
|
||||||
while (changed_this_iteration) {
|
while (changed_this_iteration) {
|
||||||
TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
|
TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
|
||||||
changed |= changed_this_iteration;
|
changed |= changed_this_iteration;
|
||||||
|
|||||||
@ -619,6 +619,9 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
|
|||||||
<< consumer->ToString();
|
<< consumer->ToString();
|
||||||
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
|
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
|
||||||
fusion_instruction->FuseInstruction(producer);
|
fusion_instruction->FuseInstruction(producer);
|
||||||
|
if (fusion_instruction != producer && fusion_instruction != consumer) {
|
||||||
|
VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
|
||||||
|
}
|
||||||
return fusion_instruction;
|
return fusion_instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -928,8 +928,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
VLOG(2) << StrFormat(
|
VLOG(2) << StrFormat(
|
||||||
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
|
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
|
||||||
HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
|
HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
|
||||||
ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
|
ShapeUtil::HumanStringWithLayout(rhs),
|
||||||
|
StrJoin(broadcast_dimensions, ", "));
|
||||||
|
|
||||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
|
||||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
|
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
|
||||||
|
|
||||||
|
|||||||
@ -38,6 +38,20 @@ Window MakeWindow(absl::Span<const int64> sizes) {
|
|||||||
return window;
|
return window;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Window MakeWindow(absl::Span<const int64> sizes,
|
||||||
|
absl::Span<const int64> strides) {
|
||||||
|
Window window;
|
||||||
|
CHECK_EQ(sizes.size(), strides.size());
|
||||||
|
for (auto nb = 0; nb < sizes.size(); ++nb) {
|
||||||
|
auto* dimension = window.add_dimensions();
|
||||||
|
dimension->set_size(sizes[nb]);
|
||||||
|
dimension->set_stride(strides[nb]);
|
||||||
|
dimension->set_base_dilation(1);
|
||||||
|
dimension->set_window_dilation(1);
|
||||||
|
}
|
||||||
|
return window;
|
||||||
|
}
|
||||||
|
|
||||||
PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
|
PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
|
||||||
PaddingConfig config;
|
PaddingConfig config;
|
||||||
for (int64 size : sizes) {
|
for (int64 size : sizes) {
|
||||||
|
|||||||
@ -27,6 +27,10 @@ namespace window_util {
|
|||||||
// to 1.
|
// to 1.
|
||||||
Window MakeWindow(absl::Span<const int64> sizes);
|
Window MakeWindow(absl::Span<const int64> sizes);
|
||||||
|
|
||||||
|
// Creates a window with the given sizes in the dimensions and given strides.
|
||||||
|
Window MakeWindow(absl::Span<const int64> sizes,
|
||||||
|
absl::Span<const int64> strides);
|
||||||
|
|
||||||
// Creates a padding config with symmetrical padding in each dimension, of value
|
// Creates a padding config with symmetrical padding in each dimension, of value
|
||||||
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
|
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
|
||||||
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
|
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
|
||||||
|
|||||||
@ -30,5 +30,14 @@ TEST(WindowUtilTest, HasOverlappingWindowTest) {
|
|||||||
window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2})));
|
window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(WindowUtilTest, MakeWindowStrideTest) {
|
||||||
|
// MakeWindow() set a stride of 1 by default.
|
||||||
|
Window w = window_util::MakeWindow({1, 2}, {3, 4});
|
||||||
|
EXPECT_EQ(w.dimensions()[0].size(), 1);
|
||||||
|
EXPECT_EQ(w.dimensions()[1].size(), 2);
|
||||||
|
EXPECT_EQ(w.dimensions()[0].stride(), 3);
|
||||||
|
EXPECT_EQ(w.dimensions()[1].stride(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user