diff --git a/tensorflow/python/autograph/operators/slices.py b/tensorflow/python/autograph/operators/slices.py index 2b7f5ad9226..af4074cc55a 100644 --- a/tensorflow/python/autograph/operators/slices.py +++ b/tensorflow/python/autograph/operators/slices.py @@ -22,6 +22,7 @@ import collections from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops @@ -119,9 +120,7 @@ def set_item(target, i, x): if target.dtype == dtypes.variant: return _tf_tensor_list_set_item(target, i, x) else: - raise ValueError( - 'tensor lists are expected to be Tensors with dtype=tf.variant,' - ' instead found %s' % target) + return _tf_tensor_set_item(target, i, x) else: return _py_set_item(target, i, x) @@ -136,6 +135,11 @@ def _tf_tensor_list_set_item(target, i, x): return list_ops.tensor_list_set_item(target, i, x) +def _tf_tensor_set_item(target, i, x): + """Overload of set_item that stages a Tensor scatter update.""" + return gen_array_ops.tensor_scatter_update(target, ((i,),), (x,)) + + def _py_set_item(target, i, x): """Overload of set_item that executes a Python list modification.""" target[i] = x