[TF2XLA] Preserve the dynamic dimension (-1) when building a reshape.

This cl is created to handle the problem of [4] -> [2, 2] reshape, where 4 is a dynamic dimension.

- We need the producer of reshape's second operand to return "-1", indicating a dimension is dynamic.
- Change GetDynamicSize's return type to S32 so we support '-1'.
- Constant folding has been to be changed to return -1 when a dimension is dynamic and when a special flag is passed to an op kernel.
- Resurrect the "dynamci dimension inference" feature in xla buidler. Expect some brokeness as this feature is not heavily exercised.

PiperOrigin-RevId: 267465833
This commit is contained in:
Yunxing Dai 2019-09-05 15:14:07 -07:00 committed by TensorFlower Gardener
parent d6eab240b9
commit 792abd2eaf
23 changed files with 150 additions and 83 deletions

View File

@ -106,13 +106,34 @@ class ReshapeOp : public XlaOpKernel {
" values, but the requested shape has ",
shape.num_elements()));
VLOG(1) << "Reshape from " << input_shape.DebugString() << " to "
VLOG(2) << "Reshape from " << input_shape.DebugString() << " to "
<< shape.DebugString() << ", unknown_index=" << unknown_index;
shape_input.clear();
// Run get input again, this time with dynamic dimension represented as
// "-1"
ctx->set_dynamic_dimension_is_minus_one(true);
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
int dynamic_dimension = -1;
for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d];
if (size == -1) {
if (dynamic_dimension == -1) {
dynamic_dimension = d;
} else {
if (unknown_index != d) {
dynamic_dimension = d;
}
}
}
}
// Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference
// in XLA to know which output dimension is dynamic.
ctx->SetOutput(0, xla::ReshapeWithInferredDimension(
ctx->Input(0), shape.dim_sizes(), unknown_index));
ctx->Input(0), shape.dim_sizes(), dynamic_dimension));
}
};

View File

@ -119,7 +119,9 @@ class SizeOp : public XlaOpKernel {
xla::XlaBuilder* builder = ctx->builder();
auto size = xla::One(builder, xla::U32);
for (int64 i = 0; i < rank; ++i) {
size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
size = xla::Mul(
size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
xla::U32));
}
size = xla::ConvertElementType(size, ctx->output_xla_type(0));
ctx->SetOutput(0, size);

View File

@ -173,6 +173,7 @@ Status BuildComputation(
xla::OpMetadata retval_metadata;
retval_metadata.set_op_name("XLA_Retvals");
builder->SetOpMetadata(retval_metadata);
VLOG(1) << "Building new computation";
auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
// Builds a no-op XLA computation. We need to set the sharding of outputs, but
@ -915,6 +916,9 @@ Status XlaCompiler::BuildArguments(
const XlaCompiler::Argument& arg = args[input_to_args->at(i)];
for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) {
int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second);
VLOG(1) << "Setting dynamic binding " << i << " -> "
<< dynamic_size_param_index;
TF_RETURN_IF_ERROR(builder->SetDynamicBinding(
/*dynamic_size_param_num=*/0, {dynamic_size_param_index},
/*target_param_num=*/0, /*target_param_index=*/{i},
@ -1170,7 +1174,7 @@ Status XlaCompiler::CompileGraph(
std::unique_ptr<Graph> graph, absl::Span<const XlaCompiler::Argument> args,
absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
CompilationResult* result) {
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.";
VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name;
TF_RETURN_IF_ERROR(PropagateConstIntoFunctionalNodes(
graph.get(), options_.flib_def, local_flib_def_.get()));

View File

@ -102,7 +102,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
}
xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
xla::Client* client) const {
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
switch (kind()) {
case Kind::kConstant:
return {constant_value()};
@ -122,7 +122,8 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
if (!is_constant) return {absl::nullopt};
TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
handle().builder()->BuildConstantSubGraph(handle()));
handle().builder()->BuildConstantSubGraph(
handle(), dynamic_dimension_is_minus_one));
TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());

View File

@ -97,7 +97,7 @@ class XlaExpression {
// optional if it cannot be resolved. Returns an error if passed a resource
// expression.
xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
xla::Client* client) const;
xla::Client* client, bool dynamic_dimension_is_minus_one = false) const;
// Returns the shape of the tensor.
// The shape of a resource is the shape of a resource handle (i.e., a scalar),

View File

@ -31,7 +31,7 @@ limitations under the License.
namespace tensorflow {
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
: context_(context) {}
: context_(context), dynamic_dimension_is_minus_one_(false) {}
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
return context_->ValidateInputsAreSameShape(op);
@ -166,7 +166,7 @@ Status XlaOpKernelContext::ConstantInputReshaped(
xla::Literal* constant_literal) {
XlaExpression e = InputExpression(index);
xla::StatusOr<absl::optional<Tensor>> constant_or_status =
e.ResolveConstant(compiler()->client());
e.ResolveConstant(compiler()->client(), dynamic_dimension_is_minus_one_);
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",

View File

@ -202,6 +202,17 @@ class XlaOpKernelContext {
Status GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const;
// When dynamic_dimension_is_minus_one is set, querying a dynamic dimension
// returns "-1", this is useful when the underlying ops expect explicit
// dynamic index like reshape.
void set_dynamic_dimension_is_minus_one(bool value) {
dynamic_dimension_is_minus_one_ = value;
}
bool dynamic_dimension_is_minus_one() const {
return dynamic_dimension_is_minus_one_;
}
// Reads the current value of the resouce variable referred to by input
// `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if
@ -280,6 +291,7 @@ class XlaOpKernelContext {
xla::Literal* constant_literal);
OpKernelContext* const context_;
bool dynamic_dimension_is_minus_one_;
};
} // namespace tensorflow

View File

@ -213,16 +213,10 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
// TODO(b/32495713): We aren't checking the called computations.
break;
case HloOpcode::kGetDimensionSize: {
int64 dimension_number = instr.dimensions(0);
const HloInstructionProto& operand =
*(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
Shape operand_shape(operand.shape());
if (operand_shape.is_dynamic_dimension(dimension_number)) {
*is_constant = false;
}
// DimensionSize is always considered constant in XLA -- If a dynamic
// dimension is presented, uint_max is returned.
break;
}
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kAllReduce:
@ -268,8 +262,8 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
for (int64 index : target_param_index) {
param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
}
// TODO(b/121223198): Set `is_dynamic` to the parameter shape when XLA
// backend can handle dynamic dimensions.
param_shape_ptr->set_dynamic_dimension(target_dim_num,
/*is_dynamic=*/true);
*instr.mutable_shape() = param_shape.ToProto();
}
}
@ -435,6 +429,7 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
for (int64 dim : broadcast_dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
}
@ -468,11 +463,21 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
<< operand_shape << "; output_shape: " << output_shape;
}
}
Shape reshaped_shape =
ShapeUtil::MakeShape(operand_shape.element_type(), reshaped_dimensions);
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(operand_shape, reshaped_shape);
for (auto& unmodified : unmodified_dims) {
if (operand_shape.is_dynamic_dimension(unmodified.first)) {
reshaped_shape.set_dynamic_dimension(unmodified.second, true);
}
}
// Eliminate the size one dimensions.
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
reshaped_dimensions),
operand));
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
// Broadcast 'reshape' up to the larger size.
return InDimBroadcast(broadcast_shape, reshaped_operand,
broadcast_dimensions);
@ -2428,7 +2433,7 @@ StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
}
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
const XlaOp& root_op) {
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
if (!is_constant) {
auto op_status = LookUpInstruction(root_op);
@ -2483,9 +2488,12 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
TF_RET_CHECK(!operand_proto->shape().is_dynamic_dimension(dimension));
auto constant_dimension_size =
static_cast<uint32>(operand_proto->shape().dimensions(dimension));
int32 constant_dimension_size = -1;
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
dynamic_dimension_is_minus_one)) {
constant_dimension_size =
static_cast<int32>(operand_proto->shape().dimensions(dimension));
}
Literal literal = LiteralUtil::CreateR0(constant_dimension_size);

View File

@ -258,7 +258,8 @@ class XlaBuilder {
// compile-time constant (see `IsConstant`), returns an error.
//
// This will copy the needed ops/computations to the subgraph.
StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
StatusOr<XlaComputation> BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous

View File

@ -917,10 +917,7 @@ TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6]
Select(pred, gte0, gte1);
Status status = BuildHloModule(&b).status();
ASSERT_IS_NOT_OK(status);
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("Operands to select must be the same shape; "
"got f32[4,<=5,6] and f32[4,5,<=6]"));
ASSERT_IS_OK(status);
}
TEST_F(XlaBuilderTest, DynamicTranspose) {

View File

@ -105,8 +105,8 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
HloInstruction* operand, int64 feature_index,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
auto elements_per_feature_u32 = add_instruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(1)));
auto elements_per_feature_s32 = add_instruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
for (int64 i = 0; i < operand->shape().rank(); ++i) {
if (i == feature_index) {
@ -114,15 +114,15 @@ class BatchNormExpanderVisitor : public DfsHloRewriteVisitor {
}
auto dynamic_dimension_size =
add_instruction(HloInstruction::CreateGetDimensionSize(
ShapeUtil::MakeShape(U32, {}), operand, i));
elements_per_feature_u32 = add_instruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(U32, {}), HloOpcode::kMultiply,
dynamic_dimension_size, elements_per_feature_u32));
ShapeUtil::MakeShape(S32, {}), operand, i));
elements_per_feature_s32 = add_instruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply,
dynamic_dimension_size, elements_per_feature_s32));
}
return HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
elements_per_feature_u32);
elements_per_feature_s32);
}
// Current HloComputation instance the BatchNormExpander is

View File

@ -463,7 +463,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
reshape->shape().dimensions(0) / operand->shape().dimensions(0);
HloInstruction* multiplier_hlo =
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(multiplier)));
LiteralUtil::CreateR0<int32>(multiplier)));
HloInstruction* new_dynamic_size =
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(
@ -638,7 +638,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) {
reshape->shape().dimensions(dynamic_dimension);
HloInstruction* divisor_hlo =
hlo->parent()->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<uint32>(divisor)));
LiteralUtil::CreateR0<int32>(divisor)));
HloInstruction* new_dynamic_size =
hlo->parent()->AddInstruction(HloInstruction::CreateBinary(

View File

@ -94,7 +94,7 @@ class DynamicDimensionInferenceTest : public HloTestBase {
std::unique_ptr<HloModule> module_;
std::unique_ptr<DynamicDimensionInference> inference_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
};
TEST_F(DynamicDimensionInferenceTest, ParamTest) {
@ -557,7 +557,7 @@ TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
const Literal& multiplier =
inference_->GetDynamicSize(reshape, {}, 0)->operand(1)->literal();
LiteralTestUtil::ExpectR0Equal<uint32>(10, multiplier);
LiteralTestUtil::ExpectR0Equal<int32>(10, multiplier);
}
TEST_F(DynamicDimensionInferenceTest, GatherTest) {
@ -895,7 +895,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
std::vector<HloInstruction*> params;
for (int i = 0; i < 2; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
}
auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
@ -997,7 +997,7 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
std::vector<HloInstruction*> params;
for (int i = 0; i < 2; ++i) {
params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
i + 2, ShapeUtil::MakeShape(U32, {}), "slice_indices")));
i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
}
auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(

View File

@ -164,7 +164,7 @@ StatusOr<bool> DynamicPadder::Run(HloModule* module) {
// mask and pad value.
//
const Shape mask_shape =
ShapeUtil::ChangeElementType(operand->shape(), xla::U32);
ShapeUtil::ChangeElementType(operand->shape(), xla::S32);
const Shape pred_shape =
ShapeUtil::ChangeElementType(operand->shape(), xla::PRED);
HloInstruction* iota = computation->AddInstruction(

View File

@ -65,7 +65,7 @@ class DynamicPadderTest : public HloTestBase {
}
std::unique_ptr<HloModule> module_;
const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
};
TEST_F(DynamicPadderTest, ReduceTest) {

View File

@ -410,9 +410,9 @@ Status HloEvaluator::HandleGetDimensionSize(
}
const Shape& shape = get_dimension_size->operand(0)->shape();
Literal output(ShapeUtil::MakeShape(U32, {}));
Literal output(ShapeUtil::MakeShape(S32, {}));
output.PopulateWithValue(
static_cast<uint32>(shape.dimensions(get_dimension_size->dimension())));
static_cast<int32>(shape.dimensions(get_dimension_size->dimension())));
evaluated_[get_dimension_size] = std::move(output);
return Status::OK();
}

View File

@ -4154,13 +4154,13 @@ TEST_F(HloEvaluatorTest, GetDimensionSize) {
HloModule Test
ENTRY main {
size = u32[] parameter(0)
size = s32[] parameter(0)
data = s32[4] parameter(1)
sum = s32[4] add(data, data)
ROOT dynamic_size = u32[] get-dimension-size(sum), dimensions={0}
ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
@ -4174,12 +4174,12 @@ ENTRY main {
DynamicDimensionInference::Run(m_.get()));
evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
Literal size_arg = LiteralUtil::CreateR0<uint32>(3);
Literal size_arg = LiteralUtil::CreateR0<int32>(3);
Literal data_arg = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
EXPECT_EQ(actual.GetFirstElement<uint32>(), static_cast<uint32>(3));
EXPECT_EQ(actual.GetFirstElement<int32>(), static_cast<int32>(3));
}
// Check that we get a useful error if we pass inputs of the wrong shape.

View File

@ -37,8 +37,10 @@ StatusOr<bool> ReplaceGetSize(
TF_ASSIGN_OR_RETURN(auto legal_shape,
ShapeInference::InferGetDimensionSizeShape(
instr->operand(0)->shape(), instr->dimension()));
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape));
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), U32));
TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape))
<< "instr->shape() " << instr->shape().ToString() << " , "
<< "legal_shape " << legal_shape.ToString();
TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32));
HloInstruction* operand = instr->mutable_operand(0);
int64 dim = instr->dimension();
HloInstruction* dynamic_size =
@ -46,9 +48,9 @@ StatusOr<bool> ReplaceGetSize(
if (dynamic_size != nullptr) {
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size));
} else {
uint32 size = instr->operand(0)->shape().dimensions(dim);
int32 size = instr->operand(0)->shape().dimensions(dim);
HloInstruction* new_instr = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(size)));
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(size)));
TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr));
}
return true;

View File

@ -44,9 +44,9 @@ TEST_F(HloGetDimensionSizeRewriterTest, Ok) {
HloModule _
ENTRY gds {
p = s32[3,4] parameter(0)
size0 = u32[] get-dimension-size(p), dimensions={0}
size1 = u32[] get-dimension-size(p), dimensions={1}
ROOT mul = u32[] multiply(size0, size1)
size0 = s32[] get-dimension-size(p), dimensions={0}
size1 = s32[] get-dimension-size(p), dimensions={1}
ROOT mul = s32[] multiply(size0, size1)
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;
@ -72,7 +72,7 @@ TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) {
HloModule _
ENTRY gds {
p = f32[2,5] parameter(0)
ROOT gds = u32[] get-dimension-size(p), dimensions={2}
ROOT gds = s32[] get-dimension-size(p), dimensions={2}
})")
.ValueOrDie();
HloGetDimensionSizeRewriter pass;

View File

@ -652,9 +652,7 @@ Status ValidateDotDimensionNumbers(
const int64 rhs_contracting_dimension =
dimension_numbers.rhs_contracting_dimensions(i);
if (lhs.dimensions(lhs_contracting_dimension) !=
rhs.dimensions(rhs_contracting_dimension) ||
lhs.is_dynamic_dimension(lhs_contracting_dimension) !=
rhs.is_dynamic_dimension(rhs_contracting_dimension)) {
rhs.dimensions(rhs_contracting_dimension)) {
return fail("Contracting dimension sizes do not match.");
}
}
@ -668,10 +666,7 @@ Status ValidateDotDimensionNumbers(
// Check that batch dimension numbers and sizes match.
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
if (lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i)) ||
lhs.is_dynamic_dimension(dimension_numbers.lhs_batch_dimensions(i)) !=
rhs.is_dynamic_dimension(
dimension_numbers.rhs_batch_dimensions(i))) {
rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
return fail("Batch dimension sizes must match for lhs/rhs.");
}
}
@ -726,13 +721,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
for (int64 i = 0; i < lhs.rank(); ++i) {
if (lhs.dimensions(i) == rhs.dimensions(i)) {
output_dimensions[i] = lhs.dimensions(i);
output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
} else if (lhs.dimensions(i) == 1) {
output_dimensions[i] = rhs.dimensions(i);
output_dimensions_is_dynamic[i] = rhs.is_dynamic_dimension(i);
} else if (rhs.dimensions(i) == 1) {
output_dimensions[i] = lhs.dimensions(i);
output_dimensions_is_dynamic[i] = lhs.is_dynamic_dimension(i);
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
@ -740,6 +732,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(rhs));
}
}
// Merge dynamic dimensions from two shapes.
for (int64 i = 0; i < rhs.rank(); ++i) {
if (rhs.is_dynamic_dimension(i) || lhs.is_dynamic_dimension(i)) {
output_dimensions_is_dynamic[i] = true;
}
}
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
output_dimensions, output_dimensions_is_dynamic);
}
@ -888,11 +888,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
// If the shapes are the same other than layout, the output shape is the
// same (elementwise op).
return ShapeUtil::ChangeElementType(
Shape result = ShapeUtil::ChangeElementType(
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
}
if (lhs.rank() == rhs.rank()) {
for (int64 i = 0; i < rhs.rank(); ++i) {
if (rhs.is_dynamic_dimension(i)) {
result.set_dynamic_dimension(i, true);
}
}
return result;
} else if (lhs.rank() == rhs.rank()) {
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
} else {
// Ranks do not match, so perform InDim broadcasting using
@ -2201,14 +2208,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// TODO(b/119580730): Remove this restriction when very large dimension size
// is needed.
if (shape.dimensions(dimension) > std::numeric_limits<uint32>::max()) {
if (shape.dimensions(dimension) > std::numeric_limits<int32>::max()) {
return InvalidArgument(
"GetDimensionSize's input shape is %s, the %dth dimension exceeds the "
"UINT_MAX limit.",
"INT_MAX limit.",
ShapeUtil::HumanString(shape), dimension);
}
return ShapeUtil::MakeShape(U32, {});
return ShapeUtil::MakeShape(S32, {});
}
/* static */ StatusOr<Window> ShapeInference::InferWindowFromDimensions(
@ -2324,7 +2331,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
sizes.push_back((limit_index - start_index + stride - 1) / stride);
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);
std::vector<bool> is_dynamic(arg.rank());
for (int64 i = 0; i < arg.dimensions_size(); ++i) {
is_dynamic[i] = arg.is_dynamic_dimension(i);
}
return ShapeUtil::MakeShape(arg.element_type(), sizes, is_dynamic);
}
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(

View File

@ -507,17 +507,23 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
return Shape::Equal().IgnoreLayout()(lhs, rhs);
return Shape::Equal().IgnoreDynamicDimension().IgnoreLayout()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
return Shape::Equal().IgnoreElementType().IgnoreLayout()(lhs, rhs);
return Shape::Equal()
.IgnoreDynamicDimension()
.IgnoreElementType()
.IgnoreLayout()(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
return Shape::Equal().IgnoreFpPrecision().IgnoreLayout()(lhs, rhs);
return Shape::Equal()
.IgnoreDynamicDimension()
.IgnoreFpPrecision()
.IgnoreLayout()(lhs, rhs);
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include <numeric>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/layout_util.h"
@ -195,7 +196,7 @@ TEST(ShapeUtilTest, CompatibleDynamicShapes) {
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_a));
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_b));
EXPECT_FALSE(ShapeUtil::Compatible(shape_a, shape_c));
EXPECT_TRUE(ShapeUtil::Compatible(shape_a, shape_c));
}
TEST(ShapeUtilTest, CompatibleTuples) {

View File

@ -737,7 +737,7 @@ def _pad_all_input(inputs, padded_shapes):
padding_map.padding_arg_index = real_shape_idx
padding_maps.append(padding_map)
real_shapes[core_idx].append(
math_ops.cast(input_shape_tensor[i], dtypes.uint32))
math_ops.cast(input_shape_tensor[i], dtypes.int32))
paddings = []
for i, s in enumerate(padded_shape.dims):