Do not dispatch to TF conditionals for SparseTensor arguments.

PiperOrigin-RevId: 252076962
This commit is contained in:
Dan Moldovan 2019-06-07 10:42:18 -07:00 committed by TensorFlower Gardener
parent 4c7f6c8404
commit b807d424a0
3 changed files with 24 additions and 2 deletions

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -183,6 +184,18 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
def test_if_sparse_tensor(self):
def test_fn(cond, a):
if cond:
a = -a
return a
st = sparse_tensor.SparseTensor(
indices=((0,),), values=(0,), dense_shape=(1,))
self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1)
self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1)
@test_util.run_deprecated_v1
def test_if_complex_outputs(self):

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.operators import special_values
from tensorflow.python.autograph.utils import ag_logging
from tensorflow.python.autograph.utils import tensors
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.ops import dataset_ops
@ -300,7 +301,7 @@ def while_stmt(test, body, init_state, opts=None):
# TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
# with the re-evaluation of `test` that `_tf_while_stmt` will make.
if tensor_util.is_tensor(init_test):
if tensors.is_dense_tensor(init_test):
return _tf_while_stmt(test, body, init_state, opts)
# Normal Python: We already consumed one evaluation of `test`; consistently,
@ -435,7 +436,8 @@ def if_stmt(cond, body, orelse, get_state, set_state):
Returns:
Tuple containing the statement outputs.
"""
if tensor_util.is_tensor(cond):
# Note: tf.cond doesn't support SparseTensor.
if tensors.is_dense_tensor(cond):
return tf_if_stmt(cond, body, orelse, get_state, set_state)
else:
return _py_if_stmt(cond, body, orelse)

View File

@ -24,10 +24,17 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import tensor_array_ops
def is_dense_tensor(t):
# TODO(mdan): Resolve this inconsistency.
return (tensor_util.is_tensor(t) and
not isinstance(t, sparse_tensor.SparseTensor))
def is_tensor_array(t):
return isinstance(t, tensor_array_ops.TensorArray)