[TF2XLA] Preserve the dynamic dimension (-1) when building a reshape.
This cl is created to handle the problem of [4] -> [2, 2] reshape, where 4 is a dynamic dimension. - We need the producer of reshape's second operand to return "-1", indicating a dimension is dynamic. - Change GetDynamicSize's return type to S32 so we support '-1'. - Constant folding has been to be changed to return -1 when a dimension is dynamic and when a special flag is passed to an op kernel. - Resurrect the "dynamci dimension inference" feature in xla buidler. Expect some brokeness as this feature is not heavily exercised. PiperOrigin-RevId: 267465833
This commit is contained in:
		
							parent
							
								
									d6eab240b9
								
							
						
					
					
						commit
						792abd2eaf
					
				@ -106,13 +106,34 @@ class ReshapeOp : public XlaOpKernel {
 | 
			
		||||
                                        " values, but the requested shape has ",
 | 
			
		||||
                                        shape.num_elements()));
 | 
			
		||||
 | 
			
		||||
    VLOG(1) << "Reshape from " << input_shape.DebugString() << " to "
 | 
			
		||||
    VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
 | 
			
		||||
            << shape.DebugString() << ", unknown_index=" << unknown_index;
 | 
			
		||||
 | 
			
		||||
    shape_input.clear();
 | 
			
		||||
    // Run get input again, this time with dynamic dimension represented as
 | 
			
		||||
    // "-1"
 | 
			
		||||
    ctx->set_dynamic_dimension_is_minus_one(true);
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
 | 
			
		||||
 | 
			
		||||
    int dynamic_dimension = -1;
 | 
			
		||||
 | 
			
		||||
    for (int d = 0; d < num_dims; ++d) {
 | 
			
		||||
      const int32 size = shape_input[d];
 | 
			
		||||
      if (size == -1) {
 | 
			
		||||
        if (dynamic_dimension == -1) {
 | 
			
		||||
          dynamic_dimension = d;
 | 
			
		||||
        } else {
 | 
			
		||||
          if (unknown_index != d) {
 | 
			
		||||
            dynamic_dimension = d;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
 | 
			
		||||
    // in XLA to know which output dimension is dynamic.
 | 
			
		||||
    ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
 | 
			
		||||
                          ctx->Input(0), shape.dim_sizes(), unknown_index));
 | 
			
		||||
                          ctx->Input(0), shape.dim_sizes(), dynamic_dimension));
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,9 @@ class SizeOp : public XlaOpKernel {
 | 
			
		||||
    xla::XlaBuilder* builder = ctx->builder();
 | 
			
		||||
    auto size = xla::One(builder, xla::U32);
 | 
			
		||||
    for (int64 i = 0; i < rank; ++i) {
 | 
			
		||||
      size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
 | 
			
		||||
      size = xla::Mul(
 | 
			
		||||
          size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
 | 
			
		||||
                                        xla::U32));
 | 
			
		||||
    }
 | 
			
		||||
    size = xla::ConvertElementType(size, ctx->output_xla_type(0));
 | 
			
		||||
    ctx->SetOutput(0, size);
 | 
			
		||||
 | 
			
		||||
@ -173,6 +173,7 @@ Status BuildComputation(
 | 
			
		||||
  xla::OpMetadata retval_metadata;
 | 
			
		||||
  retval_metadata.set_op_name("XLA_Retvals");
 | 
			
		||||
  builder->SetOpMetadata(retval_metadata);
 | 
			
		||||
  VLOG(1) << "Building new computation";
 | 
			
		||||
  auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
 | 
			
		||||
 | 
			
		||||
  // Builds a no-op XLA computation. We need to set the sharding of outputs, but
 | 
			
		||||
@ -915,6 +916,9 @@ Status XlaCompiler::BuildArguments(
 | 
			
		||||
      const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
 | 
			
		||||
      for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
 | 
			
		||||
        int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
 | 
			
		||||
        VLOG(1) << "Setting dynamic binding " << i << " -> "
 | 
			
		||||
                << dynamic_size_param_index;
 | 
			
		||||
 | 
			
		||||
        TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
 | 
			
		||||
            /*dynamic_size_param_num=*/0, {dynamic_size_param_index},
 | 
			
		||||
            /*target_param_num=*/0, /*target_param_index=*/{i},
 | 
			
		||||
@ -1170,7 +1174,7 @@ Status XlaCompiler::CompileGraph(
 | 
			
		||||
    std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
 | 
			
		||||
    absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
 | 
			
		||||
    CompilationResult* result) {
 | 
			
		||||
  VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
 | 
			
		||||
  VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
 | 
			
		||||
 | 
			
		||||
  TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
 | 
			
		||||
      graph.get(), options_.flib_def, local_flib_def_.get()));
 | 
			
		||||
 | 
			
		||||
@ -102,7 +102,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
 | 
			
		||||
    xla::Client* client) const {
 | 
			
		||||
    xla::Client* client, bool dynamic_dimension_is_minus_one) const {
 | 
			
		||||
  switch (kind()) {
 | 
			
		||||
    case Kind::kConstant:
 | 
			
		||||
      return {constant_value()};
 | 
			
		||||
@ -122,7 +122,8 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
 | 
			
		||||
  if (!is_constant) return {absl::nullopt};
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
 | 
			
		||||
                      handle().builder()->BuildConstantSubGraph(handle()));
 | 
			
		||||
                      handle().builder()->BuildConstantSubGraph(
 | 
			
		||||
                          handle(), dynamic_dimension_is_minus_one));
 | 
			
		||||
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -97,7 +97,7 @@ class XlaExpression {
 | 
			
		||||
  // optional if it cannot be resolved. Returns an error if passed a resource
 | 
			
		||||
  // expression.
 | 
			
		||||
  xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
 | 
			
		||||
      xla::Client* client) const;
 | 
			
		||||
      xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
 | 
			
		||||
 | 
			
		||||
  // Returns the shape of the tensor.
 | 
			
		||||
  // The shape of a resource is the shape of a resource handle (i.e., a scalar),
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ limitations under the License.
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
 | 
			
		||||
    : context_(context) {}
 | 
			
		||||
    : context_(context), dynamic_dimension_is_minus_one_(false) {}
 | 
			
		||||
 | 
			
		||||
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
 | 
			
		||||
  return context_->ValidateInputsAreSameShape(op);
 | 
			
		||||
@ -166,7 +166,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
 | 
			
		||||
    xla::Literal* constant_literal) {
 | 
			
		||||
  XlaExpression e = InputExpression(index);
 | 
			
		||||
  xla::StatusOr<absl::optional<Tensor>> constant_or_status =
 | 
			
		||||
      e.ResolveConstant(compiler()->client());
 | 
			
		||||
      e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_);
 | 
			
		||||
  if (!constant_or_status.ok()) {
 | 
			
		||||
    Status status = constant_or_status.status();
 | 
			
		||||
    errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
 | 
			
		||||
 | 
			
		||||
@ -202,6 +202,17 @@ class XlaOpKernelContext {
 | 
			
		||||
  Status GetVariableTypeAndShape(int index, DataType* type,
 | 
			
		||||
                                 TensorShape* shape) const;
 | 
			
		||||
 | 
			
		||||
  // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
 | 
			
		||||
  // returns "-1", this is useful when the underlying ops expect explicit
 | 
			
		||||
  // dynamic index like reshape.
 | 
			
		||||
  void set_dynamic_dimension_is_minus_one(bool value) {
 | 
			
		||||
    dynamic_dimension_is_minus_one_ = value;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool dynamic_dimension_is_minus_one() const {
 | 
			
		||||
    return dynamic_dimension_is_minus_one_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Reads the current value of the resouce variable referred to by input
 | 
			
		||||
  // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
 | 
			
		||||
  // variable. Returns an error if the variable has not been initialized, or if
 | 
			
		||||
@ -280,6 +291,7 @@ class XlaOpKernelContext {
 | 
			
		||||
                               xla::Literal* constant_literal);
 | 
			
		||||
 | 
			
		||||
  OpKernelContext* const context_;
 | 
			
		||||
  bool dynamic_dimension_is_minus_one_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -213,16 +213,10 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
 | 
			
		||||
      // TODO(b/32495713): We aren't checking the called computations.
 | 
			
		||||
      break;
 | 
			
		||||
    case HloOpcode::kGetDimensionSize: {
 | 
			
		||||
      int64 dimension_number = instr.dimensions(0);
 | 
			
		||||
      const HloInstructionProto& operand =
 | 
			
		||||
          *(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
 | 
			
		||||
      Shape operand_shape(operand.shape());
 | 
			
		||||
      if (operand_shape.is_dynamic_dimension(dimension_number)) {
 | 
			
		||||
        *is_constant = false;
 | 
			
		||||
      }
 | 
			
		||||
      // DimensionSize is always considered constant in XLA -- If a dynamic
 | 
			
		||||
      // dimension is presented, uint_max is returned.
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Non functional ops.
 | 
			
		||||
    case HloOpcode::kRng:
 | 
			
		||||
    case HloOpcode::kAllReduce:
 | 
			
		||||
@ -268,8 +262,8 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
 | 
			
		||||
      for (int64 index : target_param_index) {
 | 
			
		||||
        param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
 | 
			
		||||
      }
 | 
			
		||||
      // TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA
 | 
			
		||||
      // backend can handle dynamic dimensions.
 | 
			
		||||
      param_shape_ptr->set_dynamic_dimension(target_dim_num,
 | 
			
		||||
                                             /*is_dynamic=*/true);
 | 
			
		||||
      *instr.mutable_shape() = param_shape.ToProto();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -435,6 +429,7 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
 | 
			
		||||
  for (int64 dim : broadcast_dimensions) {
 | 
			
		||||
    instr.add_dimensions(dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -468,11 +463,21 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
 | 
			
		||||
          << operand_shape << "; output_shape: " << output_shape;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Shape reshaped_shape =
 | 
			
		||||
      ShapeUtil::MakeShape(operand_shape.element_type(), reshaped_dimensions);
 | 
			
		||||
 | 
			
		||||
  std::vector<std::pair<int64, int64>> unmodified_dims =
 | 
			
		||||
      ShapeUtil::DimensionsUnmodifiedByReshape(operand_shape, reshaped_shape);
 | 
			
		||||
 | 
			
		||||
  for (auto& unmodified : unmodified_dims) {
 | 
			
		||||
    if (operand_shape.is_dynamic_dimension(unmodified.first)) {
 | 
			
		||||
      reshaped_shape.set_dynamic_dimension(unmodified.second, true);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Eliminate the size one dimensions.
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
 | 
			
		||||
                      Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
 | 
			
		||||
                                                   reshaped_dimensions),
 | 
			
		||||
                              operand));
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
 | 
			
		||||
  // Broadcast 'reshape' up to the larger size.
 | 
			
		||||
  return InDimBroadcast(broadcast_shape, reshaped_operand,
 | 
			
		||||
                        broadcast_dimensions);
 | 
			
		||||
@ -2428,7 +2433,7 @@ StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
 | 
			
		||||
    const XlaOp& root_op) {
 | 
			
		||||
    XlaOp root_op, bool dynamic_dimension_is_minus_one) {
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
 | 
			
		||||
  if (!is_constant) {
 | 
			
		||||
    auto op_status = LookUpInstruction(root_op);
 | 
			
		||||
@ -2483,9 +2488,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
 | 
			
		||||
      TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
 | 
			
		||||
                          LookUpInstructionByHandle(operand_handle));
 | 
			
		||||
 | 
			
		||||
      TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension));
 | 
			
		||||
      auto constant_dimension_size =
 | 
			
		||||
          static_cast<uint32>(operand_proto->shape().dimensions(dimension));
 | 
			
		||||
      int32 constant_dimension_size = -1;
 | 
			
		||||
      if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
 | 
			
		||||
            dynamic_dimension_is_minus_one)) {
 | 
			
		||||
        constant_dimension_size =
 | 
			
		||||
            static_cast<int32>(operand_proto->shape().dimensions(dimension));
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -258,7 +258,8 @@ class XlaBuilder {
 | 
			
		||||
  // compile-time constant (see `IsConstant`), returns an error.
 | 
			
		||||
  //
 | 
			
		||||
  // This will copy the needed ops/computations to the subgraph.
 | 
			
		||||
  StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
 | 
			
		||||
  StatusOr<XlaComputation> BuildConstantSubGraph(
 | 
			
		||||
      XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
 | 
			
		||||
 | 
			
		||||
  // Returns the first error that was encountered while building the
 | 
			
		||||
  // computation. When an error is encountered, by default we return a vacuous
 | 
			
		||||
 | 
			
		||||
@ -917,10 +917,7 @@ TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
 | 
			
		||||
  auto gte1 = GetTupleElement(p0, 1);  // f32[4,5,<=6]
 | 
			
		||||
  Select(pred, gte0, gte1);
 | 
			
		||||
  Status status = BuildHloModule(&b).status();
 | 
			
		||||
  ASSERT_IS_NOT_OK(status);
 | 
			
		||||
  EXPECT_THAT(status.error_message(),
 | 
			
		||||
              ::testing::HasSubstr("Operands to select must be the same shape; "
 | 
			
		||||
                                   "got f32[4,<=5,6] and f32[4,5,<=6]"));
 | 
			
		||||
  ASSERT_IS_OK(status);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(XlaBuilderTest, DynamicTranspose) {
 | 
			
		||||
 | 
			
		||||
@ -105,8 +105,8 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
 | 
			
		||||
      HloInstruction* operand, int64 feature_index,
 | 
			
		||||
      const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
 | 
			
		||||
          add_instruction) {
 | 
			
		||||
    auto elements_per_feature_u32 = add_instruction(
 | 
			
		||||
        HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
 | 
			
		||||
    auto elements_per_feature_s32 = add_instruction(
 | 
			
		||||
        HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
 | 
			
		||||
 | 
			
		||||
    for (int64 i = 0; i < operand->shape().rank(); ++i) {
 | 
			
		||||
      if (i == feature_index) {
 | 
			
		||||
@ -114,15 +114,15 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
 | 
			
		||||
      }
 | 
			
		||||
      auto dynamic_dimension_size =
 | 
			
		||||
          add_instruction(HloInstruction::CreateGetDimensionSize(
 | 
			
		||||
              ShapeUtil::MakeShape(U32, {}), operand, i));
 | 
			
		||||
      elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary(
 | 
			
		||||
          ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply,
 | 
			
		||||
          dynamic_dimension_size, elements_per_feature_u32));
 | 
			
		||||
              ShapeUtil::MakeShape(S32, {}), operand, i));
 | 
			
		||||
      elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary(
 | 
			
		||||
          ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
 | 
			
		||||
          dynamic_dimension_size, elements_per_feature_s32));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return HloInstruction::CreateConvert(
 | 
			
		||||
        ShapeUtil::MakeShape(operand->shape().element_type(), {}),
 | 
			
		||||
        elements_per_feature_u32);
 | 
			
		||||
        elements_per_feature_s32);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Current HloComputation instance the BatchNormExpander is
 | 
			
		||||
 | 
			
		||||
@ -463,7 +463,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
 | 
			
		||||
                reshape->shape().dimensions(0) / operand->shape().dimensions(0);
 | 
			
		||||
            HloInstruction* multiplier_hlo =
 | 
			
		||||
                hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
                    LiteralUtil::CreateR0<uint32>(multiplier)));
 | 
			
		||||
                    LiteralUtil::CreateR0<int32>(multiplier)));
 | 
			
		||||
 | 
			
		||||
            HloInstruction* new_dynamic_size =
 | 
			
		||||
                hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
@ -638,7 +638,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
 | 
			
		||||
                reshape->shape().dimensions(dynamic_dimension);
 | 
			
		||||
            HloInstruction* divisor_hlo =
 | 
			
		||||
                hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
 | 
			
		||||
                    LiteralUtil::CreateR0<uint32>(divisor)));
 | 
			
		||||
                    LiteralUtil::CreateR0<int32>(divisor)));
 | 
			
		||||
 | 
			
		||||
            HloInstruction* new_dynamic_size =
 | 
			
		||||
                hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,7 @@ class DynamicDimensionInferenceTest : public HloTestBase {
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<HloModule> module_;
 | 
			
		||||
  std::unique_ptr<DynamicDimensionInference> inference_;
 | 
			
		||||
  const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
 | 
			
		||||
  const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(DynamicDimensionInferenceTest, ParamTest) {
 | 
			
		||||
@ -557,7 +557,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
 | 
			
		||||
  EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
 | 
			
		||||
  const Literal& multiplier =
 | 
			
		||||
      inference_->GetDynamicSize(reshape, {}, 0)->operand(1)->literal();
 | 
			
		||||
  LiteralTestUtil::ExpectR0Equal<uint32>(10, multiplier);
 | 
			
		||||
  LiteralTestUtil::ExpectR0Equal<int32>(10, multiplier);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(DynamicDimensionInferenceTest, GatherTest) {
 | 
			
		||||
@ -895,7 +895,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
 | 
			
		||||
  std::vector<HloInstruction*> params;
 | 
			
		||||
  for (int i = 0; i < 2; ++i) {
 | 
			
		||||
    params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
 | 
			
		||||
        i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
 | 
			
		||||
        i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
 | 
			
		||||
@ -997,7 +997,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
 | 
			
		||||
  std::vector<HloInstruction*> params;
 | 
			
		||||
  for (int i = 0; i < 2; ++i) {
 | 
			
		||||
    params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
 | 
			
		||||
        i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
 | 
			
		||||
        i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
 | 
			
		||||
 | 
			
		||||
@ -164,7 +164,7 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
 | 
			
		||||
          // mask and pad value.
 | 
			
		||||
          //
 | 
			
		||||
          const Shape mask_shape =
 | 
			
		||||
              ShapeUtil::ChangeElementType(operand->shape(), xla::U32);
 | 
			
		||||
              ShapeUtil::ChangeElementType(operand->shape(), xla::S32);
 | 
			
		||||
          const Shape pred_shape =
 | 
			
		||||
              ShapeUtil::ChangeElementType(operand->shape(), xla::PRED);
 | 
			
		||||
          HloInstruction* iota = computation->AddInstruction(
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ class DynamicPadderTest : public HloTestBase {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<HloModule> module_;
 | 
			
		||||
  const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
 | 
			
		||||
  const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(DynamicPadderTest, ReduceTest) {
 | 
			
		||||
 | 
			
		||||
@ -410,9 +410,9 @@ Status HloEvaluator::HandleGetDimensionSize(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const Shape& shape = get_dimension_size->operand(0)->shape();
 | 
			
		||||
  Literal output(ShapeUtil::MakeShape(U32, {}));
 | 
			
		||||
  Literal output(ShapeUtil::MakeShape(S32, {}));
 | 
			
		||||
  output.PopulateWithValue(
 | 
			
		||||
      static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
 | 
			
		||||
      static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
 | 
			
		||||
  evaluated_[get_dimension_size] = std::move(output);
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4154,13 +4154,13 @@ TEST_F(HloEvaluatorTest, GetDimensionSize) {
 | 
			
		||||
HloModule Test
 | 
			
		||||
 | 
			
		||||
ENTRY main {
 | 
			
		||||
  size = u32[] parameter(0)
 | 
			
		||||
  size = s32[] parameter(0)
 | 
			
		||||
 | 
			
		||||
  data = s32[4] parameter(1)
 | 
			
		||||
 | 
			
		||||
  sum = s32[4] add(data, data)
 | 
			
		||||
 | 
			
		||||
  ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
 | 
			
		||||
  ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
 | 
			
		||||
}
 | 
			
		||||
)";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
 | 
			
		||||
@ -4174,12 +4174,12 @@ ENTRY main {
 | 
			
		||||
                          DynamicDimensionInference::Run(m_.get()));
 | 
			
		||||
 | 
			
		||||
  evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
 | 
			
		||||
  Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
 | 
			
		||||
  Literal size_arg = LiteralUtil::CreateR0<int32>(3);
 | 
			
		||||
  Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
 | 
			
		||||
  EXPECT_EQ(actual.GetFirstElement<int32>(), static_cast<int32>(3));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check that we get a useful error if we pass inputs of the wrong shape.
 | 
			
		||||
 | 
			
		||||
@ -37,8 +37,10 @@ StatusOr<bool> ReplaceGetSize(
 | 
			
		||||
  TF_ASSIGN_OR_RETURN(auto legal_shape,
 | 
			
		||||
                      ShapeInference::InferGetDimensionSizeShape(
 | 
			
		||||
                          instr->operand(0)->shape(), instr->dimension()));
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape));
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32));
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
 | 
			
		||||
      << "instr->shape() " << instr->shape().ToString() << " , "
 | 
			
		||||
      << "legal_shape " << legal_shape.ToString();
 | 
			
		||||
  TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
 | 
			
		||||
  HloInstruction* operand = instr->mutable_operand(0);
 | 
			
		||||
  int64 dim = instr->dimension();
 | 
			
		||||
  HloInstruction* dynamic_size =
 | 
			
		||||
@ -46,9 +48,9 @@ StatusOr<bool> ReplaceGetSize(
 | 
			
		||||
  if (dynamic_size != nullptr) {
 | 
			
		||||
    TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
 | 
			
		||||
  } else {
 | 
			
		||||
    uint32 size = instr->operand(0)->shape().dimensions(dim);
 | 
			
		||||
    int32 size = instr->operand(0)->shape().dimensions(dim);
 | 
			
		||||
    HloInstruction* new_instr = computation->AddInstruction(
 | 
			
		||||
        HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(size)));
 | 
			
		||||
        HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
 | 
			
		||||
    TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
 | 
			
		||||
@ -44,9 +44,9 @@ TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
 | 
			
		||||
HloModule _
 | 
			
		||||
ENTRY gds {
 | 
			
		||||
  p = s32[3,4] parameter(0)
 | 
			
		||||
  size0 = u32[] get-dimension-size(p), dimensions={0}
 | 
			
		||||
  size1 = u32[] get-dimension-size(p), dimensions={1}
 | 
			
		||||
  ROOT mul = u32[] multiply(size0, size1)
 | 
			
		||||
  size0 = s32[] get-dimension-size(p), dimensions={0}
 | 
			
		||||
  size1 = s32[] get-dimension-size(p), dimensions={1}
 | 
			
		||||
  ROOT mul = s32[] multiply(size0, size1)
 | 
			
		||||
})")
 | 
			
		||||
                    .ValueOrDie();
 | 
			
		||||
  HloGetDimensionSizeRewriter pass;
 | 
			
		||||
@ -72,7 +72,7 @@ TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
 | 
			
		||||
HloModule _
 | 
			
		||||
ENTRY gds {
 | 
			
		||||
  p = f32[2,5] parameter(0)
 | 
			
		||||
  ROOT gds = u32[] get-dimension-size(p), dimensions={2}
 | 
			
		||||
  ROOT gds = s32[] get-dimension-size(p), dimensions={2}
 | 
			
		||||
})")
 | 
			
		||||
                    .ValueOrDie();
 | 
			
		||||
  HloGetDimensionSizeRewriter pass;
 | 
			
		||||
 | 
			
		||||
@ -652,9 +652,7 @@ Status ValidateDotDimensionNumbers(
 | 
			
		||||
    const int64 rhs_contracting_dimension =
 | 
			
		||||
        dimension_numbers.rhs_contracting_dimensions(i);
 | 
			
		||||
    if (lhs.dimensions(lhs_contracting_dimension) !=
 | 
			
		||||
            rhs.dimensions(rhs_contracting_dimension) ||
 | 
			
		||||
        lhs.is_dynamic_dimension(lhs_contracting_dimension) !=
 | 
			
		||||
            rhs.is_dynamic_dimension(rhs_contracting_dimension)) {
 | 
			
		||||
        rhs.dimensions(rhs_contracting_dimension)) {
 | 
			
		||||
      return fail("Contracting dimension sizes do not match.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -668,10 +666,7 @@ Status ValidateDotDimensionNumbers(
 | 
			
		||||
  // Check that batch dimension numbers and sizes match.
 | 
			
		||||
  for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
 | 
			
		||||
    if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
 | 
			
		||||
            rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) ||
 | 
			
		||||
        lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) !=
 | 
			
		||||
            rhs.is_dynamic_dimension(
 | 
			
		||||
                dimension_numbers.rhs_batch_dimensions(i))) {
 | 
			
		||||
        rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
 | 
			
		||||
      return fail("Batch dimension sizes must match for lhs/rhs.");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -726,13 +721,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
  for (int64 i = 0; i < lhs.rank(); ++i) {
 | 
			
		||||
    if (lhs.dimensions(i) == rhs.dimensions(i)) {
 | 
			
		||||
      output_dimensions[i] = lhs.dimensions(i);
 | 
			
		||||
      output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
 | 
			
		||||
    } else if (lhs.dimensions(i) == 1) {
 | 
			
		||||
      output_dimensions[i] = rhs.dimensions(i);
 | 
			
		||||
      output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i);
 | 
			
		||||
    } else if (rhs.dimensions(i) == 1) {
 | 
			
		||||
      output_dimensions[i] = lhs.dimensions(i);
 | 
			
		||||
      output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
 | 
			
		||||
    } else {
 | 
			
		||||
      return InvalidArgument(
 | 
			
		||||
          "Binary op %s with incompatible shapes: %s and %s.",
 | 
			
		||||
@ -740,6 +732,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
          ShapeUtil::HumanString(rhs));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Merge dynamic dimensions from two shapes.
 | 
			
		||||
  for (int64 i = 0; i < rhs.rank(); ++i) {
 | 
			
		||||
    if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
 | 
			
		||||
      output_dimensions_is_dynamic[i] = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
 | 
			
		||||
                              output_dimensions, output_dimensions_is_dynamic);
 | 
			
		||||
}
 | 
			
		||||
@ -888,11 +888,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
  if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
 | 
			
		||||
    // If the shapes are the same other than layout, the output shape is the
 | 
			
		||||
    // same (elementwise op).
 | 
			
		||||
    return ShapeUtil::ChangeElementType(
 | 
			
		||||
    Shape result = ShapeUtil::ChangeElementType(
 | 
			
		||||
        lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (lhs.rank() == rhs.rank()) {
 | 
			
		||||
    for (int64 i = 0; i < rhs.rank(); ++i) {
 | 
			
		||||
      if (rhs.is_dynamic_dimension(i)) {
 | 
			
		||||
        result.set_dynamic_dimension(i, true);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
 | 
			
		||||
  } else if (lhs.rank() == rhs.rank()) {
 | 
			
		||||
    return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Ranks do not match, so perform InDim broadcasting using
 | 
			
		||||
@ -2201,14 +2208,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
 | 
			
		||||
  // TODO(b/119580730): Remove this restriction when very large dimension size
 | 
			
		||||
  // is needed.
 | 
			
		||||
  if (shape.dimensions(dimension) > std::numeric_limits<uint32>::max()) {
 | 
			
		||||
  if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
 | 
			
		||||
    return InvalidArgument(
 | 
			
		||||
        "GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
 | 
			
		||||
        "UINT_MAX limit.",
 | 
			
		||||
        "INT_MAX limit.",
 | 
			
		||||
        ShapeUtil::HumanString(shape), dimension);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return ShapeUtil::MakeShape(U32, {});
 | 
			
		||||
  return ShapeUtil::MakeShape(S32, {});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
 | 
			
		||||
@ -2324,7 +2331,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
 | 
			
		||||
    sizes.push_back((limit_index - start_index + stride - 1) / stride);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return ShapeUtil::MakeShape(arg.element_type(), sizes);
 | 
			
		||||
  std::vector<bool> is_dynamic(arg.rank());
 | 
			
		||||
  for (int64 i = 0; i < arg.dimensions_size(); ++i) {
 | 
			
		||||
    is_dynamic[i] = arg.is_dynamic_dimension(i);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
 | 
			
		||||
 | 
			
		||||
@ -507,17 +507,23 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
 | 
			
		||||
  return Shape::Equal().IgnoreLayout()(lhs, rhs);
 | 
			
		||||
  return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
 | 
			
		||||
                                                           const Shape& rhs) {
 | 
			
		||||
  return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs);
 | 
			
		||||
  return Shape::Equal()
 | 
			
		||||
      .IgnoreDynamicDimension()
 | 
			
		||||
      .IgnoreElementType()
 | 
			
		||||
      .IgnoreLayout()(lhs, rhs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
 | 
			
		||||
                                                           const Shape& rhs) {
 | 
			
		||||
  return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs);
 | 
			
		||||
  return Shape::Equal()
 | 
			
		||||
      .IgnoreDynamicDimension()
 | 
			
		||||
      .IgnoreFpPrecision()
 | 
			
		||||
      .IgnoreLayout()(lhs, rhs);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/compiler/xla/shape_util.h"
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "absl/strings/str_join.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/layout_util.h"
 | 
			
		||||
@ -195,7 +196,7 @@ TEST(ShapeUtilTest, CompatibleDynamicShapes) {
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b));
 | 
			
		||||
  EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_c));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ShapeUtilTest, CompatibleTuples) {
 | 
			
		||||
 | 
			
		||||
@ -737,7 +737,7 @@ def _pad_all_input(inputs, padded_shapes):
 | 
			
		||||
              padding_map.padding_arg_index = real_shape_idx
 | 
			
		||||
              padding_maps.append(padding_map)
 | 
			
		||||
            real_shapes[core_idx].append(
 | 
			
		||||
                math_ops.cast(input_shape_tensor[i], dtypes.uint32))
 | 
			
		||||
                math_ops.cast(input_shape_tensor[i], dtypes.int32))
 | 
			
		||||
 | 
			
		||||
        paddings = []
 | 
			
		||||
        for i, s in enumerate(padded_shape.dims):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user