Added benchmarks for RaggedTensor.to_dense. Also added a few new test cases and improved readability of a few error messages.
PiperOrigin-RevId: 270302695
This commit is contained in:
parent
f380896421
commit
2b36d416c3
@ -527,10 +527,14 @@ TEST_F(RaggedTensorToTensorOpUnknownShapeTest, ValueRowIDs) {
|
||||
INFER_OK(*op_, "?;[6,2];?;[];[6]", "[?,?,2]");
|
||||
INFER_OK(*op_, "?;[6,2];[2];[];[6]", "[?,?,2]");
|
||||
INFER_OK(*op_, "?;[6,2,7];[2,7];[];[6]", "[?,?,2,7]");
|
||||
INFER_ERROR("default_value_shape and value_shape do not match", *op_,
|
||||
"?;[6,2];[3];[];[6]");
|
||||
INFER_ERROR("default_value_shape and value_shape do not match", *op_,
|
||||
"?;[6,2,1,2];[2,2];[];[6]");
|
||||
INFER_ERROR(
|
||||
"default_value.shape=[3] and rt_input.flat_values.shape=[6,2] "
|
||||
"are incompatible",
|
||||
*op_, "?;[6,2];[3];[];[6]");
|
||||
INFER_ERROR(
|
||||
"default_value.shape=[2,2] and rt_input.flat_values.shape="
|
||||
"[6,2,1,2] are incompatible",
|
||||
*op_, "?;[6,2,1,2];[2,2];[];[6]");
|
||||
INFER_ERROR("must be a vector", *op_, "?;[6];[];[];[3,6]");
|
||||
INFER_ERROR("must be a scalar", *op_, "?;[6];[];[7];[3]");
|
||||
}
|
||||
|
@ -91,10 +91,11 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
|
||||
}
|
||||
// At this point, value_shape and output_shape have known ranks.
|
||||
if (ragged_rank + value_shape.dim_size() != output_shape->dim_size()) {
|
||||
return InvalidArgument("Value shape (", value_shape.DebugString(),
|
||||
"), ragged_rank(", ragged_rank, ") and shape(",
|
||||
shape.DebugString(),
|
||||
") do not have a consistent number of dimensions");
|
||||
return InvalidArgument(
|
||||
"rt_input.shape and shape=", TensorShape::DebugString(shape),
|
||||
" are incompatible: rt_input.rank = ",
|
||||
ragged_rank + value_shape.dim_size(),
|
||||
" but shape.rank = ", output_shape->dim_size());
|
||||
}
|
||||
|
||||
for (int i = 1; i < value_shape.dim_size(); ++i) {
|
||||
@ -105,7 +106,11 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
|
||||
if (value_dim.size() >= 0) {
|
||||
if (output_shape_dim->size() >= 0) {
|
||||
if (output_shape_dim->size() != value_dim.size()) {
|
||||
return InvalidArgument("Value and shape dimension are inconsistent.");
|
||||
return InvalidArgument(
|
||||
"rt_input.shape and shape=", TensorShape::DebugString(shape),
|
||||
" are incompatible: rt_input.shape[", i + ragged_rank,
|
||||
"] = ", value_dim.size(), " but shape[", i + ragged_rank,
|
||||
"] = ", output_shape_dim->size());
|
||||
}
|
||||
} else {
|
||||
output_shape_dim->set_size(value_dim.size());
|
||||
@ -132,28 +137,29 @@ tensorflow::Status ValidateDefaultValueShape(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (default_value_shape.dim_size() > value_shape.dim_size()) {
|
||||
// TODO(martinz): This constraint is unnecessary. The
|
||||
// default value could have as many dimensions as shape. If there is a
|
||||
// discrepancy, it will be picked up when we broadcast the default value.
|
||||
// For now, I'll relax the constraint only slightly.
|
||||
int default_ndims = default_value_shape.dim_size();
|
||||
int values_ndims = value_shape.dim_size();
|
||||
if (default_ndims >= values_ndims) {
|
||||
return InvalidArgument(
|
||||
"default_value_shape must have no more dimensions than the value. "
|
||||
"default_value_shape: ",
|
||||
default_value_shape.DebugString(),
|
||||
" default_value_shape.dim_size(): ", default_value_shape.dim_size(),
|
||||
" value_shape: ", value_shape.DebugString(),
|
||||
" value_shape.dim_size(): ", value_shape.dim_size());
|
||||
"default_value.shape=", TensorShape::DebugString(default_value_shape),
|
||||
" and rt_input.flat_values.shape=",
|
||||
TensorShape::DebugString(value_shape),
|
||||
" are incompatible: default_value.rank = ", default_ndims,
|
||||
" must be less than rt_input.flat_values.rank = ", values_ndims);
|
||||
}
|
||||
for (int i = 0;
|
||||
i < std::min(default_value_shape.dim_size(), value_shape.dim_size() - 1);
|
||||
++i) {
|
||||
if (default_value_shape.dim(i).size() >= 0 &&
|
||||
value_shape.dim(i + 1).size() >= 0 &&
|
||||
default_value_shape.dim(i).size() != 1 &&
|
||||
default_value_shape.dim(i).size() != value_shape.dim(i + 1).size()) {
|
||||
for (int i = 0; i < std::min(default_ndims, values_ndims - 1); ++i) {
|
||||
int default_dim = default_value_shape.dim(i).size();
|
||||
int value_dim = value_shape.dim(i + 1).size();
|
||||
if (default_dim >= 0 && value_dim >= 0 && default_dim != 1 &&
|
||||
default_dim != value_dim) {
|
||||
return InvalidArgument(
|
||||
"default_value_shape and value_shape do not match on dimension ", i);
|
||||
"default_value.shape=", TensorShape::DebugString(default_value_shape),
|
||||
" and rt_input.flat_values.shape=",
|
||||
TensorShape::DebugString(value_shape),
|
||||
" are incompatible: default_value.shape[",
|
||||
i - default_value_shape.dim_size(), "] = ", default_dim,
|
||||
" but rt_input.flat_values.shape[",
|
||||
i - default_value_shape.dim_size(), "] = ", value_dim);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow/tools/test:performance.bzl", "tf_py_logged_benchmark")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -669,6 +670,11 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_logged_benchmark(
|
||||
name = "ragged_to_tensor_op_benchmark",
|
||||
target = "//tensorflow/python/ops/ragged:ragged_to_tensor_op_test",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_segment_op_test",
|
||||
srcs = ["ragged_segment_op_test.py"],
|
||||
|
@ -18,20 +18,70 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def make_placeholder(t):
|
||||
return array_ops.placeholder_with_default(t, None)
|
||||
|
||||
|
||||
def rebuild_ragged_tensor_with_value_rowids(rt, feed_dict=None, sess=None):
|
||||
"""Returns a copy of `rt`, built using `from_value_rowids`.
|
||||
|
||||
This ensures that RaggedTensor._cached_value_rowids is populated, which
|
||||
triggers a different code-path for converting ragged tensors to tensors.
|
||||
|
||||
If `feed_dict` and `sess` are specified, then build the new `RaggedTensor`
|
||||
using placeholder tensors, and populate a feed dictionary that can be used
|
||||
to feed the placeholders.
|
||||
|
||||
Args:
|
||||
rt: The RaggedTensor to copy.
|
||||
feed_dict: If specified, then build the new `RaggedTensor` using
|
||||
placeholders, and populate this dict with entries to feed those
|
||||
placeholders.
|
||||
sess: A session used to evaluate tensors; required if feed_dict is
|
||||
specified.
|
||||
|
||||
Returns:
|
||||
A copy of `rt`, built using `from_value_rowids`.
|
||||
"""
|
||||
if isinstance(rt, ragged_tensor.RaggedTensor):
|
||||
values = rebuild_ragged_tensor_with_value_rowids(rt.values, feed_dict, sess)
|
||||
rowids = rt.value_rowids()
|
||||
nrows = rt.nrows()
|
||||
if feed_dict is not None:
|
||||
rowids_ph = make_placeholder(rowids)
|
||||
nrows_ph = make_placeholder(nrows)
|
||||
feed_dict[rowids_ph] = sess.run(rowids)
|
||||
feed_dict[nrows_ph] = sess.run(nrows)
|
||||
rowids, nrows = rowids_ph, nrows_ph
|
||||
return ragged_tensor.RaggedTensor.from_value_rowids(values, rowids, nrows)
|
||||
else:
|
||||
if feed_dict is not None:
|
||||
rt_ph = make_placeholder(rt)
|
||||
feed_dict[rt_ph] = sess.run(rt)
|
||||
rt = rt_ph
|
||||
return rt
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -212,30 +262,77 @@ class RaggedTensorToTensorOpNewTest(test_util.TensorFlowTestCase,
|
||||
ragged_rank=None,
|
||||
default=None,
|
||||
expected_shape=None):
|
||||
rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
|
||||
dt = ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
|
||||
rt1 = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
|
||||
dt1 = ragged_conversion_ops.ragged_to_dense(rt1, default_value=default)
|
||||
rt2 = rebuild_ragged_tensor_with_value_rowids(rt1)
|
||||
dt2 = ragged_conversion_ops.ragged_to_dense(rt2, default_value=default)
|
||||
|
||||
self.assertIsInstance(dt, ops.Tensor)
|
||||
self.assertEqual(rt.dtype, dt.dtype)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
if expected_shape is not None:
|
||||
expected = np.ndarray(expected_shape, buffer=np.array(expected))
|
||||
self.assertAllEqual(dt, expected)
|
||||
for (rt, dt) in [(rt1, dt1), (rt2, dt2)]:
|
||||
self.assertIsInstance(dt, ops.Tensor)
|
||||
self.assertEqual(rt.dtype, dt.dtype)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
if expected_shape is not None:
|
||||
expected = np.ndarray(expected_shape, buffer=np.array(expected))
|
||||
self.assertAllEqual(dt, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'rt_input': [[1, 2, 3]],
|
||||
'default': 'a',
|
||||
'error': (TypeError, '.*'),
|
||||
}, {
|
||||
'error_type': TypeError,
|
||||
'error': r"Expected int32 passed to parameter 'default_value'|"
|
||||
r"Cannot convert 'a' to EagerTensor of dtype int32",
|
||||
},
|
||||
{
|
||||
'rt_input': [[1, 2, 3]],
|
||||
'default': 'b',
|
||||
'error': (TypeError, '.*'),
|
||||
})
|
||||
def testError(self, rt_input, default, error, ragged_rank=None):
|
||||
'default': [0],
|
||||
'error': r'default_value\.shape=\[1\] and '
|
||||
r'rt_input\.flat_values\.shape=\[3\] are incompatible: '
|
||||
r'default_value\.rank = 1 must be less than '
|
||||
r'rt_input\.flat_values\.rank = 1'
|
||||
},
|
||||
{
|
||||
'rt_input': [[[1, 2], [3, 4]], [[5, 6]]],
|
||||
'ragged_rank': 1,
|
||||
'default': [7, 8, 9],
|
||||
'error': r'default_value\.shape=\[3\] and '
|
||||
r'rt_input\.flat_values\.shape=\[3,2\] are incompatible: '
|
||||
r'default_value\.shape\[-1\] = 3 but '
|
||||
r'rt_input\.flat_values\.shape\[-1\] = 2'
|
||||
},
|
||||
{
|
||||
'rt_input': [[1, 2, 3]],
|
||||
'shape': [3, 3, 3],
|
||||
'error': r'rt_input\.shape and shape=\[.,.,.\] are incompatible: '
|
||||
r'rt_input\.rank = 2 but shape\.rank = 3'
|
||||
},
|
||||
{
|
||||
'rt_input': [[[1, 2, 3]]],
|
||||
'ragged_rank': 1,
|
||||
'shape': [1, 1, 4],
|
||||
'error': r'rt_input\.shape and shape=\[1,1,4\] are incompatible: '
|
||||
r'rt_input\.shape\[2\] = 3 but shape\[2\] = 4'
|
||||
},
|
||||
])
|
||||
def testError(self,
|
||||
rt_input,
|
||||
error,
|
||||
error_type=(ValueError, errors.InvalidArgumentError),
|
||||
default=None,
|
||||
ragged_rank=None,
|
||||
shape=None):
|
||||
|
||||
rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
|
||||
with self.assertRaisesRegexp(error[0], error[1]):
|
||||
ragged_conversion_ops.ragged_to_dense(rt, default_value=default)
|
||||
with self.assertRaisesRegexp(error_type, error):
|
||||
self.evaluate(
|
||||
ragged_conversion_ops.ragged_to_dense(
|
||||
rt, default_value=default, shape=shape))
|
||||
rt_placeholder = nest.map_structure(
|
||||
make_placeholder, rt, expand_composites=True)
|
||||
with self.assertRaisesRegexp(error_type, error):
|
||||
self.evaluate(
|
||||
ragged_conversion_ops.ragged_to_dense(
|
||||
rt_placeholder, default_value=default, shape=shape))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -454,7 +551,7 @@ class RaggedToTensorOpAdditionalTests(test_util.TensorFlowTestCase):
|
||||
input_data = ragged_factory_ops.constant([[[[1, 2], [3, 4]]], []],
|
||||
ragged_rank=1)
|
||||
# This placeholder has a 2 x 1 dimension.
|
||||
default_value = array_ops.placeholder_with_default([[5], [6]], shape=None)
|
||||
default_value = make_placeholder([[5], [6]])
|
||||
actual = ragged_conversion_ops.ragged_to_dense(
|
||||
input_data, default_value=default_value)
|
||||
expected = [[[[1, 2], [3, 4]]], [[[5, 5], [6, 6]]]]
|
||||
@ -489,5 +586,145 @@ class RaggedToTensorOpAdditionalTests(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(actual, [[3, 3, 3], [3, 3, 3]])
|
||||
|
||||
|
||||
class RaggedToDenseBenchmark(googletest.Benchmark):
|
||||
|
||||
# Configurations to test. See `run_benchmark` for config param docs.
|
||||
CONFIGS = [
|
||||
{'shape': [10, 10]},
|
||||
{'shape': [10, 1000]},
|
||||
{'shape': [1000, 10]},
|
||||
{'shape': [1000, 10], 'fill': [1, 0.95]}, # Mostly full.
|
||||
{'shape': [1000, 10], 'fill': [1, 0.05]}, # Mostly empty.
|
||||
{'shape': [1000, 10], 'dtype': dtypes.string},
|
||||
{'shape': [1000, 10], 'dtype': dtypes.int64},
|
||||
{'shape': [100, 100]},
|
||||
{'shape': [50, 50, 32]},
|
||||
{'shape': [100, 100, 100], 'min_iters': 100},
|
||||
{'shape': [1000, 1000], 'min_iters': 100},
|
||||
{'shape': [10, 10, 10, 10, 10]},
|
||||
{'shape': [10, 10, 10, 10, 10], 'ragged_rank': 1},
|
||||
{'shape': [10, 10, 10, 10, 10], 'ragged_rank': 2},
|
||||
{'shape': [50, 50, 32], 'ragged_rank': 1, 'default_shape': [32]},
|
||||
{'shape': [200, 50, 32], 'ragged_rank': 1, 'default_shape': [32]}
|
||||
] # pyformat: disable
|
||||
|
||||
def run_benchmark(self,
|
||||
shape=(100, 100),
|
||||
ragged_rank=None,
|
||||
dtype=dtypes.float32,
|
||||
fill=None,
|
||||
default_shape=(),
|
||||
output_shape=None,
|
||||
min_iters=1000):
|
||||
"""Run a benchmark with the specified configuraiton parameters.
|
||||
|
||||
Args:
|
||||
shape: Bounding box for the input ragged tensor.
|
||||
ragged_rank: Ragged rank for the input ragged tensor. Defauts to
|
||||
`len(shape)-1`.
|
||||
dtype: Data type for the input ragged tensor.
|
||||
fill: How full each dimension should be (0-1). Corresponds 1:1 with
|
||||
`shape`. Defaults to 0.8 for each dimension.
|
||||
default_shape: Shape for the default (padding) value.
|
||||
output_shape: Output shape -- ragged tensor will be padded or cropped to
|
||||
this shape.
|
||||
min_iters: Minimum iterations for benchmark.
|
||||
"""
|
||||
if ragged_rank is None:
|
||||
ragged_rank = len(shape) - 1
|
||||
if fill is None:
|
||||
fill = [0.8 for _ in shape]
|
||||
|
||||
# Build the inputs for the op.
|
||||
rt_input = self._generateRaggedTensor(shape, ragged_rank, dtype, fill)
|
||||
default_value = constant_op.constant(
|
||||
self._generateRaggedTensor(default_shape, 0, dtype), dtype=dtype)
|
||||
|
||||
mbs = np.prod(shape) / (2**20)
|
||||
with session.Session(config=benchmark.benchmark_config()) as sess:
|
||||
extras = {
|
||||
'shape': shape,
|
||||
'ragged_rank': ragged_rank,
|
||||
'dtype': dtype,
|
||||
'fill': fill,
|
||||
'default_shape': default_shape
|
||||
}
|
||||
rt = ragged_factory_ops.constant(rt_input, dtype, ragged_rank=ragged_rank)
|
||||
|
||||
# Inputs for with_splits:
|
||||
splits_rt_placeholder = ragged_factory_ops.placeholder(
|
||||
dtype, ragged_rank, shape[ragged_rank + 1:])
|
||||
splits_feed_dict = {splits_rt_placeholder: sess.run(rt)}
|
||||
|
||||
# Inputs for with_rowids:
|
||||
rowids_feed_dict = {}
|
||||
rowids_rt_placeholder = rebuild_ragged_tensor_with_value_rowids(
|
||||
rt, rowids_feed_dict, sess)
|
||||
|
||||
# Common arguments for benchmarks:
|
||||
run_op_benchmark_kwargs = dict(
|
||||
sess=sess,
|
||||
store_memory_usage=True,
|
||||
min_iters=min_iters,
|
||||
burn_iters=max(5, min_iters // 10),
|
||||
mbs=mbs,
|
||||
extras=extras)
|
||||
|
||||
ragged_to_dense_with_splits = ragged_conversion_ops.ragged_to_dense(
|
||||
splits_rt_placeholder, default_value=default_value)
|
||||
self.run_op_benchmark(
|
||||
op_or_tensor=ragged_to_dense_with_splits.op,
|
||||
name='ragged_to_dense_with_splits',
|
||||
feed_dict=splits_feed_dict,
|
||||
**run_op_benchmark_kwargs)
|
||||
|
||||
ragged_to_tensor_with_splits = splits_rt_placeholder.to_tensor(
|
||||
default_value=default_value)
|
||||
self.run_op_benchmark(
|
||||
op_or_tensor=ragged_to_tensor_with_splits.op,
|
||||
name='ragged_to_tensor_with_splits',
|
||||
feed_dict=splits_feed_dict,
|
||||
**run_op_benchmark_kwargs)
|
||||
|
||||
ragged_to_dense_with_rowids = ragged_conversion_ops.ragged_to_dense(
|
||||
rowids_rt_placeholder, default_value=default_value)
|
||||
self.run_op_benchmark(
|
||||
op_or_tensor=ragged_to_dense_with_rowids.op,
|
||||
name='ragged_to_dense_with_rowids',
|
||||
feed_dict=rowids_feed_dict,
|
||||
**run_op_benchmark_kwargs)
|
||||
|
||||
ragged_to_tensor_with_rowids = rowids_rt_placeholder.to_tensor(
|
||||
default_value=default_value)
|
||||
self.run_op_benchmark(
|
||||
op_or_tensor=ragged_to_tensor_with_rowids.op,
|
||||
name='ragged_to_tensor_with_rowids',
|
||||
feed_dict=rowids_feed_dict,
|
||||
**run_op_benchmark_kwargs)
|
||||
|
||||
def _generateRaggedTensor(self, shape, ragged_rank, dtype, fill=None, axis=0):
|
||||
if axis == len(shape):
|
||||
value = random.random()
|
||||
if dtype == dtypes.string:
|
||||
value = str(value)
|
||||
if dtype.is_integer:
|
||||
value = int(value * 1000)
|
||||
return value
|
||||
if axis == 0 or axis > ragged_rank:
|
||||
slice_size = shape[axis]
|
||||
else:
|
||||
slice_size = (np.random.geometric(fill[axis], shape[axis]) == 1).sum()
|
||||
return [
|
||||
self._generateRaggedTensor(shape, ragged_rank, dtype, fill, axis + 1)
|
||||
for _ in range(slice_size)
|
||||
]
|
||||
|
||||
def benchmark_ragged_to_dense(self):
|
||||
random.seed(5)
|
||||
for config in self.CONFIGS:
|
||||
self.run_benchmark(**config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user