Surface control_flow_ops.case to public. Update docs. Add unit tests.

Change: 115496194
This commit is contained in:
Eugene Brevdo 2016-02-24 14:50:49 -08:00 committed by TensorFlower Gardener
parent 497606904b
commit 2861cc1d23
12 changed files with 421 additions and 43 deletions

View File

@ -1285,6 +1285,120 @@ boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]]
```
- - -
### `tf.one_hot(indices, depth, on_value, off_value, axis=None, name=None)` {#one_hot}
Returns a one-hot tensor.
The locations represented by indices in `indices` take value `on_value`,
while all other locations take value `off_value`.
If the input `indices` is rank `N`, the output will have rank `N+1`,
The new axis is created at dimension `axis` (default: the new axis is
appended at the end).
If `indices` is a scalar the output shape will be a vector of length `depth`.
If `indices` is a vector of length `features`, the output shape will be:
```
features x depth if axis == -1
depth x features if axis == 0
```
If `indices` is a matrix (batch) with shape `[batch, features]`,
the output shape will be:
```
batch x features x depth if axis == -1
batch x depth x features if axis == 1
depth x batch x features if axis == 0
```
Examples
=========
Suppose that
```
indices = [0, 2, -1, 1]
depth = 3
on_value = 5.0
off_value = 0.0
axis = -1
```
Then output is `[4 x 3]`:
```output =
[5.0 0.0 0.0] // one_hot(0)
[0.0 0.0 5.0] // one_hot(2)
[0.0 0.0 0.0] // one_hot(-1)
[0.0 5.0 0.0] // one_hot(1)
```
Suppose that
```
indices = [0, 2, -1, 1]
depth = 3
on_value = 0.0
off_value = 3.0
axis = 0
```
Then output is `[3 x 4]`:
```output =
[0.0 3.0 3.0 3.0]
[3.0 3.0 3.0 0.0]
[3.0 3.0 3.0 3.0]
[3.0 0.0 3.0 3.0]
// ^ one_hot(0)
// ^ one_hot(2)
// ^ one_hot(-1)
// ^ one_hot(1)
```
Suppose that
```
indices = [[0, 2], [1, -1]]
depth = 3
on_value = 1.0
off_value = 0.0
axis = -1
```
Then output is `[2 x 2 x 3]`:
```output =
[
[1.0, 0.0, 0.0] // one_hot(0)
[0.0, 0.0, 1.0] // one_hot(2)
][
[0.0, 1.0, 0.0] // one_hot(1)
[0.0, 0.0, 0.0] // one_hot(-1)
]```
##### Args:
* <b>`indices`</b>: A `Tensor` of type `int64`. A tensor of indices.
* <b>`depth`</b>: A `Tensor` of type `int32`.
A scalar defining the depth of the one hot dimension.
* <b>`on_value`</b>: A `Tensor`.
A scalar defining the value to fill in output when `indices[j] = i`.
* <b>`off_value`</b>: A `Tensor`. Must have the same type as `on_value`.
A scalar defining the value to fill in output when `indices[j] != i`.
* <b>`axis`</b>: An optional `int`. Defaults to `-1`.
The axis to fill (default: -1, a new inner-most axis).
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `on_value`. The one-hot tensor.
## Other Functions and Classes
- - -

View File

@ -301,7 +301,7 @@ activation.
- - -
### `tf.contrib.layers.summarize_tensor(tensor)` {#summarize_tensor}
### `tf.contrib.layers.summarize_tensor(tensor, tag=None)` {#summarize_tensor}
Summarize a tensor using a suitable summary type.
@ -313,6 +313,7 @@ other tensors, `histogram_summary` is used.
* <b>`tensor`</b>: The tensor to summarize
* <b>`tag`</b>: The tag to use, if None then use tensor's op's name.
##### Returns:
@ -377,3 +378,31 @@ be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor
float.
- - -
### `tf.contrib.layers.assert_scalar_int(tensor)` {#assert_scalar_int}
Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`.
##### Args:
* <b>`tensor`</b>: Tensor to test.
##### Returns:
`tensor`, for chaining.
##### Raises:
* <b>`ValueError`</b>: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`.
- - -
### `tf.contrib.layers.is_numeric_tensor(tensor)` {#is_numeric_tensor}

View File

@ -182,6 +182,84 @@ the same non-zero number and type of outputs.
```
- - -
### `tf.case(pred_fn_pairs, default, exclusive=False, name='case')` {#case}
Create a case operation.
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
Each pair contains a boolean scalar tensor and a python callable that
creates the tensors to be returned if the boolean evaluates to True. `default`
is a callable generating a list of tensors. All the callables in
`pred_fn_pairs` as well as `default` should return the same number and types
of tensors.
If `exclusive==True`, all predicates are evaluated, and a logging operation
with an error is returned if more than one of the predicates evaluates to
True. If `exclusive==False`, execution stops are the first predicate which
evaluates to True, and the tensors generated by the corresponding function
are returned immediately. If none of the predicates evaluate to True, this
operation returns the tensors generated by `default`.
Example 1:
Pseudocode:
```
if (x < y) return 17;
else return 23;
```
Expressions:
```
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = case([(tf.less(x, y), f1)], default=f2)
```
Example 2:
Pseudocode:
```
if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;
```
Expressions:
```
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
```
##### Args:
* <b>`pred_fn_pairs`</b>: Dict or list of pairs of a boolean scalar tensor and a
callable which returns a list of tensors.
* <b>`default`</b>: A callable that returns a list of tensors.
* <b>`exclusive`</b>: True iff more than one predicate is allowed to evaluate to True.
* <b>`name`</b>: A name for this operation (optional).
##### Returns:
The tensors returned by the first pair whose predicate evaluated to True, or
those returned by `default` if none does.
##### Raises:
* <b>`TypeError`</b>: If `pred_fn_pairs` is not a list/dictionary.
* <b>`TypeError`</b>: If `pred_fn_pairs` is a list but does not contain 2-tuples.
* <b>`TypeError`</b>: If `fns[i]` is not callable for any i, or `default` is not
callable.
## Logical Operators

View File

@ -523,7 +523,7 @@ This method may be called concurrently from multiple threads.
- - -
#### `tf.Graph.unique_name(name)` {#Graph.unique_name}
#### `tf.Graph.unique_name(name, mark_as_used=True)` {#Graph.unique_name}
Return a unique operation name for `name`.
@ -537,10 +537,17 @@ Operation names are displayed in error messages reported by the
TensorFlow runtime, and in various visualization tools such as
TensorBoard.
If `mark_as_used` is set to `True`, which is the default, a new
unique name is created and marked as in use. If it's set to `False`,
the unique name is returned without actually being marked as used.
This is useful when the caller simply wants to know what the name
to be created will be.
##### Args:
* <b>`name`</b>: The name for an operation.
* <b>`mark_as_used`</b>: Whether to mark this name as being used.
##### Returns:

View File

@ -398,6 +398,39 @@ dimension.
- - -
### `tf.image.central_crop(image, central_fraction)` {#central_crop}
Crop the central region of the image.
Remove the outer parts of an image but retain the central region of the image
along each dimension. If we specify central_fraction = 0.5, this function
returns the region marked with "X" in the below diagram.
--------
| |
| XXXX |
| XXXX |
| | where "X" is the central 50% of the image.
--------
##### Args:
* <b>`image`</b>: 3-D float Tensor of shape [height, width, depth]
* <b>`central_fraction`</b>: float (0, 1], fraction of size to crop
##### Raises:
* <b>`ValueError`</b>: if central_crop_fraction is not within (0, 1].
##### Returns:
3-D float Tensor
- - -
### `tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` {#pad_to_bounding_box}

View File

@ -92,6 +92,7 @@
* [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch)
* [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims)
* [`gather`](../../api_docs/python/array_ops.md#gather)
* [`one_hot`](../../api_docs/python/array_ops.md#one_hot)
* [`pack`](../../api_docs/python/array_ops.md#pack)
* [`pad`](../../api_docs/python/array_ops.md#pad)
* [`rank`](../../api_docs/python/array_ops.md#rank)
@ -193,6 +194,7 @@
* [`sparse_segment_sum`](../../api_docs/python/math_ops.md#sparse_segment_sum)
* [`sqrt`](../../api_docs/python/math_ops.md#sqrt)
* [`square`](../../api_docs/python/math_ops.md#square)
* [`squared_difference`](../../api_docs/python/math_ops.md#squared_difference)
* [`sub`](../../api_docs/python/math_ops.md#sub)
* [`transpose`](../../api_docs/python/math_ops.md#transpose)
* [`truediv`](../../api_docs/python/math_ops.md#truediv)
@ -203,6 +205,7 @@
* **[Control Flow](../../api_docs/python/control_flow_ops.md)**:
* [`add_check_numerics_ops`](../../api_docs/python/control_flow_ops.md#add_check_numerics_ops)
* [`Assert`](../../api_docs/python/control_flow_ops.md#Assert)
* [`case`](../../api_docs/python/control_flow_ops.md#case)
* [`check_numerics`](../../api_docs/python/control_flow_ops.md#check_numerics)
* [`cond`](../../api_docs/python/control_flow_ops.md#cond)
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
@ -233,6 +236,7 @@
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
* [`adjust_hue`](../../api_docs/python/image.md#adjust_hue)
* [`adjust_saturation`](../../api_docs/python/image.md#adjust_saturation)
* [`central_crop`](../../api_docs/python/image.md#central_crop)
* [`convert_image_dtype`](../../api_docs/python/image.md#convert_image_dtype)
* [`crop_to_bounding_box`](../../api_docs/python/image.md#crop_to_bounding_box)
* [`decode_jpeg`](../../api_docs/python/image.md#decode_jpeg)
@ -316,6 +320,7 @@
* **[Neural Network](../../api_docs/python/nn.md)**:
* [`avg_pool`](../../api_docs/python/nn.md#avg_pool)
* [`batch_normalization`](../../api_docs/python/nn.md#batch_normalization)
* [`bias_add`](../../api_docs/python/nn.md#bias_add)
* [`compute_accidental_hits`](../../api_docs/python/nn.md#compute_accidental_hits)
* [`conv2d`](../../api_docs/python/nn.md#conv2d)
@ -422,8 +427,10 @@
* **[Layers (contrib)](../../api_docs/python/contrib.layers.md)**:
* [`assert_same_float_dtype`](../../api_docs/python/contrib.layers.md#assert_same_float_dtype)
* [`assert_scalar_int`](../../api_docs/python/contrib.layers.md#assert_scalar_int)
* [`convolution2d`](../../api_docs/python/contrib.layers.md#convolution2d)
* [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected)
* [`is_numeric_tensor`](../../api_docs/python/contrib.layers.md#is_numeric_tensor)
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)

View File

@ -628,6 +628,24 @@ Computes complementary error function of `x` element-wise.
the return type is `quint8`.
- - -
### `tf.squared_difference(x, y, name=None)` {#squared_difference}
Returns (x - y)(x - y) element-wise.
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `complex64`, `int64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `x`.
## Matrix Math Functions

View File

@ -104,7 +104,7 @@ Creating a variable.
- - -
#### `tf.Variable.__init__(initial_value=None, trainable=True, collections=None, validate_shape=True, name=None, variable_def=None)` {#Variable.__init__}
#### `tf.Variable.__init__(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None)` {#Variable.__init__}
Creates a new variable with value `initial_value`.
@ -131,6 +131,11 @@ variable to its initial value.
* <b>`validate_shape`</b>: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
* <b>`caching_device`</b>: Optional device string describing where the Variable
should be cached for reading. Defaults to the Variable's device.
If not `None`, caches on another device. Typical use is to cache
on the device where the Ops using the Variable reside, to deduplicate
copying through `Switch` and other conditional statements.
* <b>`name`</b>: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
* <b>`variable_def`</b>: `VariableDef` protocol buffer. If not `None`, recreates
@ -389,7 +394,7 @@ The `Operation` of this variable.
#### `tf.Variable.from_proto(variable_def)` {#Variable.from_proto}
Returns a `Variable` object created from `variable_def`.
- - -
@ -742,7 +747,7 @@ path can be passed directly to a call to `restore()`.
kept in the same directory as the checkpoint files, is automatically
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
* <b>`meta_graph_suffix`</b>: Suffix for MetaGraphDef file. Defaults to 'meta'.
* <b>`meta_graph_suffix`</b>: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
##### Returns:
@ -848,7 +853,7 @@ Writes `MetaGraphDef` to save_path/filename.
#### `tf.train.Saver.from_proto(saver_def)` {#Saver.from_proto}
Returns a `Saver` object created from `saver_def`.
- - -
@ -873,7 +878,11 @@ Sets the list of old checkpoint filenames and timestamps.
#### `tf.train.Saver.to_proto()` {#Saver.to_proto}
Returns a `SaverDef` protocol buffer.
Converts this `Saver` to a `SaverDef` protocol buffer.
##### Returns:
A `SaverDef` protocol buffer.
@ -1022,17 +1031,26 @@ Attributes:
initializer: default initializer passed to get_variable.
regularizer: default regularizer passed to get_variable.
reuse: Boolean or None, setting the reuse in get_variable.
caching_device: string, callable, or None: the caching device passed to
get_variable.
name_scope: The name passed to tf.name_scope.
- - -
#### `tf.VariableScope.__init__(reuse, name='', initializer=None, regularizer=None, name_scope='')` {#VariableScope.__init__}
#### `tf.VariableScope.__init__(reuse, name='', initializer=None, regularizer=None, caching_device=None, name_scope='')` {#VariableScope.__init__}
Creates a new VariableScope with the given properties.
- - -
#### `tf.VariableScope.get_variable(var_store, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None)` {#VariableScope.get_variable}
#### `tf.VariableScope.caching_device` {#VariableScope.caching_device}
- - -
#### `tf.VariableScope.get_variable(var_store, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None)` {#VariableScope.get_variable}
Gets an existing variable with this name or create a new one.
@ -1072,6 +1090,13 @@ Gets an existing variable with this name or create a new one.
Reuse variables in this scope.
- - -
#### `tf.VariableScope.set_caching_device(caching_device)` {#VariableScope.set_caching_device}
Set caching_device for this scope.
- - -
#### `tf.VariableScope.set_initializer(initializer)` {#VariableScope.set_initializer}
@ -1089,7 +1114,7 @@ Set regularizer for this scope.
- - -
### `tf.variable_scope(name_or_scope, reuse=None, initializer=None, regularizer=None)` {#variable_scope}
### `tf.variable_scope(name_or_scope, reuse=None, initializer=None, regularizer=None, caching_device=None)` {#variable_scope}
Returns a context for variable scope.
@ -1157,6 +1182,7 @@ then all its sub-scopes become reusing as well.
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
* <b>`initializer`</b>: default initializer for variables within this scope.
* <b>`regularizer`</b>: default regularizer for variables within this scope.
* <b>`caching_device`</b>: default caching device for variables within this scope.
##### Returns:
@ -1172,7 +1198,7 @@ then all its sub-scopes become reusing as well.
- - -
### `tf.variable_op_scope(values, name, default_name, initializer=None, regularizer=None)` {#variable_op_scope}
### `tf.variable_op_scope(values, name, default_name, initializer=None, regularizer=None, caching_device=None)` {#variable_op_scope}
Returns a context manager for defining an op that creates variables.
@ -1208,8 +1234,10 @@ def my_op_with_vars(a, b, name=None):
uniquified in the variable scope.
* <b>`default_name`</b>: The default name to use if the `name` argument is `None`, this
name will be uniquified.
* <b>`initializer`</b>: A default initializer to pass to variable scope.
* <b>`regularizer`</b>: default regularizer for variables within this scope.
* <b>`initializer`</b>: The default initializer to pass to variable scope.
* <b>`regularizer`</b>: The default regularizer for variables within this scope.
* <b>`caching_device`</b>: The default caching device for variables within this scope.
##### Returns:

View File

@ -410,36 +410,15 @@ current good choice is 1.0 or 0.1.
Optimizer that implements the FTRL algorithm.
See this [paper](
https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
- - -
#### `tf.train.FtrlOptimizer.__init__(learning_rate, learning_rate_power=-0.5, initial_accumulator_value=0.1, l1_regularization_strength=0.0, l2_regularization_strength=0.0, use_locking=False, name='Ftrl')` {#FtrlOptimizer.__init__}
Construct a new FTRL optimizer.
The Ftrl-proximal algorithm, abbreviated for Follow-the-regularized-leader,
is described in the paper [Ad Click Prediction: a View from the Trenches](
https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
It can give a good performance vs. sparsity tradeoff.
Ftrl-proximal uses its own global base learning rate and can behave like
Adagrad with `learning_rate_power=-0.5`, or like gradient descent with
`learning_rate_power=0.0`.
The effective learning rate is adjusted per parameter, relative to this
base learning rate as:
```
effective_learning_rate_i = (learning_rate /
pow(k + summed_squared_gradients_for_i, learning_rate_power));
```
where k is the small constant `initial_accumulator_value`.
Note that the real regularization coefficient of `|w|^2` for objective
function is `1 / lambda_2` if specifying `l2 = lambda_2` as argument when
using this function.
##### Args:
@ -1442,7 +1421,7 @@ depending on whether or not a `Coordinator` was passed to
#### `tf.train.QueueRunner.from_proto(queue_runner_def)` {#QueueRunner.from_proto}
Returns a `QueueRunner` object created from `queue_runner_def`.
- - -
@ -1936,7 +1915,7 @@ global_step: 10
##### Args:
* <b>`sess`</b>: A brain `Session` object.
* <b>`sess`</b>: A TensorFlow `Session` object.
* <b>`global_step_tensor`</b>: `Tensor` or the `name` of the operation that contains
the global step.
@ -2237,9 +2216,12 @@ Generates a checkpoint state proto.
Recreates a Graph saved in a `MetaGraphDef` proto.
This function reads from a file containing a `MetaGraphDef` proto,
adds all the nodes from the graph_def proto to the current graph,
recreates all the collections, and returns a saver from saver_def.
This function takes a `MetaGraphDef` protocol buffer as input. If
the argument is a file containing a `MetaGraphDef` protocol buffer ,
it constructs a protocol buffer from the file content. The function
then adds all the nodes from the `graph_def` field to the
current graph, recreates all the collections, and returns a saver
constructed from the `saver_def` field.
In combination with `export_meta_graph()`, this function can be used to
@ -2250,6 +2232,38 @@ In combination with `export_meta_graph()`, this function can be used to
* Run inference from a saved graph and checkpoints.
```Python
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_op', train_op)
sess = tf.Session()
for step in xrange(1000000):
sess.run(train_op)
if step % 1000 == 0:
# Saves checkpoint, which by default also exports a meta_graph
# named 'my-model-global_step.meta'.
saver.save(sess, 'my-model', global_step=step)
```
Later we can continue training from this saved `meta_graph` without building
the model from scratch.
```Python
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# tf.get_collection() retrurns a list. In this example we only want the
# first one.
train_op = tf.get_collection('train_op')[0]
for step in xrange(1000000):
sess.run(train_op)
```
NOTE: Restarting training from saved `meta_graph` only works if the
device assignments have not changed.
##### Args:

View File

@ -1232,6 +1232,50 @@ class ControlFlowTest(tf.test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testCase(self):
with self.test_session():
x = tf.constant(1)
y = tf.constant(2)
z = tf.constant(3)
f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
f3 = lambda: tf.constant(-1)
r1 = tf.case({x < y: f1, x > z: f2}, default=f3, exclusive=True)
self.assertAllEqual(r1.eval(), 17)
r2 = tf.case([(y > z, f1), (y > x, f2)], default=f3)
self.assertAllEqual(r2.eval(), 23)
# Duplicate events can happen, first one is selected
r3 = tf.case([(x < y, f1), (x < y, f2)], default=f3)
self.assertAllEqual(r3.eval(), 17)
# Duplicate events cause an error if exclusive = True
r4 = tf.case([(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
with self.assertRaisesOpError(
"More than one condition evaluated as True but exclusive=True."):
r4.eval()
# Check that the default is called if none of the others are
r5 = tf.case({x > y: f1}, default=f3)
self.assertAllEqual(r5.eval(), -1)
ran_once = [False, False, False]
def break_run_twice(ix):
def _break():
assert not ran_once[ix]
ran_once[ix] = True
return tf.constant(ix)
return _break
# Should not fail - each conditional gets called exactly once
r6 = tf.case([(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
default=break_run_twice(2))
self.assertAllEqual(r6.eval(), 0)
def testOneOpCond(self):
with self.test_session():
v = tf.Variable(0)

View File

@ -24,6 +24,7 @@ the execution of operations and add conditional dependencies to your graph.
@@no_op
@@count_up_to
@@cond
@@case
## Logical Operators
@ -82,6 +83,7 @@ from tensorflow.python.ops import constant_op
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
# pylint: disable=wildcard-import,undefined-variable
@ -1974,6 +1976,9 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
Expressions:
```
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
@ -2050,7 +2055,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"):
# and prev_case_seq will loop from case_sequence[0] to case_sequence[-1]
if exclusive:
# TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))
preds_c = array_ops.concat(0, preds, name="preds_c")
preds_c = array_ops.pack(preds, name="preds_c")
num_true_conditions = math_ops.reduce_sum(
math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
at_most_one_true_condition = math_ops.less(

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops.control_flow_ops import group
from tensorflow.python.ops.control_flow_ops import no_op
from tensorflow.python.ops.control_flow_ops import tuple
from tensorflow.python.ops.control_flow_ops import cond
from tensorflow.python.ops.control_flow_ops import case
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.gradients import *
from tensorflow.python.ops.init_ops import *