[TF2XLA] Preserve the dynamic dimension (-1) when building a reshape.
This cl is created to handle the problem of [4] -> [2, 2] reshape, where 4 is a dynamic dimension. - We need the producer of reshape's second operand to return "-1", indicating a dimension is dynamic. - Change GetDynamicSize's return type to S32 so we support '-1'. - Constant folding has been to be changed to return -1 when a dimension is dynamic and when a special flag is passed to an op kernel. - Resurrect the "dynamci dimension inference" feature in xla buidler. Expect some brokeness as this feature is not heavily exercised. PiperOrigin-RevId: 267465833
This commit is contained in:
parent
d6eab240b9
commit
792abd2eaf
@ -106,13 +106,34 @@ class ReshapeOp : public XlaOpKernel {
|
||||
" values, but the requested shape has ",
|
||||
shape.num_elements()));
|
||||
|
||||
VLOG(1) << "Reshape from " << input_shape.DebugString() << " to "
|
||||
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
|
||||
<< shape.DebugString() << ", unknown_index=" << unknown_index;
|
||||
|
||||
shape_input.clear();
|
||||
// Run get input again, this time with dynamic dimension represented as
|
||||
// "-1"
|
||||
ctx->set_dynamic_dimension_is_minus_one(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
|
||||
|
||||
int dynamic_dimension = -1;
|
||||
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = shape_input[d];
|
||||
if (size == -1) {
|
||||
if (dynamic_dimension == -1) {
|
||||
dynamic_dimension = d;
|
||||
} else {
|
||||
if (unknown_index != d) {
|
||||
dynamic_dimension = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
|
||||
// in XLA to know which output dimension is dynamic.
|
||||
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
|
||||
ctx->Input(0), shape.dim_sizes(), unknown_index));
|
||||
ctx->Input(0), shape.dim_sizes(), dynamic_dimension));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -119,7 +119,9 @@ class SizeOp : public XlaOpKernel {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto size = xla::One(builder, xla::U32);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
|
||||
size = xla::Mul(
|
||||
size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
|
||||
xla::U32));
|
||||
}
|
||||
size = xla::ConvertElementType(size, ctx->output_xla_type(0));
|
||||
ctx->SetOutput(0, size);
|
||||
|
@ -173,6 +173,7 @@ Status BuildComputation(
|
||||
xla::OpMetadata retval_metadata;
|
||||
retval_metadata.set_op_name("XLA_Retvals");
|
||||
builder->SetOpMetadata(retval_metadata);
|
||||
VLOG(1) << "Building new computation";
|
||||
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
|
||||
|
||||
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
|
||||
@ -915,6 +916,9 @@ Status XlaCompiler::BuildArguments(
|
||||
const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
|
||||
for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
|
||||
int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
|
||||
VLOG(1) << "Setting dynamic binding " << i << " -> "
|
||||
<< dynamic_size_param_index;
|
||||
|
||||
TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
|
||||
/*dynamic_size_param_num=*/0, {dynamic_size_param_index},
|
||||
/*target_param_num=*/0, /*target_param_index=*/{i},
|
||||
@ -1170,7 +1174,7 @@ Status XlaCompiler::CompileGraph(
|
||||
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
|
||||
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
|
||||
CompilationResult* result) {
|
||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
|
||||
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
|
||||
|
||||
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
|
||||
graph.get(), options_.flib_def, local_flib_def_.get()));
|
||||
|
@ -102,7 +102,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
||||
}
|
||||
|
||||
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::Client* client) const {
|
||||
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
|
||||
switch (kind()) {
|
||||
case Kind::kConstant:
|
||||
return {constant_value()};
|
||||
@ -122,7 +122,8 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
if (!is_constant) return {absl::nullopt};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
|
||||
handle().builder()->BuildConstantSubGraph(handle()));
|
||||
handle().builder()->BuildConstantSubGraph(
|
||||
handle(), dynamic_dimension_is_minus_one));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
|
||||
|
||||
|
@ -97,7 +97,7 @@ class XlaExpression {
|
||||
// optional if it cannot be resolved. Returns an error if passed a resource
|
||||
// expression.
|
||||
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
|
||||
xla::Client* client) const;
|
||||
xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
|
||||
|
||||
// Returns the shape of the tensor.
|
||||
// The shape of a resource is the shape of a resource handle (i.e., a scalar),
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
|
||||
: context_(context) {}
|
||||
: context_(context), dynamic_dimension_is_minus_one_(false) {}
|
||||
|
||||
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
|
||||
return context_->ValidateInputsAreSameShape(op);
|
||||
@ -166,7 +166,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
xla::Literal* constant_literal) {
|
||||
XlaExpression e = InputExpression(index);
|
||||
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
|
||||
e.ResolveConstant(compiler()->client());
|
||||
e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_);
|
||||
if (!constant_or_status.ok()) {
|
||||
Status status = constant_or_status.status();
|
||||
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
|
||||
|
@ -202,6 +202,17 @@ class XlaOpKernelContext {
|
||||
Status GetVariableTypeAndShape(int index, DataType* type,
|
||||
TensorShape* shape) const;
|
||||
|
||||
// When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
|
||||
// returns "-1", this is useful when the underlying ops expect explicit
|
||||
// dynamic index like reshape.
|
||||
void set_dynamic_dimension_is_minus_one(bool value) {
|
||||
dynamic_dimension_is_minus_one_ = value;
|
||||
}
|
||||
|
||||
bool dynamic_dimension_is_minus_one() const {
|
||||
return dynamic_dimension_is_minus_one_;
|
||||
}
|
||||
|
||||
// Reads the current value of the resouce variable referred to by input
|
||||
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
|
||||
// variable. Returns an error if the variable has not been initialized, or if
|
||||
@ -280,6 +291,7 @@ class XlaOpKernelContext {
|
||||
xla::Literal* constant_literal);
|
||||
|
||||
OpKernelContext* const context_;
|
||||
bool dynamic_dimension_is_minus_one_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -213,16 +213,10 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
// TODO(b/32495713): We aren't checking the called computations.
|
||||
break;
|
||||
case HloOpcode::kGetDimensionSize: {
|
||||
int64 dimension_number = instr.dimensions(0);
|
||||
const HloInstructionProto& operand =
|
||||
*(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
|
||||
Shape operand_shape(operand.shape());
|
||||
if (operand_shape.is_dynamic_dimension(dimension_number)) {
|
||||
*is_constant = false;
|
||||
}
|
||||
// DimensionSize is always considered constant in XLA -- If a dynamic
|
||||
// dimension is presented, uint_max is returned.
|
||||
break;
|
||||
}
|
||||
|
||||
// Non functional ops.
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kAllReduce:
|
||||
@ -268,8 +262,8 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
|
||||
for (int64 index : target_param_index) {
|
||||
param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
|
||||
}
|
||||
// TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA
|
||||
// backend can handle dynamic dimensions.
|
||||
param_shape_ptr->set_dynamic_dimension(target_dim_num,
|
||||
/*is_dynamic=*/true);
|
||||
*instr.mutable_shape() = param_shape.ToProto();
|
||||
}
|
||||
}
|
||||
@ -435,6 +429,7 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
|
||||
for (int64 dim : broadcast_dimensions) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
|
||||
}
|
||||
|
||||
@ -468,11 +463,21 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||
<< operand_shape << "; output_shape: " << output_shape;
|
||||
}
|
||||
}
|
||||
|
||||
Shape reshaped_shape =
|
||||
ShapeUtil::MakeShape(operand_shape.element_type(), reshaped_dimensions);
|
||||
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(operand_shape, reshaped_shape);
|
||||
|
||||
for (auto& unmodified : unmodified_dims) {
|
||||
if (operand_shape.is_dynamic_dimension(unmodified.first)) {
|
||||
reshaped_shape.set_dynamic_dimension(unmodified.second, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate the size one dimensions.
|
||||
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
|
||||
Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
|
||||
reshaped_dimensions),
|
||||
operand));
|
||||
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
|
||||
// Broadcast 'reshape' up to the larger size.
|
||||
return InDimBroadcast(broadcast_shape, reshaped_operand,
|
||||
broadcast_dimensions);
|
||||
@ -2428,7 +2433,7 @@ StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
|
||||
}
|
||||
|
||||
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
const XlaOp& root_op) {
|
||||
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
|
||||
if (!is_constant) {
|
||||
auto op_status = LookUpInstruction(root_op);
|
||||
@ -2483,9 +2488,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
|
||||
LookUpInstructionByHandle(operand_handle));
|
||||
|
||||
TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension));
|
||||
auto constant_dimension_size =
|
||||
static_cast<uint32>(operand_proto->shape().dimensions(dimension));
|
||||
int32 constant_dimension_size = -1;
|
||||
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
|
||||
dynamic_dimension_is_minus_one)) {
|
||||
constant_dimension_size =
|
||||
static_cast<int32>(operand_proto->shape().dimensions(dimension));
|
||||
}
|
||||
|
||||
Literal literal = LiteralUtil::CreateR0(constant_dimension_size);
|
||||
|
||||
|
@ -258,7 +258,8 @@ class XlaBuilder {
|
||||
// compile-time constant (see `IsConstant`), returns an error.
|
||||
//
|
||||
// This will copy the needed ops/computations to the subgraph.
|
||||
StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
|
||||
StatusOr<XlaComputation> BuildConstantSubGraph(
|
||||
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
|
||||
|
||||
// Returns the first error that was encountered while building the
|
||||
// computation. When an error is encountered, by default we return a vacuous
|
||||
|
@ -917,10 +917,7 @@ TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
|
||||
auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6]
|
||||
Select(pred, gte0, gte1);
|
||||
Status status = BuildHloModule(&b).status();
|
||||
ASSERT_IS_NOT_OK(status);
|
||||
EXPECT_THAT(status.error_message(),
|
||||
::testing::HasSubstr("Operands to select must be the same shape; "
|
||||
"got f32[4,<=5,6] and f32[4,5,<=6]"));
|
||||
ASSERT_IS_OK(status);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, DynamicTranspose) {
|
||||
|
@ -105,8 +105,8 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
|
||||
HloInstruction* operand, int64 feature_index,
|
||||
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
|
||||
add_instruction) {
|
||||
auto elements_per_feature_u32 = add_instruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
|
||||
auto elements_per_feature_s32 = add_instruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
for (int64 i = 0; i < operand->shape().rank(); ++i) {
|
||||
if (i == feature_index) {
|
||||
@ -114,15 +114,15 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
|
||||
}
|
||||
auto dynamic_dimension_size =
|
||||
add_instruction(HloInstruction::CreateGetDimensionSize(
|
||||
ShapeUtil::MakeShape(U32, {}), operand, i));
|
||||
elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply,
|
||||
dynamic_dimension_size, elements_per_feature_u32));
|
||||
ShapeUtil::MakeShape(S32, {}), operand, i));
|
||||
elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
|
||||
dynamic_dimension_size, elements_per_feature_s32));
|
||||
}
|
||||
|
||||
return HloInstruction::CreateConvert(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
|
||||
elements_per_feature_u32);
|
||||
elements_per_feature_s32);
|
||||
}
|
||||
|
||||
// Current HloComputation instance the BatchNormExpander is
|
||||
|
@ -463,7 +463,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
reshape->shape().dimensions(0) / operand->shape().dimensions(0);
|
||||
HloInstruction* multiplier_hlo =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(multiplier)));
|
||||
LiteralUtil::CreateR0<int32>(multiplier)));
|
||||
|
||||
HloInstruction* new_dynamic_size =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -638,7 +638,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
|
||||
reshape->shape().dimensions(dynamic_dimension);
|
||||
HloInstruction* divisor_hlo =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<uint32>(divisor)));
|
||||
LiteralUtil::CreateR0<int32>(divisor)));
|
||||
|
||||
HloInstruction* new_dynamic_size =
|
||||
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
|
||||
|
@ -94,7 +94,7 @@ class DynamicDimensionInferenceTest : public HloTestBase {
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
std::unique_ptr<DynamicDimensionInference> inference_;
|
||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
|
||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
|
||||
};
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, ParamTest) {
|
||||
@ -557,7 +557,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
|
||||
EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
|
||||
const Literal& multiplier =
|
||||
inference_->GetDynamicSize(reshape, {}, 0)->operand(1)->literal();
|
||||
LiteralTestUtil::ExpectR0Equal<uint32>(10, multiplier);
|
||||
LiteralTestUtil::ExpectR0Equal<int32>(10, multiplier);
|
||||
}
|
||||
|
||||
TEST_F(DynamicDimensionInferenceTest, GatherTest) {
|
||||
@ -895,7 +895,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
|
||||
std::vector<HloInstruction*> params;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
|
||||
i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
|
||||
}
|
||||
|
||||
auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
@ -997,7 +997,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
|
||||
std::vector<HloInstruction*> params;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
|
||||
i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
|
||||
}
|
||||
|
||||
auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
|
||||
|
@ -164,7 +164,7 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
|
||||
// mask and pad value.
|
||||
//
|
||||
const Shape mask_shape =
|
||||
ShapeUtil::ChangeElementType(operand->shape(), xla::U32);
|
||||
ShapeUtil::ChangeElementType(operand->shape(), xla::S32);
|
||||
const Shape pred_shape =
|
||||
ShapeUtil::ChangeElementType(operand->shape(), xla::PRED);
|
||||
HloInstruction* iota = computation->AddInstruction(
|
||||
|
@ -65,7 +65,7 @@ class DynamicPadderTest : public HloTestBase {
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> module_;
|
||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
|
||||
const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
|
||||
};
|
||||
|
||||
TEST_F(DynamicPadderTest, ReduceTest) {
|
||||
|
@ -410,9 +410,9 @@ Status HloEvaluator::HandleGetDimensionSize(
|
||||
}
|
||||
|
||||
const Shape& shape = get_dimension_size->operand(0)->shape();
|
||||
Literal output(ShapeUtil::MakeShape(U32, {}));
|
||||
Literal output(ShapeUtil::MakeShape(S32, {}));
|
||||
output.PopulateWithValue(
|
||||
static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
|
||||
static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
|
||||
evaluated_[get_dimension_size] = std::move(output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -4154,13 +4154,13 @@ TEST_F(HloEvaluatorTest, GetDimensionSize) {
|
||||
HloModule Test
|
||||
|
||||
ENTRY main {
|
||||
size = u32[] parameter(0)
|
||||
size = s32[] parameter(0)
|
||||
|
||||
data = s32[4] parameter(1)
|
||||
|
||||
sum = s32[4] add(data, data)
|
||||
|
||||
ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
|
||||
ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
|
||||
@ -4174,12 +4174,12 @@ ENTRY main {
|
||||
DynamicDimensionInference::Run(m_.get()));
|
||||
|
||||
evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
|
||||
Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
|
||||
Literal size_arg = LiteralUtil::CreateR0<int32>(3);
|
||||
Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
|
||||
|
||||
EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
|
||||
EXPECT_EQ(actual.GetFirstElement<int32>(), static_cast<int32>(3));
|
||||
}
|
||||
|
||||
// Check that we get a useful error if we pass inputs of the wrong shape.
|
||||
|
@ -37,8 +37,10 @@ StatusOr<bool> ReplaceGetSize(
|
||||
TF_ASSIGN_OR_RETURN(auto legal_shape,
|
||||
ShapeInference::InferGetDimensionSizeShape(
|
||||
instr->operand(0)->shape(), instr->dimension()));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape));
|
||||
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
|
||||
<< "instr->shape() " << instr->shape().ToString() << " , "
|
||||
<< "legal_shape " << legal_shape.ToString();
|
||||
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
|
||||
HloInstruction* operand = instr->mutable_operand(0);
|
||||
int64 dim = instr->dimension();
|
||||
HloInstruction* dynamic_size =
|
||||
@ -46,9 +48,9 @@ StatusOr<bool> ReplaceGetSize(
|
||||
if (dynamic_size != nullptr) {
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
|
||||
} else {
|
||||
uint32 size = instr->operand(0)->shape().dimensions(dim);
|
||||
int32 size = instr->operand(0)->shape().dimensions(dim);
|
||||
HloInstruction* new_instr = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(size)));
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
|
||||
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
|
||||
}
|
||||
return true;
|
||||
|
@ -44,9 +44,9 @@ TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = s32[3,4] parameter(0)
|
||||
size0 = u32[] get-dimension-size(p), dimensions={0}
|
||||
size1 = u32[] get-dimension-size(p), dimensions={1}
|
||||
ROOT mul = u32[] multiply(size0, size1)
|
||||
size0 = s32[] get-dimension-size(p), dimensions={0}
|
||||
size1 = s32[] get-dimension-size(p), dimensions={1}
|
||||
ROOT mul = s32[] multiply(size0, size1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
@ -72,7 +72,7 @@ TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
|
||||
HloModule _
|
||||
ENTRY gds {
|
||||
p = f32[2,5] parameter(0)
|
||||
ROOT gds = u32[] get-dimension-size(p), dimensions={2}
|
||||
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
HloGetDimensionSizeRewriter pass;
|
||||
|
@ -652,9 +652,7 @@ Status ValidateDotDimensionNumbers(
|
||||
const int64 rhs_contracting_dimension =
|
||||
dimension_numbers.rhs_contracting_dimensions(i);
|
||||
if (lhs.dimensions(lhs_contracting_dimension) !=
|
||||
rhs.dimensions(rhs_contracting_dimension) ||
|
||||
lhs.is_dynamic_dimension(lhs_contracting_dimension) !=
|
||||
rhs.is_dynamic_dimension(rhs_contracting_dimension)) {
|
||||
rhs.dimensions(rhs_contracting_dimension)) {
|
||||
return fail("Contracting dimension sizes do not match.");
|
||||
}
|
||||
}
|
||||
@ -668,10 +666,7 @@ Status ValidateDotDimensionNumbers(
|
||||
// Check that batch dimension numbers and sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
|
||||
if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
|
||||
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) ||
|
||||
lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) !=
|
||||
rhs.is_dynamic_dimension(
|
||||
dimension_numbers.rhs_batch_dimensions(i))) {
|
||||
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
|
||||
return fail("Batch dimension sizes must match for lhs/rhs.");
|
||||
}
|
||||
}
|
||||
@ -726,13 +721,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
for (int64 i = 0; i < lhs.rank(); ++i) {
|
||||
if (lhs.dimensions(i) == rhs.dimensions(i)) {
|
||||
output_dimensions[i] = lhs.dimensions(i);
|
||||
output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
|
||||
} else if (lhs.dimensions(i) == 1) {
|
||||
output_dimensions[i] = rhs.dimensions(i);
|
||||
output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i);
|
||||
} else if (rhs.dimensions(i) == 1) {
|
||||
output_dimensions[i] = lhs.dimensions(i);
|
||||
output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
|
||||
} else {
|
||||
return InvalidArgument(
|
||||
"Binary op %s with incompatible shapes: %s and %s.",
|
||||
@ -740,6 +732,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(rhs));
|
||||
}
|
||||
}
|
||||
|
||||
// Merge dynamic dimensions from two shapes.
|
||||
for (int64 i = 0; i < rhs.rank(); ++i) {
|
||||
if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
|
||||
output_dimensions_is_dynamic[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
output_dimensions, output_dimensions_is_dynamic);
|
||||
}
|
||||
@ -888,11 +888,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
|
||||
// If the shapes are the same other than layout, the output shape is the
|
||||
// same (elementwise op).
|
||||
return ShapeUtil::ChangeElementType(
|
||||
Shape result = ShapeUtil::ChangeElementType(
|
||||
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
|
||||
}
|
||||
|
||||
if (lhs.rank() == rhs.rank()) {
|
||||
for (int64 i = 0; i < rhs.rank(); ++i) {
|
||||
if (rhs.is_dynamic_dimension(i)) {
|
||||
result.set_dynamic_dimension(i, true);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
|
||||
} else if (lhs.rank() == rhs.rank()) {
|
||||
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
|
||||
} else {
|
||||
// Ranks do not match, so perform InDim broadcasting using
|
||||
@ -2201,14 +2208,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
|
||||
// TODO(b/119580730): Remove this restriction when very large dimension size
|
||||
// is needed.
|
||||
if (shape.dimensions(dimension) > std::numeric_limits<uint32>::max()) {
|
||||
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
|
||||
return InvalidArgument(
|
||||
"GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
|
||||
"UINT_MAX limit.",
|
||||
"INT_MAX limit.",
|
||||
ShapeUtil::HumanString(shape), dimension);
|
||||
}
|
||||
|
||||
return ShapeUtil::MakeShape(U32, {});
|
||||
return ShapeUtil::MakeShape(S32, {});
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
|
||||
@ -2324,7 +2331,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
sizes.push_back((limit_index - start_index + stride - 1) / stride);
|
||||
}
|
||||
|
||||
return ShapeUtil::MakeShape(arg.element_type(), sizes);
|
||||
std::vector<bool> is_dynamic(arg.rank());
|
||||
for (int64 i = 0; i < arg.dimensions_size(); ++i) {
|
||||
is_dynamic[i] = arg.is_dynamic_dimension(i);
|
||||
}
|
||||
|
||||
return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
|
||||
|
@ -507,17 +507,23 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
|
||||
return Shape::Equal().IgnoreLayout()(lhs, rhs);
|
||||
return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs);
|
||||
return Shape::Equal()
|
||||
.IgnoreDynamicDimension()
|
||||
.IgnoreElementType()
|
||||
.IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs);
|
||||
return Shape::Equal()
|
||||
.IgnoreDynamicDimension()
|
||||
.IgnoreFpPrecision()
|
||||
.IgnoreLayout()(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
@ -195,7 +196,7 @@ TEST(ShapeUtilTest, CompatibleDynamicShapes) {
|
||||
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b));
|
||||
EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_c));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleTuples) {
|
||||
|
@ -737,7 +737,7 @@ def _pad_all_input(inputs, padded_shapes):
|
||||
padding_map.padding_arg_index = real_shape_idx
|
||||
padding_maps.append(padding_map)
|
||||
real_shapes[core_idx].append(
|
||||
math_ops.cast(input_shape_tensor[i], dtypes.uint32))
|
||||
math_ops.cast(input_shape_tensor[i], dtypes.int32))
|
||||
|
||||
paddings = []
|
||||
for i, s in enumerate(padded_shape.dims):
|
||||
|
Loading…
Reference in New Issue
Block a user