Support scatter_nd_update for string dtypes

Fixes #30642

PiperOrigin-RevId: 268956189
This commit is contained in:
Alexandre Passos 2019-09-13 12:20:10 -07:00 committed by TensorFlower Gardener
parent 80880afc82
commit c3e32b03e1
2 changed files with 11 additions and 0 deletions

View File

@ -382,6 +382,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
TF_CALL_tstring(REGISTER_SCATTER_ND_CPU);
TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_bool(REGISTER_SCATTER_ND_CPU);

View File

@ -167,6 +167,16 @@ class StatefulScatterNdTest(test.TestCase):
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@test_util.run_in_graph_and_eager_modes
def testString(self):
ref = variables.Variable(["qq", "ww", "ee", "rr", "", "", "", ""])
indices = constant_op.constant([[4], [3], [1], [7]])
updates = constant_op.constant(["aa", "dd", "cc", "bb"])
update = state_ops.scatter_nd_update(ref, indices, updates)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(self.evaluate(update),
[b"qq", b"cc", b"ee", b"dd", b"aa", b"", b"", b"bb"])
@test_util.run_deprecated_v1
def testSimpleResource(self):
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)