Support scatter_nd_update for string dtypes
Fixes #30642 PiperOrigin-RevId: 268956189
This commit is contained in:
parent
80880afc82
commit
c3e32b03e1
@ -382,6 +382,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
|
||||
TF_CALL_tstring(REGISTER_SCATTER_ND_CPU);
|
||||
TF_CALL_tstring(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
|
||||
TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
|
||||
|
@ -167,6 +167,16 @@ class StatefulScatterNdTest(test.TestCase):
|
||||
result = self.evaluate(scatter)
|
||||
self.assertAllClose(result, expected)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testString(self):
|
||||
ref = variables.Variable(["qq", "ww", "ee", "rr", "", "", "", ""])
|
||||
indices = constant_op.constant([[4], [3], [1], [7]])
|
||||
updates = constant_op.constant(["aa", "dd", "cc", "bb"])
|
||||
update = state_ops.scatter_nd_update(ref, indices, updates)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual(self.evaluate(update),
|
||||
[b"qq", b"cc", b"ee", b"dd", b"aa", b"", b"", b"bb"])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSimpleResource(self):
|
||||
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user