Exposes the already-implemented variable scatter_div
, scatter_mul
, scatter_min
, and scatter_max
operations in the variables API. Also fixes some documentation, and adds tests that directly check the var.scatter_* methods.
PiperOrigin-RevId: 251565839
This commit is contained in:
parent
1535bc638c
commit
76171c3cba
@ -456,6 +456,74 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
|
||||||
self.assertEqual(self.evaluate(read), [[6]])
|
self.assertEqual(self.evaluate(read), [[6]])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterAddVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 1.5], name="add")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_add(ops.IndexedSlices(indices=[1], values=[2.5])))
|
||||||
|
self.assertAllEqual([0.0, 4.0], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterSubVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 2.5], name="sub")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_sub(ops.IndexedSlices(indices=[1], values=[1.5])))
|
||||||
|
self.assertAllEqual([0.0, 1.0], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterMaxVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="max1")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_max(ops.IndexedSlices(indices=[1], values=[5.0])))
|
||||||
|
self.assertAllEqual([0.0, 5.0], self.evaluate(v))
|
||||||
|
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 3.5], name="max2")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_max(ops.IndexedSlices(indices=[1], values=[2.0])))
|
||||||
|
self.assertAllEqual([0.0, 3.5], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterMinVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="min1")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_min(ops.IndexedSlices(indices=[1], values=[5.0])))
|
||||||
|
self.assertAllEqual([0.0, 4.0], self.evaluate(v))
|
||||||
|
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 3.5], name="min2")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_min(ops.IndexedSlices(indices=[1], values=[2.0])))
|
||||||
|
self.assertAllEqual([0.0, 2.0], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterMulVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 4.0], name="mul")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_mul(ops.IndexedSlices(indices=[1], values=[3.0])))
|
||||||
|
self.assertAllEqual([0.0, 12.0], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterDivVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 6.0], name="div")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_div(ops.IndexedSlices(indices=[1], values=[2.0])))
|
||||||
|
self.assertAllEqual([0.0, 3.0], self.evaluate(v))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testScatterUpdateVariableMethod(self):
|
||||||
|
v = resource_variable_ops.ResourceVariable([0.0, 6.0], name="update")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(
|
||||||
|
v.scatter_update(ops.IndexedSlices(indices=[1], values=[3.0])))
|
||||||
|
self.assertAllEqual([0.0, 3.0], self.evaluate(v))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testScatterUpdateString(self):
|
def testScatterUpdateString(self):
|
||||||
handle = resource_variable_ops.var_handle_op(
|
handle = resource_variable_ops.var_handle_op(
|
||||||
|
@ -179,10 +179,10 @@ class ScatterTest(test.TestCase):
|
|||||||
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
|
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
|
||||||
np_scatter(new, indices, updates)
|
np_scatter(new, indices, updates)
|
||||||
# Scatter via tensorflow
|
# Scatter via tensorflow
|
||||||
ref = variables.VariableV1(old, use_resource=False)
|
ref = variables.Variable(old)
|
||||||
ref.initializer.run()
|
self.evaluate(ref.initializer)
|
||||||
tf_scatter(ref, indices, updates).eval()
|
self.evaluate(tf_scatter(ref, indices, updates))
|
||||||
self.assertAllClose(ref.eval(), new)
|
self.assertAllClose(self.evaluate(ref), new)
|
||||||
|
|
||||||
def _VariableRankTests(self,
|
def _VariableRankTests(self,
|
||||||
tf_scatter,
|
tf_scatter,
|
||||||
@ -192,157 +192,125 @@ class ScatterTest(test.TestCase):
|
|||||||
if tf_scatter != state_ops.scatter_div:
|
if tf_scatter != state_ops.scatter_div:
|
||||||
vtypes.append(np.int32)
|
vtypes.append(np.int32)
|
||||||
|
|
||||||
if (tf_scatter == state_ops.scatter_min or
|
|
||||||
tf_scatter == state_ops.scatter_max):
|
|
||||||
vtypes.append(np.float16)
|
|
||||||
|
|
||||||
for vtype in vtypes:
|
for vtype in vtypes:
|
||||||
for itype in (np.int32, np.int64):
|
for itype in (np.int32, np.int64):
|
||||||
self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
|
self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
|
||||||
updates_are_scalar)
|
updates_are_scalar)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankUpdate(self):
|
def testVariableRankUpdate(self):
|
||||||
self._VariableRankTests(state_ops.scatter_update, False)
|
self._VariableRankTests(state_ops.scatter_update, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankAdd(self):
|
def testVariableRankAdd(self):
|
||||||
self._VariableRankTests(state_ops.scatter_add, False)
|
self._VariableRankTests(state_ops.scatter_add, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankSub(self):
|
def testVariableRankSub(self):
|
||||||
self._VariableRankTests(state_ops.scatter_sub, False)
|
self._VariableRankTests(state_ops.scatter_sub, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMul(self):
|
def testVariableRankMul(self):
|
||||||
self._VariableRankTests(state_ops.scatter_mul, False)
|
self._VariableRankTests(state_ops.scatter_mul, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankDiv(self):
|
def testVariableRankDiv(self):
|
||||||
self._VariableRankTests(state_ops.scatter_div, False)
|
self._VariableRankTests(state_ops.scatter_div, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMin(self):
|
def testVariableRankMin(self):
|
||||||
self._VariableRankTests(state_ops.scatter_min, False)
|
self._VariableRankTests(state_ops.scatter_min, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMax(self):
|
def testVariableRankMax(self):
|
||||||
self._VariableRankTests(state_ops.scatter_max, False)
|
self._VariableRankTests(state_ops.scatter_max, False)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesAdd(self):
|
def testRepeatIndicesAdd(self):
|
||||||
self._VariableRankTests(state_ops.scatter_add, True)
|
self._VariableRankTests(state_ops.scatter_add, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesSub(self):
|
def testRepeatIndicesSub(self):
|
||||||
self._VariableRankTests(state_ops.scatter_sub, True)
|
self._VariableRankTests(state_ops.scatter_sub, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMul(self):
|
def testRepeatIndicesMul(self):
|
||||||
self._VariableRankTests(state_ops.scatter_mul, True)
|
self._VariableRankTests(state_ops.scatter_mul, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesDiv(self):
|
def testRepeatIndicesDiv(self):
|
||||||
self._VariableRankTests(state_ops.scatter_div, True)
|
self._VariableRankTests(state_ops.scatter_div, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMin(self):
|
def testRepeatIndicesMin(self):
|
||||||
self._VariableRankTests(state_ops.scatter_min, True)
|
self._VariableRankTests(state_ops.scatter_min, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMax(self):
|
def testRepeatIndicesMax(self):
|
||||||
self._VariableRankTests(state_ops.scatter_max, True)
|
self._VariableRankTests(state_ops.scatter_max, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankUpdateScalar(self):
|
def testVariableRankUpdateScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_update, False, True)
|
self._VariableRankTests(state_ops.scatter_update, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankAddScalar(self):
|
def testVariableRankAddScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_add, False, True)
|
self._VariableRankTests(state_ops.scatter_add, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankSubScalar(self):
|
def testVariableRankSubScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_sub, False, True)
|
self._VariableRankTests(state_ops.scatter_sub, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMulScalar(self):
|
def testVariableRankMulScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_mul, False, True)
|
self._VariableRankTests(state_ops.scatter_mul, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankDivScalar(self):
|
def testVariableRankDivScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_div, False, True)
|
self._VariableRankTests(state_ops.scatter_div, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMinScalar(self):
|
def testVariableRankMinScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_min, False, True)
|
self._VariableRankTests(state_ops.scatter_min, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariableRankMaxScalar(self):
|
def testVariableRankMaxScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_max, False, True)
|
self._VariableRankTests(state_ops.scatter_max, False, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesAddScalar(self):
|
def testRepeatIndicesAddScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_add, True, True)
|
self._VariableRankTests(state_ops.scatter_add, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesSubScalar(self):
|
def testRepeatIndicesSubScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_sub, True, True)
|
self._VariableRankTests(state_ops.scatter_sub, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMulScalar(self):
|
def testRepeatIndicesMulScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_mul, True, True)
|
self._VariableRankTests(state_ops.scatter_mul, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesDivScalar(self):
|
def testRepeatIndicesDivScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_div, True, True)
|
self._VariableRankTests(state_ops.scatter_div, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMinScalar(self):
|
def testRepeatIndicesMinScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_min, True, True)
|
self._VariableRankTests(state_ops.scatter_min, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testRepeatIndicesMaxScalar(self):
|
def testRepeatIndicesMaxScalar(self):
|
||||||
self._VariableRankTests(state_ops.scatter_max, True, True)
|
self._VariableRankTests(state_ops.scatter_max, True, True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testBooleanScatterUpdate(self):
|
def testBooleanScatterUpdate(self):
|
||||||
if not test.is_gpu_available():
|
if not test.is_gpu_available():
|
||||||
with self.session(use_gpu=False) as session:
|
with self.session(use_gpu=False):
|
||||||
var = variables.Variable([True, False])
|
var = variables.Variable([True, False])
|
||||||
update0 = state_ops.scatter_update(var, 1, True)
|
update0 = state_ops.scatter_update(var, 1, True)
|
||||||
update1 = state_ops.scatter_update(
|
update1 = state_ops.scatter_update(
|
||||||
var, constant_op.constant(
|
var, constant_op.constant(
|
||||||
0, dtype=dtypes.int64), False)
|
0, dtype=dtypes.int64), False)
|
||||||
var.initializer.run()
|
self.evaluate(var.initializer)
|
||||||
|
|
||||||
session.run([update0, update1])
|
self.evaluate([update0, update1])
|
||||||
|
|
||||||
self.assertAllEqual([False, True], self.evaluate(var))
|
self.assertAllEqual([False, True], self.evaluate(var))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testScatterOutOfRangeCpu(self):
|
def testScatterOutOfRangeCpu(self):
|
||||||
for op, _ in _TF_OPS_TO_NUMPY.items():
|
for op, _ in _TF_OPS_TO_NUMPY.items():
|
||||||
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
||||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||||
if not test.is_gpu_available():
|
if not test.is_gpu_available():
|
||||||
with self.session(use_gpu=False):
|
with self.session(use_gpu=False):
|
||||||
ref = variables.VariableV1(params, use_resource=False)
|
ref = variables.Variable(params)
|
||||||
ref.initializer.run()
|
self.evaluate(ref.initializer)
|
||||||
|
|
||||||
# Indices all in range, no problem.
|
# Indices all in range, no problem.
|
||||||
indices = np.array([2, 0, 5])
|
indices = np.array([2, 0, 5])
|
||||||
op(ref, indices, updates).eval()
|
self.evaluate(op(ref, indices, updates))
|
||||||
|
|
||||||
# Test some out of range errors.
|
# Test some out of range errors.
|
||||||
indices = np.array([-1, 0, 5])
|
indices = np.array([-1, 0, 5])
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
r'indices\[0\] = -1 is not in \[0, 6\)'):
|
r'indices\[0\] = -1 is not in \[0, 6\)'):
|
||||||
op(ref, indices, updates).eval()
|
self.evaluate(op(ref, indices, updates))
|
||||||
|
|
||||||
indices = np.array([2, 0, 6])
|
indices = np.array([2, 0, 6])
|
||||||
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
|
with self.assertRaisesOpError(r'indices\[2\] = 6 is not in \[0, 6\)'):
|
||||||
op(ref, indices, updates).eval()
|
self.evaluate(op(ref, indices, updates))
|
||||||
|
|
||||||
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
|
# TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
|
||||||
def _disabledTestScatterOutOfRangeGpu(self):
|
def _disabledTestScatterOutOfRangeGpu(self):
|
||||||
@ -355,7 +323,7 @@ class ScatterTest(test.TestCase):
|
|||||||
# We don't test the implementation; just test there's no failures.
|
# We don't test the implementation; just test there's no failures.
|
||||||
with test_util.force_gpu():
|
with test_util.force_gpu():
|
||||||
ref = variables.Variable(params)
|
ref = variables.Variable(params)
|
||||||
ref.initializer.run()
|
self.evaluate(ref.initializer)
|
||||||
|
|
||||||
# Indices all in range, no problem.
|
# Indices all in range, no problem.
|
||||||
indices = np.array([2, 0, 5])
|
indices = np.array([2, 0, 5])
|
||||||
|
@ -1161,10 +1161,10 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
distribute_strategy=self._distribute_strategy), ()
|
distribute_strategy=self._distribute_strategy), ()
|
||||||
|
|
||||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Subtracts `IndexedSlices` from this variable.
|
"""Subtracts `tf.IndexedSlices` from this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be subtracted from this variable.
|
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -1173,40 +1173,126 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub(
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_sub(
|
||||||
self.handle, sparse_delta.indices,
|
self.handle, sparse_delta.indices,
|
||||||
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Adds `IndexedSlices` from this variable.
|
"""Adds `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be added to this variable.
|
sparse_delta: `tf.IndexedSlices` to be added to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered addition has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return self._lazy_read(gen_resource_variable_ops.resource_scatter_add(
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_add(
|
||||||
self.handle, sparse_delta.indices,
|
self.handle, sparse_delta.indices,
|
||||||
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable.
|
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to use as an argument of max
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered maximization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_max(
|
||||||
|
self.handle, sparse_delta.indices,
|
||||||
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of min
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered minimization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_min(
|
||||||
|
self.handle, sparse_delta.indices,
|
||||||
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Multiply this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered multiplication has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_mul(
|
||||||
|
self.handle, sparse_delta.indices,
|
||||||
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Divide this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to divide this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered division has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_div(
|
||||||
|
self.handle, sparse_delta.indices,
|
||||||
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Assigns `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -1215,16 +1301,16 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return self._lazy_read(gen_resource_variable_ops.resource_scatter_update(
|
return self._lazy_read(gen_resource_variable_ops.resource_scatter_update(
|
||||||
self.handle, sparse_delta.indices,
|
self.handle, sparse_delta.indices,
|
||||||
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
ops.convert_to_tensor(sparse_delta.values, self.dtype), name=name))
|
||||||
|
|
||||||
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable batch-wise.
|
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
|
||||||
|
|
||||||
Analogous to `batch_gather`. This assumes that this variable and the
|
Analogous to `batch_gather`. This assumes that this variable and the
|
||||||
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
||||||
@ -1257,7 +1343,7 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
efficient than this implementation.
|
efficient than this implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -1266,8 +1352,10 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return self._lazy_read(state_ops.batch_scatter_update(
|
return self._lazy_read(state_ops.batch_scatter_update(
|
||||||
self, sparse_delta.indices, sparse_delta.values,
|
self, sparse_delta.indices, sparse_delta.values,
|
||||||
use_locking=use_locking, name=name))
|
use_locking=use_locking, name=name))
|
||||||
@ -1317,9 +1405,6 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return self._lazy_read(gen_state_ops.resource_scatter_nd_sub(
|
return self._lazy_read(gen_state_ops.resource_scatter_nd_sub(
|
||||||
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
||||||
@ -1370,9 +1455,6 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return self._lazy_read(gen_state_ops.resource_scatter_nd_add(
|
return self._lazy_read(gen_state_ops.resource_scatter_nd_add(
|
||||||
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
||||||
@ -1423,9 +1505,6 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return self._lazy_read(gen_state_ops.resource_scatter_nd_update(
|
return self._lazy_read(gen_state_ops.resource_scatter_nd_update(
|
||||||
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
self.handle, indices, ops.convert_to_tensor(updates, self.dtype),
|
||||||
|
@ -645,12 +645,12 @@ def scatter_mul(ref, indices, updates, use_locking=False, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A mutable `Tensor`. Has the same type as `ref`.
|
A mutable `Tensor`. Has the same type as `ref`.
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_mul(
|
if ref.dtype._is_ref_dtype:
|
||||||
ref=ref,
|
return gen_state_ops.scatter_mul(ref, indices, updates,
|
||||||
indices=indices,
|
use_locking=use_locking, name=name)
|
||||||
updates=updates,
|
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access
|
||||||
use_locking=use_locking,
|
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
|
||||||
name=name)
|
name=name))
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["scatter_div"])
|
@tf_export(v1=["scatter_div"])
|
||||||
@ -697,12 +697,12 @@ def scatter_div(ref, indices, updates, use_locking=False, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A mutable `Tensor`. Has the same type as `ref`.
|
A mutable `Tensor`. Has the same type as `ref`.
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_div(
|
if ref.dtype._is_ref_dtype:
|
||||||
ref=ref,
|
return gen_state_ops.scatter_div(ref, indices, updates,
|
||||||
indices=indices,
|
use_locking=use_locking, name=name)
|
||||||
updates=updates,
|
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access
|
||||||
use_locking=use_locking,
|
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
|
||||||
name=name)
|
name=name))
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["scatter_max"])
|
@tf_export(v1=["scatter_max"])
|
||||||
@ -752,12 +752,12 @@ def scatter_max(ref, indices, updates, use_locking=False, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A mutable `Tensor`. Has the same type as `ref`.
|
A mutable `Tensor`. Has the same type as `ref`.
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_max(
|
if ref.dtype._is_ref_dtype:
|
||||||
ref=ref,
|
return gen_state_ops.scatter_max(ref, indices, updates,
|
||||||
indices=indices,
|
use_locking=use_locking, name=name)
|
||||||
updates=updates,
|
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access
|
||||||
use_locking=use_locking,
|
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
|
||||||
name=name)
|
name=name))
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["scatter_min"])
|
@tf_export(v1=["scatter_min"])
|
||||||
@ -807,12 +807,12 @@ def scatter_min(ref, indices, updates, use_locking=False, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A mutable `Tensor`. Has the same type as `ref`.
|
A mutable `Tensor`. Has the same type as `ref`.
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_min(
|
if ref.dtype._is_ref_dtype:
|
||||||
ref=ref,
|
return gen_state_ops.scatter_min(ref, indices, updates,
|
||||||
indices=indices,
|
use_locking=use_locking, name=name)
|
||||||
updates=updates,
|
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access
|
||||||
use_locking=use_locking,
|
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
|
||||||
name=name)
|
name=name))
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["batch_scatter_update"])
|
@tf_export(v1=["batch_scatter_update"])
|
||||||
|
@ -656,10 +656,10 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Subtracts `IndexedSlices` from this variable.
|
"""Subtracts `tf.IndexedSlices` from this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be subtracted from this variable.
|
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -668,15 +668,15 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Adds `IndexedSlices` to this variable.
|
"""Adds `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to be added to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -685,15 +685,85 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
the scattered addition has completed.
|
the scattered addition has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of max
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered maximization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of min
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered minimization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Multiply this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered multiplication has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Divide this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to divide this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered division has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable.
|
"""Assigns `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -702,12 +772,12 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable batch-wise.
|
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
|
||||||
|
|
||||||
Analogous to `batch_gather`. This assumes that this variable and the
|
Analogous to `batch_gather`. This assumes that this variable and the
|
||||||
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
||||||
@ -740,7 +810,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
efficient than this implementation.
|
efficient than this implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -749,7 +819,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -798,9 +868,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -849,9 +916,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered addition has completed.
|
the scattered addition has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -900,9 +964,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -2000,10 +2061,10 @@ class RefVariable(VariableV1):
|
|||||||
return assign.op
|
return assign.op
|
||||||
|
|
||||||
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Subtracts `IndexedSlices` from this variable.
|
"""Subtracts `tf.IndexedSlices` from this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be subtracted from this variable.
|
sparse_delta: `tf.IndexedSlices` to be subtracted from this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -2012,10 +2073,10 @@ class RefVariable(VariableV1):
|
|||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return gen_state_ops.scatter_sub(
|
return gen_state_ops.scatter_sub(
|
||||||
self._variable,
|
self._variable,
|
||||||
sparse_delta.indices,
|
sparse_delta.indices,
|
||||||
@ -2024,10 +2085,10 @@ class RefVariable(VariableV1):
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
def scatter_add(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Adds `IndexedSlices` from this variable.
|
"""Adds `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be added to this variable.
|
sparse_delta: `tf.IndexedSlices` to be added to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -2036,10 +2097,10 @@ class RefVariable(VariableV1):
|
|||||||
the scattered addition has completed.
|
the scattered addition has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return gen_state_ops.scatter_add(
|
return gen_state_ops.scatter_add(
|
||||||
self._variable,
|
self._variable,
|
||||||
sparse_delta.indices,
|
sparse_delta.indices,
|
||||||
@ -2047,11 +2108,109 @@ class RefVariable(VariableV1):
|
|||||||
use_locking=use_locking,
|
use_locking=use_locking,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def scatter_max(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable.
|
"""Updates this variable with the max of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to use as an argument of max
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered maximization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return gen_state_ops.scatter_max(
|
||||||
|
self._variable,
|
||||||
|
sparse_delta.indices,
|
||||||
|
sparse_delta.values,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def scatter_min(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Updates this variable with the min of `tf.IndexedSlices` and itself.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to use as an argument of min
|
||||||
|
with this variable.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered minimization has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return gen_state_ops.scatter_min(
|
||||||
|
self._variable,
|
||||||
|
sparse_delta.indices,
|
||||||
|
sparse_delta.values,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def scatter_mul(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Multiply this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to multiply this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered multiplication has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return gen_state_ops.scatter_mul(
|
||||||
|
self._variable,
|
||||||
|
sparse_delta.indices,
|
||||||
|
sparse_delta.values,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def scatter_div(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Divide this variable by `tf.IndexedSlices`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to divide this variable by.
|
||||||
|
use_locking: If `True`, use locking during the operation.
|
||||||
|
name: the name of the operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` that will hold the new value of this variable after
|
||||||
|
the scattered division has completed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
|
"""
|
||||||
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
|
return gen_state_ops.scatter_div(
|
||||||
|
self._variable,
|
||||||
|
sparse_delta.indices,
|
||||||
|
sparse_delta.values,
|
||||||
|
use_locking=use_locking,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
def scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
|
"""Assigns `tf.IndexedSlices` to this variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -2060,10 +2219,10 @@ class RefVariable(VariableV1):
|
|||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
if not isinstance(sparse_delta, ops.IndexedSlices):
|
if not isinstance(sparse_delta, ops.IndexedSlices):
|
||||||
raise ValueError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
raise TypeError("sparse_delta is not IndexedSlices: %s" % sparse_delta)
|
||||||
return gen_state_ops.scatter_update(
|
return gen_state_ops.scatter_update(
|
||||||
self._variable,
|
self._variable,
|
||||||
sparse_delta.indices,
|
sparse_delta.indices,
|
||||||
@ -2072,7 +2231,7 @@ class RefVariable(VariableV1):
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
def batch_scatter_update(self, sparse_delta, use_locking=False, name=None):
|
||||||
"""Assigns `IndexedSlices` to this variable batch-wise.
|
"""Assigns `tf.IndexedSlices` to this variable batch-wise.
|
||||||
|
|
||||||
Analogous to `batch_gather`. This assumes that this variable and the
|
Analogous to `batch_gather`. This assumes that this variable and the
|
||||||
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
sparse_delta IndexedSlices have a series of leading dimensions that are the
|
||||||
@ -2105,7 +2264,7 @@ class RefVariable(VariableV1):
|
|||||||
efficient than this implementation.
|
efficient than this implementation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sparse_delta: `IndexedSlices` to be assigned to this variable.
|
sparse_delta: `tf.IndexedSlices` to be assigned to this variable.
|
||||||
use_locking: If `True`, use locking during the operation.
|
use_locking: If `True`, use locking during the operation.
|
||||||
name: the name of the operation.
|
name: the name of the operation.
|
||||||
|
|
||||||
@ -2114,7 +2273,7 @@ class RefVariable(VariableV1):
|
|||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
TypeError: if `sparse_delta` is not an `IndexedSlices`.
|
||||||
"""
|
"""
|
||||||
return state_ops.batch_scatter_update(
|
return state_ops.batch_scatter_update(
|
||||||
self, sparse_delta.indices, sparse_delta.values,
|
self, sparse_delta.indices, sparse_delta.values,
|
||||||
@ -2165,9 +2324,6 @@ class RefVariable(VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered subtraction has completed.
|
the scattered subtraction has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_nd_sub(
|
return gen_state_ops.scatter_nd_sub(
|
||||||
self._variable, indices, updates, use_locking=True, name=name)
|
self._variable, indices, updates, use_locking=True, name=name)
|
||||||
@ -2217,9 +2373,6 @@ class RefVariable(VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered addition has completed.
|
the scattered addition has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_nd_add(
|
return gen_state_ops.scatter_nd_add(
|
||||||
self._variable, indices, updates, use_locking=True, name=name)
|
self._variable, indices, updates, use_locking=True, name=name)
|
||||||
@ -2269,9 +2422,6 @@ class RefVariable(VariableV1):
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor` that will hold the new value of this variable after
|
A `Tensor` that will hold the new value of this variable after
|
||||||
the scattered assignment has completed.
|
the scattered assignment has completed.
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `sparse_delta` is not an `IndexedSlices`.
|
|
||||||
"""
|
"""
|
||||||
return gen_state_ops.scatter_nd_update(
|
return gen_state_ops.scatter_nd_update(
|
||||||
self._variable, indices, updates, use_locking=True, name=name)
|
self._variable, indices, updates, use_locking=True, name=name)
|
||||||
|
@ -112,6 +112,22 @@ tf_class {
|
|||||||
name: "scatter_add"
|
name: "scatter_add"
|
||||||
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_div"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_max"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_min"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_mul"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scatter_nd_add"
|
name: "scatter_nd_add"
|
||||||
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -111,6 +111,22 @@ tf_class {
|
|||||||
name: "scatter_add"
|
name: "scatter_add"
|
||||||
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_div"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_max"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_min"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scatter_mul"
|
||||||
|
argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scatter_nd_add"
|
name: "scatter_nd_add"
|
||||||
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user