Do not dispatch to TF conditionals for SparseTensor arguments.
PiperOrigin-RevId: 252076962
This commit is contained in:
parent
4c7f6c8404
commit
b807d424a0
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.autograph.converters import control_flow
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -183,6 +184,18 @@ class ControlFlowTest(converter_testing.TestCase):
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
|
||||
self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
|
||||
|
||||
def test_if_sparse_tensor(self):
|
||||
|
||||
def test_fn(cond, a):
|
||||
if cond:
|
||||
a = -a
|
||||
return a
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=((0,),), values=(0,), dense_shape=(1,))
|
||||
self.assertTransformedResult(test_fn, (st, constant_op.constant(1)), -1)
|
||||
self.assertTransformedResult(test_fn, (None, constant_op.constant(1)), 1)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_if_complex_outputs(self):
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.operators import special_values
|
||||
from tensorflow.python.autograph.utils import ag_logging
|
||||
from tensorflow.python.autograph.utils import tensors
|
||||
from tensorflow.python.data.experimental.ops import scan_ops
|
||||
from tensorflow.python.data.experimental.ops import take_while_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -300,7 +301,7 @@ def while_stmt(test, body, init_state, opts=None):
|
||||
|
||||
# TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
|
||||
# with the re-evaluation of `test` that `_tf_while_stmt` will make.
|
||||
if tensor_util.is_tensor(init_test):
|
||||
if tensors.is_dense_tensor(init_test):
|
||||
return _tf_while_stmt(test, body, init_state, opts)
|
||||
|
||||
# Normal Python: We already consumed one evaluation of `test`; consistently,
|
||||
@ -435,7 +436,8 @@ def if_stmt(cond, body, orelse, get_state, set_state):
|
||||
Returns:
|
||||
Tuple containing the statement outputs.
|
||||
"""
|
||||
if tensor_util.is_tensor(cond):
|
||||
# Note: tf.cond doesn't support SparseTensor.
|
||||
if tensors.is_dense_tensor(cond):
|
||||
return tf_if_stmt(cond, body, orelse, get_state, set_state)
|
||||
else:
|
||||
return _py_if_stmt(cond, body, orelse)
|
||||
|
@ -24,10 +24,17 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
|
||||
def is_dense_tensor(t):
|
||||
# TODO(mdan): Resolve this inconsistency.
|
||||
return (tensor_util.is_tensor(t) and
|
||||
not isinstance(t, sparse_tensor.SparseTensor))
|
||||
|
||||
|
||||
def is_tensor_array(t):
|
||||
return isinstance(t, tensor_array_ops.TensorArray)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user