Bug fix; op output validation for aggregate stats and best feature split
PiperOrigin-RevId: 266181494
This commit is contained in:
		
							parent
							
								
									b32c5e1e35
								
							
						
					
					
						commit
						c3d844efc8
					
				| @ -126,10 +126,18 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit") | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); | ||||
|       ShapeHandle output_shape = c->MakeShape({c->UnknownDim()}); | ||||
|       for (int i = 0; i < 7; ++i) { | ||||
|         c->set_output(i, output_shape); | ||||
|       } | ||||
|       ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); | ||||
|       c->set_output(0, rank_1_output_shape); | ||||
|       c->set_output(1, rank_1_output_shape); | ||||
|       c->set_output(2, rank_1_output_shape); | ||||
|       c->set_output(3, rank_1_output_shape); | ||||
|       c->set_output(6, rank_1_output_shape); | ||||
|       int logits_dimension; | ||||
|       TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); | ||||
|       ShapeHandle contribs_output_shape = | ||||
|           c->MakeShape({c->UnknownDim(), logits_dimension}); | ||||
|       c->set_output(4, contribs_output_shape); | ||||
|       c->set_output(5, contribs_output_shape); | ||||
|       return Status::OK(); | ||||
|     }); | ||||
| 
 | ||||
| @ -168,10 +176,18 @@ REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit") | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape)); | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape)); | ||||
|       ShapeHandle output_shape = c->MakeShape({-1}); | ||||
|       for (int i = 0; i < 7; ++i) { | ||||
|         c->set_output(i, output_shape); | ||||
|       } | ||||
|       ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); | ||||
|       c->set_output(0, rank_1_output_shape); | ||||
|       c->set_output(1, rank_1_output_shape); | ||||
|       c->set_output(2, rank_1_output_shape); | ||||
|       c->set_output(3, rank_1_output_shape); | ||||
|       c->set_output(6, rank_1_output_shape); | ||||
|       int logits_dimension; | ||||
|       TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); | ||||
|       ShapeHandle contribs_output_shape = | ||||
|           c->MakeShape({c->UnknownDim(), logits_dimension}); | ||||
|       c->set_output(4, contribs_output_shape); | ||||
|       c->set_output(5, contribs_output_shape); | ||||
|       return Status::OK(); | ||||
|     }); | ||||
| 
 | ||||
| @ -301,8 +317,9 @@ REGISTER_OP("BoostedTreesAggregateStats") | ||||
|       DimensionHandle feature_dim = c->Dim(c->input(3), 1); | ||||
|       DimensionHandle stats_dim; | ||||
|       TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim)); | ||||
|       c->set_output( | ||||
|           0, c->MakeShape({max_splits, feature_dim, num_buckets, stats_dim})); | ||||
|       c->set_output(0, c->MakeShape({max_splits, feature_dim, | ||||
|                                      num_buckets + 1,  // +1 for missing bucket.
 | ||||
|                                      stats_dim})); | ||||
|       return Status::OK(); | ||||
|     }); | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user