Add support for tensor_scatter_update (experimental).
PiperOrigin-RevId: 247413911
This commit is contained in:
parent
0aed5bddea
commit
5b8f89b921
@ -22,6 +22,7 @@ import collections
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
@ -119,9 +120,7 @@ def set_item(target, i, x):
|
||||
if target.dtype == dtypes.variant:
|
||||
return _tf_tensor_list_set_item(target, i, x)
|
||||
else:
|
||||
raise ValueError(
|
||||
'tensor lists are expected to be Tensors with dtype=tf.variant,'
|
||||
' instead found %s' % target)
|
||||
return _tf_tensor_set_item(target, i, x)
|
||||
else:
|
||||
return _py_set_item(target, i, x)
|
||||
|
||||
@ -136,6 +135,11 @@ def _tf_tensor_list_set_item(target, i, x):
|
||||
return list_ops.tensor_list_set_item(target, i, x)
|
||||
|
||||
|
||||
def _tf_tensor_set_item(target, i, x):
|
||||
"""Overload of set_item that stages a Tensor scatter update."""
|
||||
return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,))
|
||||
|
||||
|
||||
def _py_set_item(target, i, x):
|
||||
"""Overload of set_item that executes a Python list modification."""
|
||||
target[i] = x
|
||||
|
Loading…
Reference in New Issue
Block a user