Merge pull request #33257 from nouiz:small2
PiperOrigin-RevId: 274526704
This commit is contained in:
		
						commit
						7557b548ca
					
				| @ -120,6 +120,21 @@ class Array4D : public Array<T> { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // Fills all of the {p,x} with the array provided, which specifies {z,y}.
 | ||||||
|  |   void FillWithZY(const Array2D<T>& value) { | ||||||
|  |     CHECK_EQ(value.height(), depth()); | ||||||
|  |     CHECK_EQ(value.width(), height()); | ||||||
|  |     for (int64 plane = 0; plane < planes(); ++plane) { | ||||||
|  |       for (int64 depth = 0; depth < this->depth(); ++depth) { | ||||||
|  |         for (int64 height = 0; height < this->height(); ++height) { | ||||||
|  |           for (int64 width = 0; width < this->width(); ++width) { | ||||||
|  |             (*this)(plane, depth, height, width) = value(depth, height); | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   // Fills all of the {x,y} with the array provided, which specifies {p,z}.
 |   // Fills all of the {x,y} with the array provided, which specifies {p,z}.
 | ||||||
|   void FillWithPZ(const Array2D<T>& value) { |   void FillWithPZ(const Array2D<T>& value) { | ||||||
|     CHECK_EQ(value.height(), planes()); |     CHECK_EQ(value.height(), planes()); | ||||||
|  | |||||||
| @ -40,7 +40,7 @@ class HloPassFix : public Pass { | |||||||
|     int64 iteration_count = 0; |     int64 iteration_count = 0; | ||||||
|     int64 limit = |     int64 limit = | ||||||
|         std::max(static_cast<int64>(1000), module->instruction_count()); |         std::max(static_cast<int64>(1000), module->instruction_count()); | ||||||
|     VLOG(3) << "Running HloPassFix."; |     VLOG(3) << "Running HloPassFix on " << Pass::name(); | ||||||
|     while (changed_this_iteration) { |     while (changed_this_iteration) { | ||||||
|       TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); |       TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module)); | ||||||
|       changed |= changed_this_iteration; |       changed |= changed_this_iteration; | ||||||
|  | |||||||
| @ -619,6 +619,9 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, | |||||||
|           << consumer->ToString(); |           << consumer->ToString(); | ||||||
|   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); |   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); | ||||||
|   fusion_instruction->FuseInstruction(producer); |   fusion_instruction->FuseInstruction(producer); | ||||||
|  |   if (fusion_instruction != producer && fusion_instruction != consumer) { | ||||||
|  |     VLOG(2) << "       created new fusion: " << fusion_instruction->ToString(); | ||||||
|  |   } | ||||||
|   return fusion_instruction; |   return fusion_instruction; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -928,8 +928,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, | |||||||
|     absl::Span<const int64> broadcast_dimensions) { |     absl::Span<const int64> broadcast_dimensions) { | ||||||
|   VLOG(2) << StrFormat( |   VLOG(2) << StrFormat( | ||||||
|       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", |       "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", | ||||||
|       HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), |       HloOpcodeString(opcode), ShapeUtil::HumanStringWithLayout(lhs), | ||||||
|       ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); |       ShapeUtil::HumanStringWithLayout(rhs), | ||||||
|  |       StrJoin(broadcast_dimensions, ", ")); | ||||||
|  | 
 | ||||||
|   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); |   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); | ||||||
|   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); |   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -38,6 +38,20 @@ Window MakeWindow(absl::Span<const int64> sizes) { | |||||||
|   return window; |   return window; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | Window MakeWindow(absl::Span<const int64> sizes, | ||||||
|  |                   absl::Span<const int64> strides) { | ||||||
|  |   Window window; | ||||||
|  |   CHECK_EQ(sizes.size(), strides.size()); | ||||||
|  |   for (auto nb = 0; nb < sizes.size(); ++nb) { | ||||||
|  |     auto* dimension = window.add_dimensions(); | ||||||
|  |     dimension->set_size(sizes[nb]); | ||||||
|  |     dimension->set_stride(strides[nb]); | ||||||
|  |     dimension->set_base_dilation(1); | ||||||
|  |     dimension->set_window_dilation(1); | ||||||
|  |   } | ||||||
|  |   return window; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) { | PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) { | ||||||
|   PaddingConfig config; |   PaddingConfig config; | ||||||
|   for (int64 size : sizes) { |   for (int64 size : sizes) { | ||||||
|  | |||||||
| @ -27,6 +27,10 @@ namespace window_util { | |||||||
| // to 1.
 | // to 1.
 | ||||||
| Window MakeWindow(absl::Span<const int64> sizes); | Window MakeWindow(absl::Span<const int64> sizes); | ||||||
| 
 | 
 | ||||||
|  | // Creates a window with the given sizes in the dimensions and given strides.
 | ||||||
|  | Window MakeWindow(absl::Span<const int64> sizes, | ||||||
|  |                   absl::Span<const int64> strides); | ||||||
|  | 
 | ||||||
| // Creates a padding config with symmetrical padding in each dimension, of value
 | // Creates a padding config with symmetrical padding in each dimension, of value
 | ||||||
| // given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
 | // given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero
 | ||||||
| // pixels of padding in dimension 0, one pixel of padding symmetrically, on each
 | // pixels of padding in dimension 0, one pixel of padding symmetrically, on each
 | ||||||
|  | |||||||
| @ -30,5 +30,14 @@ TEST(WindowUtilTest, HasOverlappingWindowTest) { | |||||||
|       window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2}))); |       window_util::HasOverlappingWindow(window_util::MakeWindow({2, 2, 2, 2}))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(WindowUtilTest, MakeWindowStrideTest) { | ||||||
|  |   // MakeWindow() set a stride of 1 by default.
 | ||||||
|  |   Window w = window_util::MakeWindow({1, 2}, {3, 4}); | ||||||
|  |   EXPECT_EQ(w.dimensions()[0].size(), 1); | ||||||
|  |   EXPECT_EQ(w.dimensions()[1].size(), 2); | ||||||
|  |   EXPECT_EQ(w.dimensions()[0].stride(), 3); | ||||||
|  |   EXPECT_EQ(w.dimensions()[1].stride(), 4); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace xla
 | }  // namespace xla
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user