Merge pull request #33257 from nouiz:small2
PiperOrigin-RevId: 274526704
This commit is contained in:
		
						commit
						7557b548ca
					
				@ -120,6 +120,21 @@ class Array4D : public Array<T> {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Fills all of the {p,x} with the array provided, which specifies {z,y}.
 | 
			
		||||
  void FillWithZY(const Array2D<T>& value) {
 | 
			
		||||
    CHECK_EQ(value.height(), depth());
 | 
			
		||||
    CHECK_EQ(value.width(), height());
 | 
			
		||||
    for (int64 plane = 0; plane < planes(); ++plane) {
 | 
			
		||||
      for (int64 depth = 0; depth < this->depth(); ++depth) {
 | 
			
		||||
        for (int64 height = 0; height < this->height(); ++height) {
 | 
			
		||||
          for (int64 width = 0; width < this->width(); ++width) {
 | 
			
		||||
            (*this)(plane, depth, height, width) = value(depth, height);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Fills all of the {x,y} with the array provided, which specifies {p,z}.
 | 
			
		||||
  void FillWithPZ(const Array2D<T>& value) {
 | 
			
		||||
    CHECK_EQ(value.height(), planes());
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ class HloPassFix : public Pass {
 | 
			
		||||
    int64 iteration_count = 0;
 | 
			
		||||
    int64 limit =
 | 
			
		||||
        std::max(static_cast<int64>(1000), module->instruction_count());
 | 
			
		||||
    VLOG(3) << "Running HloPassFix.";
 | 
			
		||||
    VLOG(3) << "Running HloPassFix on " << Pass::name();
 | 
			
		||||
    while (changed_this_iteration) {
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
 | 
			
		||||
      changed |= changed_this_iteration;
 | 
			
		||||
 | 
			
		||||
@ -619,6 +619,9 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
 | 
			
		||||
          << consumer->ToString();
 | 
			
		||||
  HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
 | 
			
		||||
  fusion_instruction->FuseInstruction(producer);
 | 
			
		||||
  if (fusion_instruction != producer && fusion_instruction != consumer) {
 | 
			
		||||
    VLOG(2) << "       created new fusion: " << fusion_instruction->ToString();
 | 
			
		||||
  }
 | 
			
		||||
  return fusion_instruction;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -928,8 +928,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
    absl::Span<const int64> broadcast_dimensions) {
 | 
			
		||||
  VLOG(2) << StrFormat(
 | 
			
		||||
      "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
 | 
			
		||||
      HloOpcodeString(opcode), ShapeUtil::HumanString(lhs),
 | 
			
		||||
      ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", "));
 | 
			
		||||
      HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs),
 | 
			
		||||
      ShapeUtil::HumanStringWithLayout(rhs),
 | 
			
		||||
      StrJoin(broadcast_dimensions, ", "));
 | 
			
		||||
 | 
			
		||||
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
 | 
			
		||||
  TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,20 @@ Window MakeWindow(absl::Span<const int64> sizes) {
 | 
			
		||||
  return window;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Window MakeWindow(absl::Span<const int64> sizes,
 | 
			
		||||
                  absl::Span<const int64> strides) {
 | 
			
		||||
  Window window;
 | 
			
		||||
  CHECK_EQ(sizes.size(), strides.size());
 | 
			
		||||
  for (auto nb = 0; nb < sizes.size(); ++nb) {
 | 
			
		||||
    auto* dimension = window.add_dimensions();
 | 
			
		||||
    dimension->set_size(sizes[nb]);
 | 
			
		||||
    dimension->set_stride(strides[nb]);
 | 
			
		||||
    dimension->set_base_dilation(1);
 | 
			
		||||
    dimension->set_window_dilation(1);
 | 
			
		||||
  }
 | 
			
		||||
  return window;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
 | 
			
		||||
  PaddingConfig config;
 | 
			
		||||
  for (int64 size : sizes) {
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,10 @@ namespace window_util {
 | 
			
		||||
// to 1.
 | 
			
		||||
Window MakeWindow(absl::Span<const int64> sizes);
 | 
			
		||||
 | 
			
		||||
// Creates a window with the given sizes in the dimensions and given strides.
 | 
			
		||||
Window MakeWindow(absl::Span<const int64> sizes,
 | 
			
		||||
                  absl::Span<const int64> strides);
 | 
			
		||||
 | 
			
		||||
// Creates a padding config with symmetrical padding in each dimension, of value
 | 
			
		||||
// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
 | 
			
		||||
// pixels of padding in dimension 0, one pixel of padding symmetrically, on each
 | 
			
		||||
 | 
			
		||||
@ -30,5 +30,14 @@ TEST(WindowUtilTest, HasOverlappingWindowTest) {
 | 
			
		||||
      window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(WindowUtilTest, MakeWindowStrideTest) {
 | 
			
		||||
  // MakeWindow() set a stride of 1 by default.
 | 
			
		||||
  Window w = window_util::MakeWindow({1, 2}, {3, 4});
 | 
			
		||||
  EXPECT_EQ(w.dimensions()[0].size(), 1);
 | 
			
		||||
  EXPECT_EQ(w.dimensions()[1].size(), 2);
 | 
			
		||||
  EXPECT_EQ(w.dimensions()[0].stride(), 3);
 | 
			
		||||
  EXPECT_EQ(w.dimensions()[1].stride(), 4);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user