Fix keras/backend_test for v2 control flow.
PiperOrigin-RevId: 257228448
This commit is contained in:
parent
eebe2f7247
commit
b05c0ee754
@ -5429,9 +5429,9 @@ def ctc_label_dense_to_sparse(labels, label_lengths):
|
||||
num_batches_tns = array_ops.stack([label_shape[0]])
|
||||
max_num_labels_tns = array_ops.stack([label_shape[1]])
|
||||
|
||||
def range_less_than(_, current_input):
|
||||
def range_less_than(old_input, current_input):
|
||||
return array_ops.expand_dims(
|
||||
math_ops.range(label_shape[1]), 0) < array_ops.fill(
|
||||
math_ops.range(array_ops.shape(old_input)[1]), 0) < array_ops.fill(
|
||||
max_num_labels_tns, current_input)
|
||||
|
||||
init = math_ops.cast(
|
||||
|
@ -1661,6 +1661,7 @@ class BackendCrossEntropyLossesTest(test.TestCase):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.with_control_flow_v2
|
||||
class TestCTC(test.TestCase):
|
||||
|
||||
def test_ctc_decode(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user