Exposes the already-implemented variable scatter_div, scatter_mul, scatter_min, and scatter_max operations in the variables API. Also fixes some documentation, and adds tests that directly check the var.scatter_* methods.

PiperOrigin-RevId: 251565839
This commit is contained in:
A. Unique TensorFlower 2019-06-04 19:43:31 -07:00 committed by TensorFlower Gardener
parent 1535bc638c
commit 76171c3cba
7 changed files with 438 additions and 141 deletions

View File

@ -456,6 +456,74 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
@test_util.run_in_graph_and_eager_modes
def testScatterAddVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 1.5], name="add")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_add(ops.IndexedSlices(indices=[1], values=[2.5])))
self.assertAllEqual([0.0, 4.0], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterSubVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 2.5], name="sub")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_sub(ops.IndexedSlices(indices=[1], values=[1.5])))
self.assertAllEqual([0.0, 1.0], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterMaxVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="max1")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_max(ops.IndexedSlices(indices=[1], values=[5.0])))
self.assertAllEqual([0.0, 5.0], self.evaluate(v))
v = resource_variable_ops.ResourceVariable([0.0, 3.5], name="max2")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_max(ops.IndexedSlices(indices=[1], values=[2.0])))
self.assertAllEqual([0.0, 3.5], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterMinVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="min1")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_min(ops.IndexedSlices(indices=[1], values=[5.0])))
self.assertAllEqual([0.0, 4.0], self.evaluate(v))
v = resource_variable_ops.ResourceVariable([0.0, 3.5], name="min2")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_min(ops.IndexedSlices(indices=[1], values=[2.0])))
self.assertAllEqual([0.0, 2.0], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterMulVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="mul")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_mul(ops.IndexedSlices(indices=[1], values=[3.0])))
self.assertAllEqual([0.0, 12.0], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterDivVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 6.0], name="div")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_div(ops.IndexedSlices(indices=[1], values=[2.0])))
self.assertAllEqual([0.0, 3.0], self.evaluate(v))
@test_util.run_in_graph_and_eager_modes
def testScatterUpdateVariableMethod(self):
v = resource_variable_ops.ResourceVariable([0.0, 6.0], name="update")
self.evaluate(variables.global_variables_initializer())
self.evaluate(
v.scatter_update(ops.IndexedSlices(indices=[1], values=[3.0])))
self.assertAllEqual([0.0, 3.0], self.evaluate(v))
@test_util.run_deprecated_v1
def testScatterUpdateString(self):
handle = resource_variable_ops.var_handle_op(

View File

@ -179,10 +179,10 @@ class ScatterTest(test.TestCase):
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = variables.VariableV1(old, use_resource=False)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
ref = variables.Variable(old)
self.evaluate(ref.initializer)
self.evaluate(tf_scatter(ref, indices, updates))
self.assertAllClose(self.evaluate(ref), new)
def _VariableRankTests(self,
tf_scatter,
@ -192,157 +192,125 @@ class ScatterTest(test.TestCase):
if tf_scatter != state_ops.scatter_div:
vtypes.append(np.int32)
if (tf_scatter == state_ops.scatter_min or
tf_scatter == state_ops.scatter_max):
vtypes.append(np.float16)
for vtype in vtypes:
for itype in (np.int32, np.int64):
self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
updates_are_scalar)
@test_util.run_deprecated_v1
def testVariableRankUpdate(self):
self._VariableRankTests(state_ops.scatter_update, False)
@test_util.run_deprecated_v1
def testVariableRankAdd(self):
self._VariableRankTests(state_ops.scatter_add, False)
@test_util.run_deprecated_v1
def testVariableRankSub(self):
self._VariableRankTests(state_ops.scatter_sub, False)
@test_util.run_deprecated_v1
def testVariableRankMul(self):
self._VariableRankTests(state_ops.scatter_mul, False)
@test_util.run_deprecated_v1
def testVariableRankDiv(self):
self._VariableRankTests(state_ops.scatter_div, False)
@test_util.run_deprecated_v1
def testVariableRankMin(self):
self._VariableRankTests(state_ops.scatter_min, False)
@test_util.run_deprecated_v1
def testVariableRankMax(self):
self._VariableRankTests(state_ops.scatter_max, False)
@test_util.run_deprecated_v1
def testRepeatIndicesAdd(self):
self._VariableRankTests(state_ops.scatter_add, True)
@test_util.run_deprecated_v1
def testRepeatIndicesSub(self):
self._VariableRankTests(state_ops.scatter_sub, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMul(self):
self._VariableRankTests(state_ops.scatter_mul, True)
@test_util.run_deprecated_v1
def testRepeatIndicesDiv(self):
self._VariableRankTests(state_ops.scatter_div, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMin(self):
self._VariableRankTests(state_ops.scatter_min, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMax(self):
self._VariableRankTests(state_ops.scatter_max, True)
@test_util.run_deprecated_v1
def testVariableRankUpdateScalar(self):
self._VariableRankTests(state_ops.scatter_update, False, True)
@test_util.run_deprecated_v1
def testVariableRankAddScalar(self):
self._VariableRankTests(state_ops.scatter_add, False, True)
@test_util.run_deprecated_v1
def testVariableRankSubScalar(self):
self._VariableRankTests(state_ops.scatter_sub, False, True)
@test_util.run_deprecated_v1
def testVariableRankMulScalar(self):
self._VariableRankTests(state_ops.scatter_mul, False, True)
@test_util.run_deprecated_v1
def testVariableRankDivScalar(self):
self._VariableRankTests(state_ops.scatter_div, False, True)
@test_util.run_deprecated_v1
def testVariableRankMinScalar(self):
self._VariableRankTests(state_ops.scatter_min, False, True)
@test_util.run_deprecated_v1
def testVariableRankMaxScalar(self):
self._VariableRankTests(state_ops.scatter_max, False, True)
@test_util.run_deprecated_v1
def testRepeatIndicesAddScalar(self):
self._VariableRankTests(state_ops.scatter_add, True, True)
@test_util.run_deprecated_v1
def testRepeatIndicesSubScalar(self):
self._VariableRankTests(state_ops.scatter_sub, True, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMulScalar(self):
self._VariableRankTests(state_ops.scatter_mul, True, True)
@test_util.run_deprecated_v1
def testRepeatIndicesDivScalar(self):
self._VariableRankTests(state_ops.scatter_div, True, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMinScalar(self):
self._VariableRankTests(state_ops.scatter_min, True, True)
@test_util.run_deprecated_v1
def testRepeatIndicesMaxScalar(self):
self._VariableRankTests(state_ops.scatter_max, True, True)
@test_util.run_deprecated_v1
def testBooleanScatterUpdate(self):
if not test.is_gpu_available():
with self.session(use_gpu=False) as session:
with self.session(use_gpu=False):
var = variables.Variable([True, False])
update0 = state_ops.scatter_update(var, 1, True)
update1 = state_ops.scatter_update(
var, constant_op.constant(
0, dtype=dtypes.int64), False)
var.initializer.run()
self.evaluate(var.initializer)
session.run([update0, update1])
self.evaluate([update0, update1])
self.assertAllEqual([False, True], self.evaluate(var))
@test_util.run_deprecated_v1
def testScatterOutOfRangeCpu(self):
for op, _ in _TF_OPS_TO_NUMPY.items():
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
if not test.is_gpu_available():
with self.session(use_gpu=False):
ref = variables.VariableV1(params, use_resource=False)
ref.initializer.run()
ref = variables.Variable(params)
self.evaluate(ref.initializer)
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
op(ref, indices, updates).eval()
self.evaluate(op(ref, indices, updates))
# Test some out of range errors.
indices = np.array([-1, 0, 5])
with self.assertRaisesOpError(
r'indices\[0\] = -1 is not in \[0, 6\)'):
op(ref, indices, updates).eval()
self.evaluate(op(ref, indices, updates))
indices = np.array([2, 0, 6])
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
op(ref, indices, updates).eval()
self.evaluate(op(ref, indices, updates))
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
def _disabledTestScatterOutOfRangeGpu(self):
@ -355,7 +323,7 @@ class ScatterTest(test.TestCase):
# We don't test the implementation; just test there's no failures.
with test_util.force_gpu():
ref = variables.Variable(params)
ref.initializer.run()
self.evaluate(ref.initializer)
# Indices all in range, no problem.
indices = np.array([2, 0, 5])

View File

@ -1161,10 +1161,10 @@ class ResourceVariable(variables.VariableV1):
distribute_strategy=self._distribute_strategy), ()
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
"""Subtracts `tf.IndexedSlices` from this variable.
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -1173,40 +1173,126 @@ class ResourceVariable(variables.VariableV1):
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Adds `IndexedSlices` from this variable.
"""Adds `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be added to this variable.
sparse_delta: `tf.IndexedSlices` to be added to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
the scattered addition has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_add(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable.
def scatter_max(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to use as an argument of max
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered maximization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_max(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_min(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `tf.IndexedSlices` to use as an argument of min
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered minimization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_min(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
"""Multiply this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered multiplication has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_mul(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_div(self, sparse_delta, use_locking=False, name=None):
"""Divide this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to divide this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered division has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_div(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -1215,16 +1301,16 @@ class ResourceVariable(variables.VariableV1):
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(gen_resource_variable_ops.resource_scatter_update(
self.handle, sparse_delta.indices,
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable batch-wise.
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
Analogous to `batch_gather`. This assumes that this variable and the
sparse_delta IndexedSlices have a series of leading dimensions that are the
@ -1257,7 +1343,7 @@ class ResourceVariable(variables.VariableV1):
efficient than this implementation.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -1266,8 +1352,10 @@ class ResourceVariable(variables.VariableV1):
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return self._lazy_read(state_ops.batch_scatter_update(
self, sparse_delta.indices, sparse_delta.values,
use_locking=use_locking, name=name))
@ -1317,9 +1405,6 @@ class ResourceVariable(variables.VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return self._lazy_read(gen_state_ops.resource_scatter_nd_sub(
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
@ -1370,9 +1455,6 @@ class ResourceVariable(variables.VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return self._lazy_read(gen_state_ops.resource_scatter_nd_add(
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
@ -1423,9 +1505,6 @@ class ResourceVariable(variables.VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return self._lazy_read(gen_state_ops.resource_scatter_nd_update(
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),

View File

@ -645,12 +645,12 @@ def scatter_mul(ref, indices, updates, use_locking=False, name=None):
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
return gen_state_ops.scatter_mul(
ref=ref,
indices=indices,
updates=updates,
use_locking=use_locking,
name=name)
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_mul(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_div"])
@ -697,12 +697,12 @@ def scatter_div(ref, indices, updates, use_locking=False, name=None):
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
return gen_state_ops.scatter_div(
ref=ref,
indices=indices,
updates=updates,
use_locking=use_locking,
name=name)
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_div(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_max"])
@ -752,12 +752,12 @@ def scatter_max(ref, indices, updates, use_locking=False, name=None):
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
return gen_state_ops.scatter_max(
ref=ref,
indices=indices,
updates=updates,
use_locking=use_locking,
name=name)
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_max(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["scatter_min"])
@ -807,12 +807,12 @@ def scatter_min(ref, indices, updates, use_locking=False, name=None):
Returns:
A mutable `Tensor`. Has the same type as `ref`.
"""
return gen_state_ops.scatter_min(
ref=ref,
indices=indices,
updates=updates,
use_locking=use_locking,
name=name)
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_min(ref, indices, updates,
use_locking=use_locking, name=name)
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
@tf_export(v1=["batch_scatter_update"])

View File

@ -656,10 +656,10 @@ class Variable(six.with_metaclass(VariableMetaclass,
raise NotImplementedError
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
"""Subtracts `tf.IndexedSlices` from this variable.
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -668,15 +668,15 @@ class Variable(six.with_metaclass(VariableMetaclass,
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Adds `IndexedSlices` to this variable.
"""Adds `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to be added to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -685,15 +685,85 @@ class Variable(six.with_metaclass(VariableMetaclass,
the scattered addition has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_max(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `tf.IndexedSlices` to use as an argument of max
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered maximization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_min(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `tf.IndexedSlices` to use as an argument of min
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered minimization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
"""Multiply this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered multiplication has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_div(self, sparse_delta, use_locking=False, name=None):
"""Divide this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to divide this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered division has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable.
"""Assigns `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -702,12 +772,12 @@ class Variable(six.with_metaclass(VariableMetaclass,
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable batch-wise.
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
Analogous to `batch_gather`. This assumes that this variable and the
sparse_delta IndexedSlices have a series of leading dimensions that are the
@ -740,7 +810,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
efficient than this implementation.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -749,7 +819,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
@ -798,9 +868,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
@ -849,9 +916,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered addition has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
@ -900,9 +964,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
@ -2000,10 +2061,10 @@ class RefVariable(VariableV1):
return assign.op
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
"""Subtracts `tf.IndexedSlices` from this variable.
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -2012,10 +2073,10 @@ class RefVariable(VariableV1):
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_sub(
self._variable,
sparse_delta.indices,
@ -2024,10 +2085,10 @@ class RefVariable(VariableV1):
name=name)
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Adds `IndexedSlices` from this variable.
"""Adds `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be added to this variable.
sparse_delta: `tf.IndexedSlices` to be added to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -2036,10 +2097,10 @@ class RefVariable(VariableV1):
the scattered addition has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_add(
self._variable,
sparse_delta.indices,
@ -2047,11 +2108,109 @@ class RefVariable(VariableV1):
use_locking=use_locking,
name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable.
def scatter_max(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to use as an argument of max
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered maximization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_max(
self._variable,
sparse_delta.indices,
sparse_delta.values,
use_locking=use_locking,
name=name)
def scatter_min(self, sparse_delta, use_locking=False, name=None):
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
Args:
sparse_delta: `tf.IndexedSlices` to use as an argument of min
with this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered minimization has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_min(
self._variable,
sparse_delta.indices,
sparse_delta.values,
use_locking=use_locking,
name=name)
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
"""Multiply this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered multiplication has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_mul(
self._variable,
sparse_delta.indices,
sparse_delta.values,
use_locking=use_locking,
name=name)
def scatter_div(self, sparse_delta, use_locking=False, name=None):
"""Divide this variable by `tf.IndexedSlices`.
Args:
sparse_delta: `tf.IndexedSlices` to divide this variable by.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered division has completed.
Raises:
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_div(
self._variable,
sparse_delta.indices,
sparse_delta.values,
use_locking=use_locking,
name=name)
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `tf.IndexedSlices` to this variable.
Args:
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -2060,10 +2219,10 @@ class RefVariable(VariableV1):
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
if not isinstance(sparse_delta, ops.IndexedSlices):
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
return gen_state_ops.scatter_update(
self._variable,
sparse_delta.indices,
@ -2072,7 +2231,7 @@ class RefVariable(VariableV1):
name=name)
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable batch-wise.
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
Analogous to `batch_gather`. This assumes that this variable and the
sparse_delta IndexedSlices have a series of leading dimensions that are the
@ -2105,7 +2264,7 @@ class RefVariable(VariableV1):
efficient than this implementation.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
@ -2114,7 +2273,7 @@ class RefVariable(VariableV1):
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
TypeError: if `sparse_delta` is not an `IndexedSlices`.
"""
return state_ops.batch_scatter_update(
self, sparse_delta.indices, sparse_delta.values,
@ -2165,9 +2324,6 @@ class RefVariable(VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return gen_state_ops.scatter_nd_sub(
self._variable, indices, updates, use_locking=True, name=name)
@ -2217,9 +2373,6 @@ class RefVariable(VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered addition has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return gen_state_ops.scatter_nd_add(
self._variable, indices, updates, use_locking=True, name=name)
@ -2269,9 +2422,6 @@ class RefVariable(VariableV1):
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered assignment has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
return gen_state_ops.scatter_nd_update(
self._variable, indices, updates, use_locking=True, name=name)

View File

@ -112,6 +112,22 @@ tf_class {
name: "scatter_add"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_div"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_max"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_min"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_mul"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_nd_add"
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -111,6 +111,22 @@ tf_class {
name: "scatter_add"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_div"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_max"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_min"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_mul"
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "scatter_nd_add"
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "