Improve performance of clip_by_value when the input is IndexedSlices
This commit is contained in:
parent
e3ea52501e
commit
1b323d7010
@ -150,6 +150,41 @@ class ClipTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertAllClose(np_ans, tf_ans)
|
self.assertAllClose(np_ans, tf_ans)
|
||||||
|
|
||||||
|
def _testClipIndexedSlicesByValue(self, values, indices, shape,
|
||||||
|
clip_value_min, clip_value_max, expected):
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
values = constant_op.constant(values)
|
||||||
|
indices = constant_op.constant(indices)
|
||||||
|
shape = constant_op.constant(shape)
|
||||||
|
# IndexedSlices mode
|
||||||
|
indixed_slices = ops.IndexedSlices(values, indices, shape)
|
||||||
|
clipped = clip_ops.clip_by_value(indixed_slices, clip_value_min,
|
||||||
|
clip_value_max)
|
||||||
|
# clipped should be IndexedSlices
|
||||||
|
self.assertIsInstance(clipped, ops.IndexedSlices)
|
||||||
|
|
||||||
|
self.assertAllClose(clipped.values, expected)
|
||||||
|
|
||||||
|
def testClipByValueWithIndexedSlicesClipped(self):
|
||||||
|
values = [[[-3.0, 0.0, 0.0], [4.0, 0.0, 0.0]],
|
||||||
|
[[0.0, 2.0, 0.0], [0.0, 0.0, -1.0]]]
|
||||||
|
indices = [2, 6]
|
||||||
|
shape = [10, 2, 3]
|
||||||
|
# [-2.0, 2.0]
|
||||||
|
self._testClipIndexedSlicesByValue(values, indices, shape, -2.0, 2.0,
|
||||||
|
[[[-2.0, 0.0, 0.0], [2.0, 0.0, 0.0]],
|
||||||
|
[[0.0, 2.0, 0.0], [0.0, 0.0, -1.0]]])
|
||||||
|
# [1.0, 2.0]
|
||||||
|
self._testClipIndexedSlicesByValue(values, indices, shape, 1.0, 2.0,
|
||||||
|
[[[1.0, 1.0, 1.0], [2.0, 1.0, 1.0]],
|
||||||
|
[[1.0, 2.0, 1.0], [1.0, 1.0, 1.0]]])
|
||||||
|
# [-2.0, -1.0]
|
||||||
|
self._testClipIndexedSlicesByValue(values, indices, shape, -2.0, -1.0,
|
||||||
|
[[[-2.0, -1.0, -1.0],
|
||||||
|
[-1.0, -1.0, -1.0]],
|
||||||
|
[[-1.0, -1.0, -1.0],
|
||||||
|
[-1.0, -1.0, -1.0]]])
|
||||||
|
|
||||||
# ClipByNorm tests
|
# ClipByNorm tests
|
||||||
def testClipByNormClipped(self):
|
def testClipByNormClipped(self):
|
||||||
# Norm clipping when clip_norm < 5
|
# Norm clipping when clip_norm < 5
|
||||||
|
@ -50,7 +50,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
|
|||||||
correct results.
|
correct results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t: A `Tensor`.
|
t: A `Tensor` or `IndexedSlices`.
|
||||||
clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
||||||
as `t`. The minimum value to clip by.
|
as `t`. The minimum value to clip by.
|
||||||
clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape
|
||||||
@ -58,7 +58,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
|
|||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A clipped `Tensor`.
|
A clipped `Tensor` or `IndexedSlices`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the clip tensors would trigger array broadcasting
|
ValueError: If the clip tensors would trigger array broadcasting
|
||||||
@ -66,16 +66,20 @@ def clip_by_value(t, clip_value_min, clip_value_max,
|
|||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "clip_by_value",
|
with ops.name_scope(name, "clip_by_value",
|
||||||
[t, clip_value_min, clip_value_max]) as name:
|
[t, clip_value_min, clip_value_max]) as name:
|
||||||
t = ops.convert_to_tensor(t, name="t")
|
values = ops.convert_to_tensor(
|
||||||
|
t.values if isinstance(t, ops.IndexedSlices) else t, name="t")
|
||||||
|
|
||||||
# Go through list of tensors, for each value in each tensor clip
|
# Go through list of tensors, for each value in each tensor clip
|
||||||
t_min = math_ops.minimum(t, clip_value_max)
|
t_min = math_ops.minimum(values, clip_value_max)
|
||||||
# Assert that the shape is compatible with the initial shape,
|
# Assert that the shape is compatible with the initial shape,
|
||||||
# to prevent unintentional broadcasting.
|
# to prevent unintentional broadcasting.
|
||||||
_ = t.shape.merge_with(t_min.shape)
|
_ = values.shape.merge_with(t_min.shape)
|
||||||
|
|
||||||
t_max = math_ops.maximum(t_min, clip_value_min, name=name)
|
t_max = math_ops.maximum(t_min, clip_value_min, name=name)
|
||||||
_ = t.shape.merge_with(t_max.shape)
|
_ = values.shape.merge_with(t_max.shape)
|
||||||
|
|
||||||
|
if isinstance(t, ops.IndexedSlices):
|
||||||
|
t_max = ops.IndexedSlices(t_max, t.indices, t.dense_shape)
|
||||||
|
|
||||||
return t_max
|
return t_max
|
||||||
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
|
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
|
||||||
|
Loading…
Reference in New Issue
Block a user