Clean up linspace implementation, support num=0 and integer dtypes
PiperOrigin-RevId: 313679328 Change-Id: I610a0428b790de95d04fdde9f884a1493d8329a8
This commit is contained in:
parent
ee53f4e245
commit
eb40237008
@ -169,8 +169,6 @@ class LinearToMelTest(test.TestCase, parameterized.TestCase):
|
||||
return
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(num_mel_bins=0)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(num_spectrogram_bins=0)
|
||||
with self.assertRaises(ValueError):
|
||||
mel_ops.linear_to_mel_weight_matrix(sample_rate=0.0)
|
||||
with self.assertRaises(ValueError):
|
||||
|
@ -180,42 +180,38 @@ def linspace_nd(start, stop, num, name=None, axis=0):
|
||||
|
||||
axis = array_ops.where_v2(axis >= 0, axis, ndims + axis)
|
||||
|
||||
# The purpose is to avoid having negative values when repeating.
|
||||
num_fill = gen_math_ops.maximum(num_int - 2, 0)
|
||||
# To avoid having negative values in the range or zero division
|
||||
# the result is sliced in the end so a correct result is returned for
|
||||
# num == 1.
|
||||
# num == 1, and num == 0.
|
||||
n_steps = gen_math_ops.maximum(num_int - 1, 1)
|
||||
delta = (expanded_stop - expanded_start) / cast(n_steps,
|
||||
expanded_stop.dtype)
|
||||
# Re-cast tensors as delta.
|
||||
expanded_start = cast(expanded_start, delta.dtype)
|
||||
expanded_stop = cast(expanded_stop, delta.dtype)
|
||||
# If num < 0, we will throw exception in the range
|
||||
# otherwise use the same div for delta
|
||||
range_end = array_ops.where_v2(num_int > 0, n_steps, -1)
|
||||
range_end = array_ops.where_v2(num_int >= 0, n_steps, -1)
|
||||
# Even though range supports an output dtype, its limited
|
||||
# (e.g. doesn't support half at the moment).
|
||||
num_range = cast(range(1, range_end, dtype=dtypes.int64), start.dtype)
|
||||
shape_range = range(ndims)
|
||||
ones_like_shape_range = array_ops.ones_like(shape_range)
|
||||
axis_tiled = ones_like_shape_range * axis
|
||||
# the purpose is to avoid having negative values when repeating
|
||||
num_fill = gen_math_ops.maximum(num_int - 2, 0)
|
||||
num_tiled = array_ops.ones_like(shape_range) * num_fill
|
||||
ones = array_ops.ones_like(num_tiled)
|
||||
mask = gen_math_ops.equal(axis_tiled, shape_range)
|
||||
# reshape_target is [1. 1. 1. ... 1. num 1. 1. ... 1.], where the index
|
||||
# of num is equal to axis
|
||||
reshape_target = array_ops.where_v2(mask, num_fill, shape)
|
||||
delta_expanded = array_ops.reshape(delta, shape)
|
||||
delta_repeated = array_ops.broadcast_to(delta_expanded, reshape_target)
|
||||
start_repeated = array_ops.broadcast_to(expanded_start, reshape_target)
|
||||
desired_range = cast(range(1, range_end, dtype=dtypes.int64), delta.dtype)
|
||||
mask = gen_math_ops.equal(axis, range(ndims))
|
||||
# desired_range_shape is [1. 1. 1. ... 1. num_fill 1. 1. ... 1.], where the
|
||||
# index of num_fill is equal to axis.
|
||||
desired_range_shape = array_ops.where_v2(mask, num_fill, 1)
|
||||
desired_range = array_ops.reshape(desired_range, desired_range_shape)
|
||||
|
||||
expanded_shape = array_ops.where_v2(mask, num_fill, ones)
|
||||
range_indices = array_ops.reshape(num_range, expanded_shape)
|
||||
tiled_range_indices = array_ops.tile(range_indices, shape)
|
||||
res = start_repeated + delta_repeated * tiled_range_indices
|
||||
res = expanded_start + delta * desired_range
|
||||
|
||||
# Add the start and endpoints to the result, and slice out the desired
|
||||
# portion.
|
||||
all_tensors = (expanded_start, res, expanded_stop)
|
||||
concatenated = array_ops.concat(all_tensors, axis=axis)
|
||||
begin = array_ops.zeros_like(shape)
|
||||
num_slice = ones_like_shape_range * num_int
|
||||
size = array_ops.where_v2(mask, num_slice, shape)
|
||||
size = array_ops.where_v2(mask, num_int, shape)
|
||||
|
||||
return array_ops.slice(concatenated, begin, size)
|
||||
|
||||
|
||||
|
@ -798,12 +798,15 @@ class LinspaceTest(test_util.TensorFlowTestCase):
|
||||
|
||||
shapes = [(), (2,), (2, 2)]
|
||||
|
||||
types = [np.float64, np.int64]
|
||||
|
||||
for start_shape, stop_shape in itertools.product(shapes, repeat=2):
|
||||
for num in [1, 2, 20]:
|
||||
for num in [0, 1, 2, 20]:
|
||||
ndims = max(len(start_shape), len(stop_shape))
|
||||
for axis in range(-ndims, ndims):
|
||||
start = np.ones(start_shape)
|
||||
stop = 10 * np.ones(stop_shape)
|
||||
for dtype in types:
|
||||
start = np.ones(start_shape, dtype)
|
||||
stop = 10 * np.ones(stop_shape, dtype)
|
||||
|
||||
np_ans = np.linspace(start, stop, num, axis=axis)
|
||||
tf_ans = self.evaluate(
|
||||
|
Loading…
Reference in New Issue
Block a user