Minor optimization for RowPartition.from_uniform_row_length

PiperOrigin-RevId: 300118346
Change-Id: I9b653a73ee65668fea34931c888052d44520629c
This commit is contained in:
Edward Loper 2020-03-10 10:10:31 -07:00 committed by TensorFlower Gardener
parent 268853ee81
commit 49eac7f8e2

View File

@ -519,10 +519,9 @@ class RowPartition(object):
if nrows is None:
if const_row_length is None:
# Avoid division by zero if uniform_row_length==0 (and nvals==0).
rowlen_or_1 = control_flow_ops.cond(
math_ops.equal(uniform_row_length, 0),
lambda: constant_op.constant(1, uniform_row_length.dtype),
lambda: uniform_row_length)
rowlen_or_1 = math_ops.maximum(
uniform_row_length,
constant_op.constant(1, uniform_row_length.dtype))
nrows = nvals // rowlen_or_1
elif const_row_length == 0:
nrows = 0