Fix segment_reduction to support dynamic dims correctly.
Previously it just ignores dynamic dimension. PiperOrigin-RevId: 327861140 Change-Id: Icfe9a6293cc28ca2b811b1810e790f4c62e1e4a3
This commit is contained in:
parent
e98f54f469
commit
f8d80a78a3
@ -74,12 +74,44 @@ class UnsortedSegmentReduce : public XlaOpKernel {
|
||||
" vs. ", indices_shape.dim_size(d)));
|
||||
}
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
// data shape = [indices_shape, segment_shape]
|
||||
// buffer shape = [num_segment, segment_shape]
|
||||
// We now create the buffer shape by reverse enginerring data shape into
|
||||
// indices shape and segment shape.
|
||||
TensorShape buffer_shape = data_shape;
|
||||
buffer_shape.RemoveDimRange(0, indices_shape.dims());
|
||||
buffer_shape.InsertDim(0, num_segments);
|
||||
|
||||
auto buffer =
|
||||
xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
|
||||
|
||||
// Build dynamic dim sizes for buffer, as well as whether each dimension
|
||||
// size is dynamic or static. We build two parts: num_sgement part and
|
||||
// segment_shape part.
|
||||
std::vector<xla::XlaOp> buffer_dims;
|
||||
std::vector<bool> buffer_dims_are_dynamic;
|
||||
// Build the "num_segment" part.
|
||||
bool num_segments_is_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
|
||||
|
||||
buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
|
||||
buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
|
||||
num_segments_is_dynamic);
|
||||
// Build the segment shape part.
|
||||
for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
|
||||
buffer_dims.push_back(xla::GetDimensionSize(data, i));
|
||||
buffer_dims_are_dynamic.push_back(
|
||||
ctx->InputXlaShape(0)->is_dynamic_dimension(i));
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < buffer_dims.size(); ++i) {
|
||||
if (buffer_dims_are_dynamic[i]) {
|
||||
// For each dynamic dimension, call set-dimension-size on it.
|
||||
buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
|
||||
xla::XlaBuilder* builder) { return Combine(a, b); };
|
||||
|
||||
|
@ -116,12 +116,44 @@ class UnsortedSegmentSum : public XlaOpKernel {
|
||||
indices_shape.dim_size(d)));
|
||||
}
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
// data shape = [indices_shape, segment_shape]
|
||||
// buffer shape = [num_segment, segment_shape]
|
||||
// We now create the buffer shape by reverse enginerring data shape into
|
||||
// indices shape and segment shape.
|
||||
TensorShape buffer_shape = data_shape;
|
||||
buffer_shape.RemoveDimRange(0, indices_shape.dims());
|
||||
buffer_shape.InsertDim(0, num_segments);
|
||||
|
||||
auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
|
||||
buffer_shape.dim_sizes());
|
||||
|
||||
// Build dynamic dim sizes for buffer, as well as whether each dimension
|
||||
// size is dynamic or static. We build two parts: num_sgement part and
|
||||
// segment_shape part.
|
||||
std::vector<xla::XlaOp> buffer_dims;
|
||||
std::vector<bool> buffer_dims_are_dynamic;
|
||||
// Build the "num_segment" part.
|
||||
bool num_segments_is_dynamic;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
|
||||
|
||||
buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
|
||||
buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
|
||||
num_segments_is_dynamic);
|
||||
// Build the segment shape part.
|
||||
for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
|
||||
buffer_dims.push_back(xla::GetDimensionSize(data, i));
|
||||
buffer_dims_are_dynamic.push_back(
|
||||
ctx->InputXlaShape(0)->is_dynamic_dimension(i));
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < buffer_dims.size(); ++i) {
|
||||
if (buffer_dims_are_dynamic[i]) {
|
||||
// For each dynamic dimension, call set-dimension-size on it.
|
||||
buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
|
||||
return a + b;
|
||||
};
|
||||
|
@ -632,6 +632,34 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
# This assumes that there are exactly 2 replicas
|
||||
self.assertAllEqual([2, 1], run(next(input_iterator)))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
mode=["eager"]))
|
||||
def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
|
||||
|
||||
def dataset_fn(_):
|
||||
data = array_ops.zeros(5, dtype=dtypes.int32)
|
||||
dataset = get_dataset_from_tensor_slices(data)
|
||||
dataset = dataset.batch(3)
|
||||
return dataset
|
||||
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
@def_function.function
|
||||
def step_fn(example):
|
||||
segment_ids = array_ops.zeros_like_v2(example)
|
||||
num_segment = array_ops.shape(example)[0]
|
||||
# If number of segments is dynamic, output should be a dynamic shape.
|
||||
return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
|
||||
|
||||
# This assumes that there are exactly 2 replicas
|
||||
outputs = distribution.experimental_local_results(
|
||||
distribution.run(step_fn, args=(next(input_iterator),)))
|
||||
self.assertAllEqual((3,), outputs[0].shape)
|
||||
self.assertAllEqual((2,), outputs[1].shape)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.multidevice_strategies,
|
||||
|
Loading…
x
Reference in New Issue
Block a user