diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 97359f81eee..d63b8146491 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -74,12 +74,44 @@ class UnsortedSegmentReduce : public XlaOpKernel { " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); + // data shape = [indices_shape, segment_shape] + // buffer shape = [num_segment, segment_shape] + // We now create the buffer shape by reverse enginerring data shape into + // indices shape and segment shape. TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); + auto buffer = xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + // Build dynamic dim sizes for buffer, as well as whether each dimension + // size is dynamic or static. We build two parts: num_sgement part and + // segment_shape part. + std::vector buffer_dims; + std::vector buffer_dims_are_dynamic; + // Build the "num_segment" part. + bool num_segments_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic)); + + buffer_dims.insert(buffer_dims.begin(), ctx->Input(2)); + buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(), + num_segments_is_dynamic); + // Build the segment shape part. + for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) { + buffer_dims.push_back(xla::GetDimensionSize(data, i)); + buffer_dims_are_dynamic.push_back( + ctx->InputXlaShape(0)->is_dynamic_dimension(i)); + } + + for (int64 i = 0; i < buffer_dims.size(); ++i) { + if (buffer_dims_are_dynamic[i]) { + // For each dynamic dimension, call set-dimension-size on it. + buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i); + } + } + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { return Combine(a, b); }; diff --git a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc index f7c33e57fa0..fc15d71dfd8 100644 --- a/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/segment_reduction_ops.cc @@ -116,12 +116,44 @@ class UnsortedSegmentSum : public XlaOpKernel { indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); + // data shape = [indices_shape, segment_shape] + // buffer shape = [num_segment, segment_shape] + // We now create the buffer shape by reverse enginerring data shape into + // indices shape and segment shape. TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); + auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_), buffer_shape.dim_sizes()); + // Build dynamic dim sizes for buffer, as well as whether each dimension + // size is dynamic or static. We build two parts: num_sgement part and + // segment_shape part. + std::vector buffer_dims; + std::vector buffer_dims_are_dynamic; + // Build the "num_segment" part. + bool num_segments_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic)); + + buffer_dims.insert(buffer_dims.begin(), ctx->Input(2)); + buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(), + num_segments_is_dynamic); + // Build the segment shape part. + for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) { + buffer_dims.push_back(xla::GetDimensionSize(data, i)); + buffer_dims_are_dynamic.push_back( + ctx->InputXlaShape(0)->is_dynamic_dimension(i)); + } + + for (int64 i = 0; i < buffer_dims.size(); ++i) { + if (buffer_dims_are_dynamic[i]) { + // For each dynamic dimension, call set-dimension-size on it. + buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i); + } + } + auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { return a + b; }; diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index 3103d73df6f..a835f5e5ac9 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -632,6 +632,34 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, # This assumes that there are exactly 2 replicas self.assertAllEqual([2, 1], run(next(input_iterator))) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.multidevice_strategies, + mode=["eager"])) + def testSegmentSumWithDynamicNumberOfSegments(self, distribution): + + def dataset_fn(_): + data = array_ops.zeros(5, dtype=dtypes.int32) + dataset = get_dataset_from_tensor_slices(data) + dataset = dataset.batch(3) + return dataset + + input_iterator = iter( + distribution.experimental_distribute_datasets_from_function(dataset_fn)) + + @def_function.function + def step_fn(example): + segment_ids = array_ops.zeros_like_v2(example) + num_segment = array_ops.shape(example)[0] + # If number of segments is dynamic, output should be a dynamic shape. + return math_ops.unsorted_segment_sum(example, segment_ids, num_segment) + + # This assumes that there are exactly 2 replicas + outputs = distribution.experimental_local_results( + distribution.run(step_fn, args=(next(input_iterator),))) + self.assertAllEqual((3,), outputs[0].shape) + self.assertAllEqual((2,), outputs[1].shape) + @combinations.generate( combinations.combine( distribution=strategy_combinations.multidevice_strategies,