Add support for tensor_scatter_update (experimental).

PiperOrigin-RevId: 247413911
This commit is contained in:
Dan Moldovan 2019-05-09 06:37:00 -07:00 committed by TensorFlower Gardener
parent 0aed5bddea
commit 5b8f89b921

View File

@ -22,6 +22,7 @@ import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
@ -119,9 +120,7 @@ def set_item(target, i, x):
if target.dtype == dtypes.variant:
return _tf_tensor_list_set_item(target, i, x)
else:
raise ValueError(
'tensor lists are expected to be Tensors with dtype=tf.variant,'
' instead found %s' % target)
return _tf_tensor_set_item(target, i, x)
else:
return _py_set_item(target, i, x)
@ -136,6 +135,11 @@ def _tf_tensor_list_set_item(target, i, x):
return list_ops.tensor_list_set_item(target, i, x)
def _tf_tensor_set_item(target, i, x):
"""Overload of set_item that stages a Tensor scatter update."""
return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,))
def _py_set_item(target, i, x):
"""Overload of set_item that executes a Python list modification."""
target[i] = x