Merge pull request #24282 from yongtang:15692-draw_bounding_boxes-colors
PiperOrigin-RevId: 242188454
This commit is contained in:
commit
e3eec7a4da
@ -0,0 +1,43 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DrawBoundingBoxesV2"
|
||||||
|
in_arg {
|
||||||
|
name: "images"
|
||||||
|
description: <<END
|
||||||
|
4-D with shape `[batch, height, width, depth]`. A batch of images.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "boxes"
|
||||||
|
description: <<END
|
||||||
|
3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
|
||||||
|
boxes.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "colors"
|
||||||
|
description: <<END
|
||||||
|
2-D. A list of RGBA colors to cycle through for the boxes.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
4-D with the same shape as `images`. The batch of input images with
|
||||||
|
bounding boxes drawn on the images.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Draw bounding boxes on a batch of images."
|
||||||
|
description: <<END
|
||||||
|
Outputs a copy of `images` but draws on top of the pixels zero or more bounding
|
||||||
|
boxes specified by the locations in `boxes`. The coordinates of the each
|
||||||
|
bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
|
||||||
|
bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
|
||||||
|
height of the underlying image.
|
||||||
|
|
||||||
|
For example, if an image is 100 x 200 pixels (height x width) and the bounding
|
||||||
|
box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
|
||||||
|
the bounding box will be `(40, 10)` to `(100, 50)` (in (x,y) coordinates).
|
||||||
|
|
||||||
|
Parts of the bounding box may fall outside the image.
|
||||||
|
END
|
||||||
|
}
|
@ -1,6 +1,4 @@
|
|||||||
op {
|
op {
|
||||||
graph_op_name: "DrawBoundingBoxes"
|
graph_op_name: "DrawBoundingBoxes"
|
||||||
endpoint {
|
visibility: HIDDEN
|
||||||
name: "image.draw_bounding_boxes"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DrawBoundingBoxesV2"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -25,6 +25,30 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> DefaultColorTable(int depth) {
|
||||||
|
std::vector<std::vector<float>> color_table;
|
||||||
|
color_table.emplace_back(std::vector<float>({1, 1, 0, 1})); // 0: yellow
|
||||||
|
color_table.emplace_back(std::vector<float>({0, 0, 1, 1})); // 1: blue
|
||||||
|
color_table.emplace_back(std::vector<float>({1, 0, 0, 1})); // 2: red
|
||||||
|
color_table.emplace_back(std::vector<float>({0, 1, 0, 1})); // 3: lime
|
||||||
|
color_table.emplace_back(std::vector<float>({0.5, 0, 0.5, 1})); // 4: purple
|
||||||
|
color_table.emplace_back(std::vector<float>({0.5, 0.5, 0, 1})); // 5: olive
|
||||||
|
color_table.emplace_back(std::vector<float>({0.5, 0, 0, 1})); // 6: maroon
|
||||||
|
color_table.emplace_back(std::vector<float>({0, 0, 0.5, 1})); // 7: navy blue
|
||||||
|
color_table.emplace_back(std::vector<float>({0, 1, 1, 1})); // 8: aqua
|
||||||
|
color_table.emplace_back(std::vector<float>({1, 0, 1, 1})); // 9: fuchsia
|
||||||
|
|
||||||
|
if (depth == 1) {
|
||||||
|
for (int64 i = 0; i < color_table.size(); i++) {
|
||||||
|
color_table[i][0] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return color_table;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class DrawBoundingBoxesOp : public OpKernel {
|
class DrawBoundingBoxesOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -52,30 +76,31 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
const int64 batch_size = images.dim_size(0);
|
const int64 batch_size = images.dim_size(0);
|
||||||
const int64 height = images.dim_size(1);
|
const int64 height = images.dim_size(1);
|
||||||
const int64 width = images.dim_size(2);
|
const int64 width = images.dim_size(2);
|
||||||
const int64 color_table_length = 10;
|
std::vector<std::vector<float>> color_table;
|
||||||
|
if (context->num_inputs() == 3) {
|
||||||
|
const Tensor& colors_tensor = context->input(2);
|
||||||
|
OP_REQUIRES(context, colors_tensor.shape().dims() == 2,
|
||||||
|
errors::InvalidArgument("colors must be a 2-D matrix",
|
||||||
|
colors_tensor.shape().DebugString()));
|
||||||
|
OP_REQUIRES(context, colors_tensor.shape().dim_size(1) >= depth,
|
||||||
|
errors::InvalidArgument("colors must have equal or more ",
|
||||||
|
"channels than the image provided: ",
|
||||||
|
colors_tensor.shape().DebugString()));
|
||||||
|
if (colors_tensor.NumElements() != 0) {
|
||||||
|
color_table.clear();
|
||||||
|
|
||||||
// 0: yellow
|
auto colors = colors_tensor.matrix<float>();
|
||||||
// 1: blue
|
for (int64 i = 0; i < colors.dimension(0); i++) {
|
||||||
// 2: red
|
std::vector<float> color_value(4);
|
||||||
// 3: lime
|
for (int64 j = 0; j < 4; j++) {
|
||||||
// 4: purple
|
color_value[j] = colors(i, j);
|
||||||
// 5: olive
|
|
||||||
// 6: maroon
|
|
||||||
// 7: navy blue
|
|
||||||
// 8: aqua
|
|
||||||
// 9: fuchsia
|
|
||||||
float color_table[color_table_length][4] = {
|
|
||||||
{1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1},
|
|
||||||
{0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1},
|
|
||||||
{0, 1, 1, 1}, {1, 0, 1, 1},
|
|
||||||
};
|
|
||||||
|
|
||||||
// Reset first color channel to 1 if image is GRY.
|
|
||||||
// For GRY images, this means all bounding boxes will be white.
|
|
||||||
if (depth == 1) {
|
|
||||||
for (int64 i = 0; i < color_table_length; i++) {
|
|
||||||
color_table[i][0] = 1;
|
|
||||||
}
|
}
|
||||||
|
color_table.emplace_back(color_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (color_table.empty()) {
|
||||||
|
color_table = DefaultColorTable(depth);
|
||||||
}
|
}
|
||||||
Tensor* output;
|
Tensor* output;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
@ -90,7 +115,7 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
const int64 num_boxes = boxes.dim_size(1);
|
const int64 num_boxes = boxes.dim_size(1);
|
||||||
const auto tboxes = boxes.tensor<T, 3>();
|
const auto tboxes = boxes.tensor<T, 3>();
|
||||||
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
for (int64 bb = 0; bb < num_boxes; ++bb) {
|
||||||
int64 color_index = bb % color_table_length;
|
int64 color_index = bb % color_table.size();
|
||||||
const int64 min_box_row =
|
const int64 min_box_row =
|
||||||
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
|
||||||
const int64 min_box_row_clamp = std::max<int64>(min_box_row, int64{0});
|
const int64 min_box_row_clamp = std::max<int64>(min_box_row, int64{0});
|
||||||
@ -179,6 +204,9 @@ class DrawBoundingBoxesOp : public OpKernel {
|
|||||||
#define REGISTER_CPU_KERNEL(T) \
|
#define REGISTER_CPU_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
|
DrawBoundingBoxesOp<T>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("DrawBoundingBoxesV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||||
DrawBoundingBoxesOp<T>);
|
DrawBoundingBoxesOp<T>);
|
||||||
TF_CALL_half(REGISTER_CPU_KERNEL);
|
TF_CALL_half(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_float(REGISTER_CPU_KERNEL);
|
TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||||
|
@ -601,6 +601,17 @@ REGISTER_OP("DrawBoundingBoxes")
|
|||||||
return shape_inference::UnchangedShape(c);
|
return shape_inference::UnchangedShape(c);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
REGISTER_OP("DrawBoundingBoxesV2")
|
||||||
|
.Input("images: T")
|
||||||
|
.Input("boxes: float")
|
||||||
|
.Input("colors: float")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: {float, half} = DT_FLOAT")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
|
||||||
|
});
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("SampleDistortedBoundingBox")
|
REGISTER_OP("SampleDistortedBoundingBox")
|
||||||
.Input("image_size: T")
|
.Input("image_size: T")
|
||||||
|
@ -54,15 +54,18 @@ class DrawBoundingBoxOpTest(test.TestCase):
|
|||||||
image[height - 1, 0:width, 0:depth] = color
|
image[height - 1, 0:width, 0:depth] = color
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _testDrawBoundingBoxColorCycling(self, img):
|
def _testDrawBoundingBoxColorCycling(self, img, colors=None):
|
||||||
"""Tests if cycling works appropriately.
|
"""Tests if cycling works appropriately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img: 3-D numpy image on which to draw.
|
img: 3-D numpy image on which to draw.
|
||||||
"""
|
"""
|
||||||
|
color_table = colors
|
||||||
|
if colors is None:
|
||||||
# THIS TABLE MUST MATCH draw_bounding_box_op.cc
|
# THIS TABLE MUST MATCH draw_bounding_box_op.cc
|
||||||
color_table = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 0, 1],
|
color_table = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 0, 1],
|
||||||
[0, 1, 0, 1], [0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1],
|
[0, 1, 0, 1], [0.5, 0, 0.5,
|
||||||
|
1], [0.5, 0.5, 0, 1],
|
||||||
[0.5, 0, 0, 1], [0, 0, 0.5, 1], [0, 1, 1, 1],
|
[0.5, 0, 0, 1], [0, 0, 0.5, 1], [0, 1, 1, 1],
|
||||||
[1, 0, 1, 1]])
|
[1, 0, 1, 1]])
|
||||||
assert len(img.shape) == 3
|
assert len(img.shape) == 3
|
||||||
@ -85,9 +88,9 @@ class DrawBoundingBoxOpTest(test.TestCase):
|
|||||||
image = ops.convert_to_tensor(image)
|
image = ops.convert_to_tensor(image)
|
||||||
image = image_ops_impl.convert_image_dtype(image, dtypes.float32)
|
image = image_ops_impl.convert_image_dtype(image, dtypes.float32)
|
||||||
image = array_ops.expand_dims(image, 0)
|
image = array_ops.expand_dims(image, 0)
|
||||||
image = image_ops.draw_bounding_boxes(image, bboxes)
|
image = image_ops.draw_bounding_boxes(image, bboxes, colors=colors)
|
||||||
with self.cached_session(use_gpu=False) as sess:
|
with self.cached_session(use_gpu=False) as sess:
|
||||||
op_drawn_image = np.squeeze(self.evaluate(image), 0)
|
op_drawn_image = np.squeeze(sess.run(image), 0)
|
||||||
self.assertAllEqual(test_drawn_image, op_drawn_image)
|
self.assertAllEqual(test_drawn_image, op_drawn_image)
|
||||||
|
|
||||||
def testDrawBoundingBoxRGBColorCycling(self):
|
def testDrawBoundingBoxRGBColorCycling(self):
|
||||||
@ -105,6 +108,20 @@ class DrawBoundingBoxOpTest(test.TestCase):
|
|||||||
image = np.zeros([4, 4, 1], "float32")
|
image = np.zeros([4, 4, 1], "float32")
|
||||||
self._testDrawBoundingBoxColorCycling(image)
|
self._testDrawBoundingBoxColorCycling(image)
|
||||||
|
|
||||||
|
def testDrawBoundingBoxRGBColorCyclingWithColors(self):
|
||||||
|
"""Test if RGB color cycling works correctly with provided colors."""
|
||||||
|
image = np.zeros([10, 10, 3], "float32")
|
||||||
|
colors = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], [0.5, 0, 0.5, 1],
|
||||||
|
[0.5, 0.5, 0, 1], [0, 1, 1, 1], [1, 0, 1, 1]])
|
||||||
|
self._testDrawBoundingBoxColorCycling(image, colors=colors)
|
||||||
|
|
||||||
|
def testDrawBoundingBoxRGBAColorCyclingWithColors(self):
|
||||||
|
"""Test if RGBA color cycling works correctly with provided colors."""
|
||||||
|
image = np.zeros([10, 10, 4], "float32")
|
||||||
|
colors = np.asarray([[0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], [0.5, 0, 0, 1],
|
||||||
|
[0, 0, 0.5, 1]])
|
||||||
|
self._testDrawBoundingBoxColorCycling(image, colors=colors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -3355,7 +3355,6 @@ def crop_and_resize_v1( # pylint: disable=missing-docstring
|
|||||||
|
|
||||||
crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__
|
crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=['image.extract_glimpse'])
|
@tf_export(v1=['image.extract_glimpse'])
|
||||||
def extract_glimpse(
|
def extract_glimpse(
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
@ -3555,3 +3554,65 @@ def combined_non_max_suppression(boxes,
|
|||||||
return gen_image_ops.combined_non_max_suppression(
|
return gen_image_ops.combined_non_max_suppression(
|
||||||
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
|
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
|
||||||
score_threshold, pad_per_class)
|
score_threshold, pad_per_class)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.draw_bounding_boxes', v1=[])
|
||||||
|
def draw_bounding_boxes_v2(images, boxes, colors, name=None):
|
||||||
|
"""Draw bounding boxes on a batch of images.
|
||||||
|
|
||||||
|
Outputs a copy of `images` but draws on top of the pixels zero or more
|
||||||
|
bounding boxes specified by the locations in `boxes`. The coordinates of the
|
||||||
|
each bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`.
|
||||||
|
The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width
|
||||||
|
and height of the underlying image.
|
||||||
|
|
||||||
|
For example, if an image is 100 x 200 pixels (height x width) and the bounding
|
||||||
|
box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
|
||||||
|
the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
|
||||||
|
|
||||||
|
Parts of the bounding box may fall outside the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: A `Tensor`. Must be one of the following types: `float32`, `half`.
|
||||||
|
4-D with shape `[batch, height, width, depth]`. A batch of images.
|
||||||
|
boxes: A `Tensor` of type `float32`. 3-D with shape `[batch,
|
||||||
|
num_bounding_boxes, 4]` containing bounding boxes.
|
||||||
|
colors: A `Tensor` of type `float32`. 2-D. A list of RGBA colors to cycle
|
||||||
|
through for the boxes.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor`. Has the same type as `images`.
|
||||||
|
"""
|
||||||
|
if colors is None and not compat.forward_compatible(2019, 5, 1):
|
||||||
|
return gen_image_ops.draw_bounding_boxes(images, boxes, name)
|
||||||
|
return gen_image_ops.draw_bounding_boxes_v2(images, boxes, colors, name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=['image.draw_bounding_boxes'])
|
||||||
|
def draw_bounding_boxes(images, boxes, name=None, colors=None):
|
||||||
|
"""Draw bounding boxes on a batch of images.
|
||||||
|
|
||||||
|
Outputs a copy of `images` but draws on top of the pixels zero or more
|
||||||
|
bounding boxes specified by the locations in `boxes`. The coordinates of the
|
||||||
|
each bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`.
|
||||||
|
The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width
|
||||||
|
and height of the underlying image.
|
||||||
|
|
||||||
|
For example, if an image is 100 x 200 pixels (height x width) and the bounding
|
||||||
|
box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
|
||||||
|
the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
|
||||||
|
|
||||||
|
Parts of the bounding box may fall outside the image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: A `Tensor`. Must be one of the following types: `float32`, `half`.
|
||||||
|
4-D with shape `[batch, height, width, depth]`. A batch of images.
|
||||||
|
boxes: A `Tensor` of type `float32`. 3-D with shape `[batch,
|
||||||
|
num_bounding_boxes, 4]` containing bounding boxes.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor`. Has the same type as `images`.
|
||||||
|
"""
|
||||||
|
return draw_bounding_boxes_v2(images, boxes, colors, name)
|
||||||
|
@ -74,7 +74,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "draw_bounding_boxes"
|
name: "draw_bounding_boxes"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'name\', \'colors\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "encode_jpeg"
|
name: "encode_jpeg"
|
||||||
|
@ -936,6 +936,10 @@ tf_module {
|
|||||||
name: "DrawBoundingBoxes"
|
name: "DrawBoundingBoxes"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DrawBoundingBoxesV2"
|
||||||
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DynamicPartition"
|
name: "DynamicPartition"
|
||||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -74,7 +74,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "draw_bounding_boxes"
|
name: "draw_bounding_boxes"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "encode_jpeg"
|
name: "encode_jpeg"
|
||||||
|
@ -936,6 +936,10 @@ tf_module {
|
|||||||
name: "DrawBoundingBoxes"
|
name: "DrawBoundingBoxes"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DrawBoundingBoxesV2"
|
||||||
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DynamicPartition"
|
name: "DynamicPartition"
|
||||||
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'data\', \'partitions\', \'num_partitions\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user