Changes: - Ptrdiff -> DenseIndex change by @jiayq - Fix to scoping the logging in logging.py by @dga - Improvement to Conv2DBackpropFilter on CPU by Andy - Remove lookup table wrappers for the time being (wasn't in our public API yet) by Yukata - Add a check similar to numpy to make sure the user isn't in the tensorflow src directory by @vrv - More changes for python 3 compat by @girving - Make dropout preserve shape info from input (@mrry) - Significant speed improvements by @zheng-xq to BFC allocator to bring on par (CPU overhead-wise) to the region allocator. Make BFC allocator the default now that it's working well for a variety of models. - Fix a bunch of typos reported by users (@vrv) - Enable concat for bfloat16 on GPU by Ashish. Base CL: 107733123
1310 lines
42 KiB
Python
1310 lines
42 KiB
Python
"""## Arithmetic Operators
|
|
|
|
TensorFlow provides several operations that you can use to add basic arithmetic
|
|
operators to your graph.
|
|
|
|
@@add
|
|
@@sub
|
|
@@mul
|
|
@@div
|
|
@@mod
|
|
|
|
## Basic Math Functions
|
|
|
|
TensorFlow provides several operations that you can use to add basic
|
|
mathematical functions to your graph.
|
|
|
|
@@add_n
|
|
@@abs
|
|
@@neg
|
|
@@sign
|
|
@@inv
|
|
@@square
|
|
@@round
|
|
@@sqrt
|
|
@@rsqrt
|
|
@@pow
|
|
@@exp
|
|
@@log
|
|
@@ceil
|
|
@@floor
|
|
@@maximum
|
|
@@minimum
|
|
@@cos
|
|
@@sin
|
|
|
|
## Matrix Math Functions
|
|
|
|
TensorFlow provides several operations that you can use to add basic
|
|
mathematical functions for matrices to your graph.
|
|
|
|
@@diag
|
|
@@transpose
|
|
|
|
@@matmul
|
|
@@batch_matmul
|
|
|
|
@@matrix_determinant
|
|
@@batch_matrix_determinant
|
|
|
|
@@matrix_inverse
|
|
@@batch_matrix_inverse
|
|
|
|
@@cholesky
|
|
@@batch_cholesky
|
|
|
|
## Complex Number Functions
|
|
|
|
TensorFlow provides several operations that you can use to add complex number
|
|
functions to your graph.
|
|
|
|
@@complex
|
|
@@complex_abs
|
|
@@conj
|
|
@@imag
|
|
@@real
|
|
|
|
## Reduction
|
|
|
|
TensorFlow provides several operations that you can use to perform
|
|
common math computations that reduce various dimensions of a tensor.
|
|
|
|
@@reduce_sum
|
|
@@reduce_prod
|
|
@@reduce_min
|
|
@@reduce_max
|
|
@@reduce_mean
|
|
@@reduce_all
|
|
@@reduce_any
|
|
|
|
@@accumulate_n
|
|
|
|
## Segmentation
|
|
|
|
TensorFlow provides several operations that you can use to perform common
|
|
math computations on tensor segments.
|
|
Here a segmentation is a partitioning of a tensor along
|
|
the first dimension, i.e. it defines a mapping from the first dimension onto
|
|
`segment_ids`. The `segment_ids` tensor should be the size of
|
|
the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
|
|
where `k<d0`.
|
|
In particular, a segmentation of a matrix tensor is a mapping of rows to
|
|
segments.
|
|
|
|
For example:
|
|
|
|
```python
|
|
c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
|
|
tf.segment_sum(c, tf.constant([0, 0, 1]))
|
|
==> [[0 0 0 0]
|
|
[5 6 7 8]]
|
|
```
|
|
|
|
@@segment_sum
|
|
@@segment_prod
|
|
@@segment_min
|
|
@@segment_max
|
|
@@segment_mean
|
|
|
|
@@unsorted_segment_sum
|
|
|
|
@@sparse_segment_sum
|
|
@@sparse_segment_mean
|
|
|
|
|
|
## Sequence Comparison and Indexing
|
|
|
|
TensorFlow provides several operations that you can use to add sequence
|
|
comparison and index extraction to your graph. You can use these operations to
|
|
determine sequence differences and determine the indexes of specific values in
|
|
a tensor.
|
|
|
|
@@argmin
|
|
@@argmax
|
|
|
|
@@listdiff
|
|
@@where
|
|
@@unique
|
|
|
|
@@edit_distance
|
|
|
|
@@invert_permutation
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import tensorflow.python.platform
|
|
|
|
import numpy as np
|
|
import six.moves
|
|
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import types
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import common_shapes
|
|
from tensorflow.python.ops import gen_math_ops
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import gen_state_ops
|
|
# pylint: disable=wildcard-import,undefined-variable
|
|
from tensorflow.python.ops.gen_math_ops import *
|
|
|
|
|
|
# Aliases for some automatically-generated names.
|
|
argmax = gen_math_ops.arg_max
|
|
argmin = gen_math_ops.arg_min
|
|
linspace = gen_math_ops.lin_space
|
|
|
|
|
|
# pylint: disable=anomalous-backslash-in-string,protected-access
|
|
def abs(x, name=None):
|
|
"""Computes the absolute value of a tensor.
|
|
|
|
Given a tensor of real numbers `x`, this operation returns a tensor
|
|
containing the absolute value of each element in `x`. For example, if x is
|
|
an input element and y is an output element, this operation computes
|
|
\\\\(y = |x|\\\\).
|
|
|
|
See [`tf.complex_abs()`](#tf_complex_abs) to compute the absolute value of a complex
|
|
number.
|
|
|
|
Args:
|
|
x: A `Tensor` of type `float`, `double`, `int32`, or `int64`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` the same size and type as `x` with absolute values.
|
|
"""
|
|
with ops.op_scope([x], name, "Abs") as name:
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
if x.dtype == types.complex64:
|
|
return gen_math_ops.complex_abs(x, name=name)
|
|
return gen_math_ops._abs(x, name=name)
|
|
|
|
|
|
|
|
def pow(x, y, name=None):
|
|
"""Computes the power of one value to another.
|
|
|
|
Given a tensor `x` and a tensor `y`, this operation computes \\\\(x^y\\\\) for
|
|
corresponding elements in `x` and `y`. For example:
|
|
|
|
```
|
|
# tensor 'x' is [[2, 2]], [3, 3]]
|
|
# tensor 'y' is [[8, 16], [2, 3]]
|
|
tf.pow(x, y) ==> [[256, 65536], [9, 27]]
|
|
```
|
|
|
|
Args:
|
|
x: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
|
|
y: A `Tensor` of type `float`, `double`, `int32`, `complex64`, or `int64`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor`.
|
|
"""
|
|
with ops.op_scope([x], name, "Pow") as name:
|
|
return gen_math_ops._pow(x, y, name=name)
|
|
|
|
|
|
def complex(real, imag, name=None):
|
|
"""Converts two real numbers to a complex number.
|
|
|
|
Given a tensor `real` representing the real part of a complex number, and a
|
|
tensor `imag` representing the imaginary part of a complex number, this
|
|
operation computes complex numbers elementwise of the form \\\\(a + bj\\\\),
|
|
where *a* represents the `real` part and *b* represents the `imag` part.
|
|
|
|
The input tensors `real` and `imag` must be the same shape.
|
|
|
|
For example:
|
|
|
|
```
|
|
# tensor 'real' is [2.25, 3.25]
|
|
# tensor `imag` is [4.75, 5.75]
|
|
tf.complex(real, imag) ==> [[2.25 + 4.74j], [3.25 + 5.75j]]
|
|
```
|
|
|
|
Args:
|
|
real: A `Tensor` of type `float`.
|
|
imag: A `Tensor` of type `float`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of type `complex64`.
|
|
"""
|
|
with ops.op_scope([real, imag], name, "Complex") as name:
|
|
return gen_math_ops._complex(real, imag, name=name)
|
|
|
|
|
|
def round(x, name=None):
|
|
"""Rounds the values of a tensor to the nearest integer, element-wise.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 'a' is [0.9, 2.5, 2.3, -4.4]
|
|
tf.round(a) ==> [ 1.0, 3.0, 2.0, -4.0 ]
|
|
```
|
|
|
|
Args:
|
|
x: A `Tensor` of type `float` or `double`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of same shape and type as `x`.
|
|
"""
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
if x.dtype.is_integer:
|
|
return x
|
|
else:
|
|
return floor(x + 0.5, name=name)
|
|
|
|
|
|
def cast(x, dtype, name=None):
|
|
"""Casts a tensor to a new type.
|
|
|
|
The operation casts `x` (in case of `Tensor`) or `x.values`
|
|
(in case of `SparseTensor`) to `dtype`.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# tensor `a` is [1.8, 2.2], dtype=tf.float
|
|
tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
|
|
```
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
dtype: The destination type.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `dtype`.
|
|
"""
|
|
with ops.op_scope([x], name, "Cast") as name:
|
|
if isinstance(x, ops.SparseTensor):
|
|
values_cast = cast(x.values, dtype, name=name)
|
|
return ops.SparseTensor(x.indices, values_cast, x.shape)
|
|
else:
|
|
# TODO(touts): Handle what Josh said.
|
|
#
|
|
# Could return ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
|
|
# allows some conversions that cast() can't do, e.g. casting numbers to
|
|
# strings.
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
if x.dtype.base_dtype == dtype:
|
|
return x
|
|
return gen_math_ops.cast(x, dtype, name=name)
|
|
|
|
|
|
def to_float(x, name="ToFloat"):
|
|
"""Casts a tensor to type `float32`.
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `float32`.
|
|
"""
|
|
return cast(x, types.float32, name=name)
|
|
|
|
|
|
def to_double(x, name="ToDouble"):
|
|
"""Casts a tensor to type `float64`.
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `float64`.
|
|
"""
|
|
return cast(x, types.float64, name=name)
|
|
|
|
|
|
def to_int32(x, name="ToInt32"):
|
|
"""Casts a tensor to type `int32`.
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `int32`.
|
|
"""
|
|
return cast(x, types.int32, name=name)
|
|
|
|
|
|
def to_int64(x, name="ToInt64"):
|
|
"""Casts a tensor to type `int64`.
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `int64`.
|
|
"""
|
|
return cast(x, types.int64, name=name)
|
|
|
|
|
|
def to_bfloat16(x, name="ToBFloat16"):
|
|
"""Casts a tensor to type `bfloat16`.
|
|
|
|
Args:
|
|
x: A `Tensor` or `SparseTensor`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
|
|
|
|
Raises:
|
|
TypeError: If `x` cannot be cast to the `bfloat16`.
|
|
"""
|
|
return cast(x, types.bfloat16, name=name)
|
|
|
|
|
|
ops.Tensor._override_operator("__neg__", neg)
|
|
ops.Tensor._override_operator("__abs__", abs)
|
|
# __invert__ corresponds to the ~ operator. Here we follow the numpy convention
|
|
# ~ marks an elementwise bit-wise inverse. This is only implemented for boolean
|
|
# tensors and will throw a TypeError if used on nonboolean arrays
|
|
ops.Tensor._override_operator("__invert__", logical_not)
|
|
|
|
|
|
def _OverrideBinaryOperatorHelper(func, op_name):
|
|
"""Register operators with different tensor and scalar versions.
|
|
|
|
Args:
|
|
func: the operator
|
|
op_name: name of the operator being overridden
|
|
"""
|
|
|
|
def binary_op_wrapper(x, y):
|
|
with ops.op_scope([x, y], None, op_name) as name:
|
|
assert isinstance(x, ops.Tensor)
|
|
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
|
|
return func(x, y, name=name)
|
|
|
|
ops.Tensor._override_operator("__%s__" % op_name, binary_op_wrapper)
|
|
del binary_op_wrapper
|
|
|
|
def r_binary_op_wrapper(y, x):
|
|
with ops.op_scope([x, y], None, op_name) as name:
|
|
assert isinstance(y, ops.Tensor)
|
|
x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
|
|
return func(x, y, name=name)
|
|
|
|
ops.Tensor._override_operator("__r%s__" % op_name, r_binary_op_wrapper)
|
|
del r_binary_op_wrapper
|
|
|
|
|
|
# Conversion table for __truediv__. None entries mean no conversion required.
|
|
_TRUEDIV_TABLE = {
|
|
types.uint8: types.float32,
|
|
types.int8: types.float32,
|
|
types.int16: types.float32,
|
|
types.int32: types.float64,
|
|
types.int64: types.float64,
|
|
types.float32: None,
|
|
types.float64: None,
|
|
types.complex64: None,
|
|
}
|
|
|
|
|
|
def truediv(x, y, name=None):
|
|
"""Divides x / y elementwise, always producing floating point results.
|
|
|
|
The same as `tf.div` for floating point arguments, but casts integer arguments
|
|
to floating point before dividing so that the result is always floating point.
|
|
This op is generated by normal `x / y` division in Python 3 and in Python 2.7
|
|
with `from __future__ import division`. If you want integer division that
|
|
rounds down, use `x // y` or `tf.floordiv`.
|
|
|
|
`x` and `y` must have the same numeric type. If the inputs are floating
|
|
point, the output will have the same type. If the inputs are integral, the
|
|
inputs are cast to `float32` for `int8` and `int16` and `float64` for `int32`
|
|
and `int64` (matching the behavior of Numpy).
|
|
|
|
Args:
|
|
x: `Tensor` numerator of numeric type.
|
|
y: `Tensor` denominator of numeric type.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
`x / y` evaluated in floating point.
|
|
|
|
Raises:
|
|
TypeError: If `x` and `y` have different dtypes.
|
|
"""
|
|
with ops.op_scope([x, y], name, "truediv") as name:
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
y = ops.convert_to_tensor(y, name="y")
|
|
x_dtype = x.dtype.base_dtype
|
|
y_dtype = y.dtype.base_dtype
|
|
if x_dtype != y_dtype:
|
|
raise TypeError("x and y must have the same dtype, got %r != %r" %
|
|
(x_dtype, y_dtype))
|
|
try:
|
|
dtype = _TRUEDIV_TABLE[x_dtype]
|
|
except KeyError:
|
|
raise TypeError("Invalid dtype %r in __truediv__" % x_dtype)
|
|
if dtype is not None:
|
|
x = cast(x, dtype)
|
|
y = cast(y, dtype)
|
|
return div(x, y, name=name)
|
|
|
|
|
|
def floordiv(x, y, name=None):
|
|
"""Divides `x / y` elementwise, rounding down for floating point.
|
|
|
|
The same as `tf.div(x,y)`, but uses `tf.floor(tf.div(x,y))` for floating
|
|
point arguments so that the result is always an integer (though possibly an
|
|
integer represented as floating point). This op is generated by `x // y`
|
|
floor division in Python 3 and in Python 2.7 with
|
|
`from __future__ import division`.
|
|
|
|
Note that for efficiency, __floordiv__ uses C semantics for negative numbers
|
|
(unlike Python and Numpy).
|
|
|
|
`x` and `y` must have the same type, and the result will have the same type
|
|
as well.
|
|
|
|
Args:
|
|
x: `Tensor` numerator of real numeric type.
|
|
y: `Tensor` numerator of real numeric type.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
`x / y` rounded down (except possibly for integers in C).
|
|
|
|
Raises:
|
|
TypeError: If the inputs are complex.
|
|
"""
|
|
with ops.op_scope([x, y], name, "floordiv") as name:
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
dtype = x.dtype
|
|
if dtype.is_floating:
|
|
return floor(div(x, y), name=name)
|
|
else:
|
|
if not dtype.is_integer:
|
|
raise TypeError("Expected floating point or integer, got %r" % dtype)
|
|
return div(x, y, name=name)
|
|
|
|
|
|
_OverrideBinaryOperatorHelper(add, "add")
|
|
_OverrideBinaryOperatorHelper(sub, "sub")
|
|
_OverrideBinaryOperatorHelper(mul, "mul")
|
|
_OverrideBinaryOperatorHelper(div, "div")
|
|
_OverrideBinaryOperatorHelper(truediv, "truediv")
|
|
_OverrideBinaryOperatorHelper(floordiv, "floordiv")
|
|
_OverrideBinaryOperatorHelper(mod, "mod")
|
|
|
|
|
|
def logical_xor(x, y, name="LogicalXor"):
|
|
"""x ^ y = (x | y) & ~(x & y)."""
|
|
# TODO(alemi) Make this a cwise op if people end up relying on it.
|
|
return logical_and(logical_or(x, y), logical_not(logical_and(x, y)),
|
|
name=name)
|
|
|
|
_OverrideBinaryOperatorHelper(logical_and, "and")
|
|
_OverrideBinaryOperatorHelper(logical_or, "or")
|
|
_OverrideBinaryOperatorHelper(logical_xor, "xor")
|
|
|
|
ops.Tensor._override_operator("__lt__", less)
|
|
ops.Tensor._override_operator("__le__", less_equal)
|
|
ops.Tensor._override_operator("__gt__", greater)
|
|
ops.Tensor._override_operator("__ge__", greater_equal)
|
|
|
|
|
|
def range(start, limit=None, delta=1, name="range"):
|
|
"""Creates a sequence of integers.
|
|
|
|
Creates a sequence of integers that begins at `start` and extends by
|
|
increments of `delta` up to but not including `limit`.
|
|
|
|
Like the Python builtin `range`, `start` defaults to 0, so that
|
|
`range(n) = range(0, n)`.
|
|
|
|
For example:
|
|
|
|
```
|
|
# 'start' is 3
|
|
# 'limit' is 18
|
|
# 'delta' is 3
|
|
tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
|
|
|
|
# 'limit' is 5
|
|
tf.range(limit) ==> [0, 1, 2, 3, 4]
|
|
```
|
|
|
|
Args:
|
|
start: A 0-D (scalar) of type `int32`. First entry in sequence.
|
|
Defaults to 0.
|
|
limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
|
|
exclusive.
|
|
delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
|
|
Number that increments `start`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
An 1-D `int32` `Tensor`.
|
|
"""
|
|
if limit is None:
|
|
start, limit = 0, start
|
|
return gen_math_ops._range(start, limit, delta, name=name)
|
|
|
|
|
|
@ops.RegisterShape("Range")
|
|
def _RangeShape(op):
|
|
start_value = tensor_util.ConstantValue(op.inputs[0])
|
|
limit_value = tensor_util.ConstantValue(op.inputs[1])
|
|
delta_value = tensor_util.ConstantValue(op.inputs[2])
|
|
if start_value is None or limit_value is None or delta_value is None:
|
|
return [tensor_shape.vector(None)]
|
|
else:
|
|
return [tensor_shape.vector((limit_value - start_value + delta_value - 1) //
|
|
delta_value)]
|
|
|
|
|
|
# Reduction operations
|
|
def _ReductionDims(x, reduction_indices):
|
|
"""Returns range(0, rank(x)) if reduction_indices is None."""
|
|
if reduction_indices is not None:
|
|
return reduction_indices
|
|
else:
|
|
return range(0, array_ops.rank(x))
|
|
|
|
|
|
def reduce_sum(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the sum of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 'x' is [[1, 1, 1]]
|
|
# [1, 1, 1]]
|
|
tf.reduce_sum(x) ==> 6
|
|
tf.reduce_sum(x, 0) ==> [2, 2, 2]
|
|
tf.reduce_sum(x, 1) ==> [3, 3]
|
|
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
|
|
tf.reduce_sum(x, [0, 1]) ==> 6
|
|
```
|
|
|
|
Args:
|
|
input_tensor: The tensor to reduce. Should have numeric type.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._sum(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_mean(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the mean of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 'x' is [[1., 1. ]]
|
|
# [2., 2.]]
|
|
tf.reduce_mean(x) ==> 1.5
|
|
tf.reduce_mean(x, 0) ==> [1.5, 1.5]
|
|
tf.reduce_mean(x, 1) ==> [1., 2.]
|
|
```
|
|
|
|
Args:
|
|
input_tensor: The tensor to reduce. Should have numeric type.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._mean(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_prod(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the product of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
Args:
|
|
input_tensor: The tensor to reduce. Should have numeric type.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._prod(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_min(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the minimum of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
Args:
|
|
input_tensor: The tensor to reduce. Should have numeric type.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._min(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_max(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the maximum of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
Args:
|
|
input_tensor: The tensor to reduce. Should have numeric type.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._max(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_all(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the "logical and" of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 'x' is [[True, True]]
|
|
# [False, False]]
|
|
tf.reduce_all(x) ==> False
|
|
tf.reduce_all(x, 0) ==> [False, False]
|
|
tf.reduce_all(x, 1) ==> [True, False]
|
|
```
|
|
|
|
Args:
|
|
input_tensor: The boolean tensor to reduce.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._all(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def reduce_any(input_tensor, reduction_indices=None, keep_dims=False,
|
|
name=None):
|
|
"""Computes the "logical or" of elements across dimensions of a tensor.
|
|
|
|
Reduces `input_tensor` along the dimensions given in `reduction_indices`.
|
|
Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each
|
|
entry in `reduction_indices`. If `keep_dims` is true, the reduced dimensions
|
|
are retained with length 1.
|
|
|
|
If `reduction_indices` has no entries, all dimensions are reduced, and a
|
|
tensor with a single element is returned.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 'x' is [[True, True]]
|
|
# [False, False]]
|
|
tf.reduce_any(x) ==> True
|
|
tf.reduce_any(x, 0) ==> [True, True]
|
|
tf.reduce_any(x, 1) ==> [True, False]
|
|
```
|
|
|
|
Args:
|
|
input_tensor: The boolean tensor to reduce.
|
|
reduction_indices: The dimensions to reduce. If `None` (the defaut),
|
|
reduces all dimensions.
|
|
keep_dims: If true, retains reduced dimensions with length 1.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
The reduced tensor.
|
|
"""
|
|
return gen_math_ops._any(input_tensor, _ReductionDims(input_tensor,
|
|
reduction_indices),
|
|
keep_dims, name=name)
|
|
|
|
|
|
def matmul(a, b,
|
|
transpose_a=False, transpose_b=False,
|
|
a_is_sparse=False, b_is_sparse=False,
|
|
name=None):
|
|
"""Multiplies matrix `a` by matrix `b`, producing `a` * `b`.
|
|
|
|
The inputs must be two-dimensional matrices, with matching inner dimensions,
|
|
possibly after transposition.
|
|
|
|
Both matrices must be of the same type. The supported types are:
|
|
`float`, `double`, `int32`, `complex64`.
|
|
|
|
Either matrix can be transposed on the fly by setting the corresponding flag
|
|
to `True`. This is `False` by default.
|
|
|
|
If one or both of the matrices contain a lot of zeros, a more efficient
|
|
multiplication algorithm can be used by setting the corresponding
|
|
`a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# 2-D tensor `a`
|
|
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3]) => [[1. 2. 3.]
|
|
[4. 5. 6.]]
|
|
# 2-D tensor `b`
|
|
b = tf.constant([7, 8, 9, 10, 11, 12], shape=[3, 2]) => [[7. 8.]
|
|
[9. 10.]
|
|
[11. 12.]]
|
|
c = tf.matmul(a, b) => [[58 64]
|
|
[139 154]]
|
|
```
|
|
|
|
Args:
|
|
a: `Tensor` of type `float`, `double`, `int32` or `complex64`.
|
|
b: `Tensor` with same type as `a`.
|
|
transpose_a: If `True`, `a` is transposed before multiplication.
|
|
transpose_b: If `True`, `b` is transposed before multiplication.
|
|
a_is_sparse: If `True`, `a` is treated as a sparse matrix.
|
|
b_is_sparse: If `True`, `b` is treated as a sparse matrix.
|
|
name: Name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of the same type as `a`.
|
|
"""
|
|
with ops.op_scope([a, b], name, "MatMul") as name:
|
|
a = ops.convert_to_tensor(a, name="a")
|
|
b = ops.convert_to_tensor(b, name="b")
|
|
if a.dtype == types.float32 and (a_is_sparse or b_is_sparse):
|
|
return sparse_matmul(a, b,
|
|
transpose_a=transpose_a,
|
|
transpose_b=transpose_b,
|
|
a_is_sparse=a_is_sparse,
|
|
b_is_sparse=b_is_sparse,
|
|
name=name)
|
|
else:
|
|
return gen_math_ops._mat_mul(a, b,
|
|
transpose_a=transpose_a,
|
|
transpose_b=transpose_b,
|
|
name=name)
|
|
|
|
sparse_matmul = gen_math_ops._sparse_mat_mul
|
|
batch_matmul = gen_math_ops._batch_mat_mul
|
|
|
|
ops.RegisterShape("MatMul")(common_shapes.matmul_shape)
|
|
ops.RegisterShape("SparseMatMul")(common_shapes.matmul_shape)
|
|
|
|
|
|
def _as_indexed_slices(x):
|
|
"""Convert 'x' to IndexedSlices.
|
|
|
|
Convert a dense Tensor to a block-sparse IndexedSlices.
|
|
|
|
Args:
|
|
x: Either a Tensor object, or an IndexedSlices object.
|
|
|
|
Returns:
|
|
An IndexedSlices object.
|
|
|
|
Raises:
|
|
TypeError: If 'x' is not a Tensor or an IndexedSlices object.
|
|
"""
|
|
# TODO(touts): op_scope
|
|
if not isinstance(x, (ops.Tensor, ops.IndexedSlices)):
|
|
raise TypeError("Not a Tensor or IndexedSlices: %s" % type(x))
|
|
if isinstance(x, ops.IndexedSlices):
|
|
return x
|
|
x_shape = array_ops.shape(x)
|
|
return ops.IndexedSlices(x, range(0, x_shape[0]), x_shape)
|
|
|
|
|
|
def _as_indexed_slices_list(inputs):
|
|
"""Convert all elements of 'inputs' to IndexedSlices.
|
|
|
|
Additionally, homogenize the types of all the indices to
|
|
either int32 or int64.
|
|
|
|
Args:
|
|
inputs: List containing either Tensor or IndexedSlices objects.
|
|
|
|
Returns:
|
|
A list of IndexedSlices objects.
|
|
|
|
Raises:
|
|
TypeError: If 'inputs' is not a list or a tuple.
|
|
"""
|
|
if not isinstance(inputs, (list, tuple)):
|
|
raise TypeError("Expected a list or tuple, not a %s" % type(inputs))
|
|
outputs = [_as_indexed_slices(i) for i in inputs]
|
|
with_int32_index = [o.indices for o in outputs
|
|
if o.indices.dtype == types.int32]
|
|
if not with_int32_index or len(with_int32_index) == len(outputs):
|
|
return outputs
|
|
casted_outputs = []
|
|
for o in outputs:
|
|
if o.indices.dtype == types.int32:
|
|
casted_outputs.append(
|
|
ops.IndexedSlices(o.values, cast(o.indices, types.int64),
|
|
o.dense_shape))
|
|
else:
|
|
casted_outputs.append(o)
|
|
return casted_outputs
|
|
|
|
|
|
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
|
|
"""Returns the element-wise sum of a list of tensors.
|
|
|
|
Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
|
|
otherwise, these are inferred.
|
|
|
|
For example:
|
|
|
|
```python
|
|
# tensor 'a' is [[1, 2], [3, 4]
|
|
# tensor `b` is [[5, 0], [0, 6]]
|
|
tf.accumulate_n([a, b, a]) ==> [[7, 4], [6, 14]]
|
|
|
|
# Explicitly pass shape and type
|
|
tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
|
|
==> [[7, 4], [6, 14]]
|
|
```
|
|
|
|
Args:
|
|
inputs: A list of `Tensor` objects, each with same shape and type.
|
|
shape: Shape of elements of `inputs`.
|
|
tensor_dtype: The type of `inputs`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A `Tensor` of same shape and type as the elements of `inputs`.
|
|
|
|
Raises:
|
|
ValueError: If `inputs` don't all have same shape and dtype or the shape
|
|
cannot be inferred.
|
|
"""
|
|
if tensor_dtype is None:
|
|
if not inputs or not isinstance(inputs, (list, tuple)):
|
|
raise ValueError("inputs must be a list of at least one Tensor with the "
|
|
"same dtype and shape")
|
|
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
|
|
if not all(isinstance(x, ops.Tensor) for x in inputs):
|
|
raise ValueError("inputs must be a list of at least one Tensor with the "
|
|
"same dtype and shape")
|
|
if not all(x.dtype == inputs[0].dtype for x in inputs):
|
|
raise ValueError("inputs must be a list of at least one Tensor with the "
|
|
"same dtype and shape")
|
|
tensor_dtype = inputs[0].dtype
|
|
if shape is not None:
|
|
shape = tensor_shape.as_shape(shape)
|
|
else:
|
|
shape = tensor_shape.unknown_shape()
|
|
for input_tensor in inputs:
|
|
if isinstance(input_tensor, ops.Tensor):
|
|
shape = shape.merge_with(input_tensor.get_shape())
|
|
if not shape.is_fully_defined():
|
|
# TODO(pbar): Make a version of assign_add that accepts an uninitialized
|
|
# lvalue, and takes its shape from that? This would allow accumulate_n to
|
|
# work in all situations that add_n currently works.
|
|
raise ValueError("Cannot infer the shape of the accumulator for "
|
|
"accumulate_n. Pass the shape argument, or set the shape "
|
|
"of at least one of the inputs.")
|
|
with ops.op_scope(inputs, name, "AccumulateN") as name:
|
|
var = gen_state_ops._temporary_variable(shape=shape, dtype=tensor_dtype)
|
|
var_name = var.op.name
|
|
var = state_ops.assign(var, array_ops.zeros_like(inputs[0]))
|
|
update_ops = []
|
|
for input_tensor in inputs:
|
|
op = state_ops.assign_add(var, input_tensor, use_locking=True)
|
|
update_ops.append(op)
|
|
with ops.control_dependencies(update_ops):
|
|
return gen_state_ops._destroy_temporary_variable(var,
|
|
var_name=var_name,
|
|
name=name)
|
|
|
|
|
|
@ops.RegisterShape("BatchMatMul")
|
|
def _BatchMatMulShape(op):
|
|
"""Shape function for BatchMatMul op."""
|
|
a_shape = op.inputs[0].get_shape()
|
|
adj_a = op.get_attr("adj_x")
|
|
b_shape = op.inputs[1].get_shape()
|
|
adj_b = op.get_attr("adj_y")
|
|
if not a_shape.is_fully_defined() or not b_shape.is_fully_defined():
|
|
return [tensor_shape.unknown_shape()]
|
|
batch_dims = a_shape[:-2].merge_with(b_shape[:-2])
|
|
output_rows = a_shape[-1] if adj_a else a_shape[-2]
|
|
output_cols = b_shape[-2] if adj_b else b_shape[-1]
|
|
inner_a = a_shape[-2] if adj_a else a_shape[-1]
|
|
inner_b = b_shape[-1] if adj_b else b_shape[-2]
|
|
inner_a.assert_is_compatible_with(inner_b)
|
|
return [batch_dims.concatenate([output_rows, output_cols])]
|
|
|
|
|
|
def sigmoid(x, name=None):
|
|
"""Computes sigmoid of `x` element-wise.
|
|
|
|
Specifically, `y = 1 / (1 + exp(-x))`.
|
|
|
|
Args:
|
|
x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
|
|
or `qint32`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A Tensor with the same type as `x` if `x.dtype != qint32`
|
|
otherwise the return type is `quint8`.
|
|
"""
|
|
with ops.op_scope([x], name, "Sigmoid") as name:
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
return gen_math_ops._sigmoid(x, name=name)
|
|
|
|
|
|
def tanh(x, name=None):
|
|
"""Computes hyperbolic tangent of `x` element-wise.
|
|
|
|
Args:
|
|
x: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
|
|
or `qint32`.
|
|
name: A name for the operation (optional).
|
|
|
|
Returns:
|
|
A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
|
|
the return type is `quint8`.
|
|
"""
|
|
with ops.op_scope([x], name, "Tanh") as name:
|
|
x = ops.convert_to_tensor(x, name="x")
|
|
return gen_math_ops._tanh(x, name=name)
|
|
|
|
|
|
ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Ceil")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Conj")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Cos")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Exp")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Floor")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Imag")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Inv")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("IsFinite")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("IsInf")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("IsNan")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Log")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("LogicalNot")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Neg")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Real")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Rsqrt")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Sign")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Sin")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Sqrt")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Square")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Tanh")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("Cast")(common_shapes.unchanged_shape)
|
|
ops.RegisterShape("ComplexAbs")(common_shapes.unchanged_shape)
|
|
|
|
|
|
@ops.RegisterShape("Add")
|
|
@ops.RegisterShape("Complex")
|
|
@ops.RegisterShape("Div")
|
|
@ops.RegisterShape("Equal")
|
|
@ops.RegisterShape("Greater")
|
|
@ops.RegisterShape("GreaterEqual")
|
|
@ops.RegisterShape("Less")
|
|
@ops.RegisterShape("LessEqual")
|
|
@ops.RegisterShape("LogicalAnd")
|
|
@ops.RegisterShape("LogicalOr")
|
|
@ops.RegisterShape("Maximum")
|
|
@ops.RegisterShape("Minimum")
|
|
@ops.RegisterShape("Mod")
|
|
@ops.RegisterShape("Mul")
|
|
@ops.RegisterShape("NotEqual")
|
|
@ops.RegisterShape("Pow")
|
|
@ops.RegisterShape("Sub")
|
|
def _BroadcastShape(op):
|
|
"""Common shape function for binary operators that broadcast their inputs."""
|
|
shape_x = op.inputs[0].get_shape()
|
|
shape_y = op.inputs[1].get_shape()
|
|
if shape_x.ndims is None or shape_y.ndims is None:
|
|
return [tensor_shape.unknown_shape()]
|
|
|
|
# To compute the broadcasted dimensions, we zip together shape_x and shape_y,
|
|
# and pad with 1 to make them the same length.
|
|
broadcasted_dims = reversed(list(six.moves.zip_longest(
|
|
reversed(shape_x.dims),
|
|
reversed(shape_y.dims),
|
|
fillvalue=tensor_shape.Dimension(1))))
|
|
# Next we combine the dimensions according to the numpy broadcasting rules.
|
|
# http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
|
|
return_dims = []
|
|
for (dim_x, dim_y) in broadcasted_dims:
|
|
if dim_x.value is None or dim_y.value is None:
|
|
# One or both dimensions is unknown. If either dimension is greater than
|
|
# 1, we assume that the program is correct, and the other dimension will
|
|
# be broadcast to match it.
|
|
# TODO(mrry): If we eliminate the shape checks in C++, we must still
|
|
# assert that the unknown dim is either 1 or the same as the known dim.
|
|
if dim_x.value is not None and dim_x.value > 1:
|
|
return_dims.append(dim_x)
|
|
elif dim_y.value is not None and dim_y.value > 1:
|
|
return_dims.append(dim_y)
|
|
else:
|
|
return_dims.append(None)
|
|
elif dim_x.value == 1:
|
|
# We will broadcast dim_x to dim_y.
|
|
return_dims.append(dim_y)
|
|
elif dim_y.value == 1:
|
|
# We will broadcast dim_y to dim_x.
|
|
return_dims.append(dim_x)
|
|
elif dim_x.value == dim_y.value:
|
|
# The dimensions are compatible, so output is the same size in that
|
|
# dimension.
|
|
return_dims.append(dim_x.merge_with(dim_y))
|
|
else:
|
|
raise ValueError("Incompatible shapes for broadcasting: %s and %s"
|
|
% (shape_x, shape_y))
|
|
return [tensor_shape.TensorShape(return_dims)]
|
|
|
|
|
|
@ops.RegisterShape("AddN")
|
|
def _AddNShape(op):
|
|
merged_shape = tensor_shape.unknown_shape()
|
|
for input_ in op.inputs:
|
|
merged_shape = merged_shape.merge_with(input_.get_shape())
|
|
return [merged_shape]
|
|
|
|
|
|
@ops.RegisterShape("Select")
|
|
def _SelectShape(op):
|
|
# All three inputs must have the same shape.
|
|
return [op.inputs[0].get_shape()
|
|
.merge_with(op.inputs[1].get_shape())
|
|
.merge_with(op.inputs[2].get_shape())]
|
|
|
|
|
|
@ops.RegisterShape("ArgMax")
|
|
@ops.RegisterShape("ArgMin")
|
|
def _ArgOpShape(op):
|
|
"""Common shape function for arg-reduction ops."""
|
|
dimension_shape = op.inputs[1].get_shape()
|
|
dimension_shape.assert_is_compatible_with(tensor_shape.scalar())
|
|
input_shape = op.inputs[0].get_shape()
|
|
if input_shape.ndims is None:
|
|
return [tensor_shape.unknown_shape()]
|
|
elif input_shape.ndims <= 1:
|
|
return [tensor_shape.scalar()]
|
|
|
|
dimension = tensor_util.ConstantValue(op.inputs[1])
|
|
if dimension is None:
|
|
return [tensor_shape.unknown_shape(ndims=input_shape.ndims - 1)]
|
|
elif 0 <= dimension and dimension < input_shape.ndims:
|
|
returned_shape = []
|
|
for i, dim in enumerate(input_shape.dims):
|
|
if i != dimension:
|
|
returned_shape.append(dim)
|
|
return [tensor_shape.TensorShape(returned_shape)]
|
|
else:
|
|
raise ValueError(
|
|
"dimension (%d) must be in the range [0, %d), where %d is the number "
|
|
"of dimensions in the input"
|
|
% (dimension, input_shape.ndims, input_shape.ndims))
|
|
|
|
|
|
@ops.RegisterShape("All")
|
|
@ops.RegisterShape("Any")
|
|
@ops.RegisterShape("Max")
|
|
@ops.RegisterShape("Mean")
|
|
@ops.RegisterShape("Min")
|
|
@ops.RegisterShape("Prod")
|
|
@ops.RegisterShape("Sum")
|
|
def _ReductionShape(op):
|
|
"""Common shape function for reduction ops."""
|
|
input_shape = op.inputs[0].get_shape()
|
|
reduction_indices = tensor_util.ConstantValue(op.inputs[1])
|
|
keep_dims = op.get_attr("keep_dims")
|
|
if reduction_indices is None or input_shape.ndims is None:
|
|
if keep_dims:
|
|
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
|
|
else:
|
|
return [tensor_shape.unknown_shape()]
|
|
|
|
# Turn reduction_indices from scalar to vector if necessary
|
|
reduction_indices = np.ravel(reduction_indices)
|
|
|
|
for reduction_index in reduction_indices:
|
|
if reduction_index < 0 or reduction_index >= input_shape.ndims:
|
|
raise ValueError("Invalid reduction dimension %d for input with %d "
|
|
"dimensions" % (reduction_index, input_shape.ndims))
|
|
|
|
returned_dims = []
|
|
if keep_dims:
|
|
for i, dim in enumerate(input_shape.dims):
|
|
if i in reduction_indices:
|
|
returned_dims.append(1)
|
|
else:
|
|
returned_dims.append(dim)
|
|
else:
|
|
for i, dim in enumerate(input_shape.dims):
|
|
if i not in reduction_indices:
|
|
returned_dims.append(dim)
|
|
return [tensor_shape.TensorShape(returned_dims)]
|
|
|
|
|
|
@ops.RegisterShape("SegmentMax")
|
|
@ops.RegisterShape("SegmentMean")
|
|
@ops.RegisterShape("SegmentMin")
|
|
@ops.RegisterShape("SegmentProd")
|
|
@ops.RegisterShape("SegmentSum")
|
|
def _SegmentReductionShape(op):
|
|
"""Common shape function for segment reduction ops."""
|
|
data_shape = op.inputs[0].get_shape()
|
|
segment_ids_shape = op.inputs[1].get_shape()
|
|
segment_ids_shape.assert_has_rank(1)
|
|
return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
|
|
|
|
|
|
@ops.RegisterShape("SparseSegmentMean")
|
|
@ops.RegisterShape("SparseSegmentSum")
|
|
def _SparseSegmentReductionShape(op):
|
|
"""Common shape function for sparse segment reduction ops."""
|
|
data_shape = op.inputs[0].get_shape()
|
|
indices_shape = op.inputs[1].get_shape()
|
|
indices_shape.assert_has_rank(1)
|
|
segment_ids_shape = op.inputs[2].get_shape()
|
|
segment_ids_shape.assert_has_rank(1)
|
|
indices_shape.assert_is_compatible_with(segment_ids_shape)
|
|
return [tensor_shape.TensorShape([None]).concatenate(data_shape[1:])]
|
|
|
|
|
|
@ops.RegisterShape("SparseSegmentMeanGrad")
|
|
def _SparseSegmentMeanGradShape(op):
|
|
"""Shape function for the SparseSegmentMeanGrad op."""
|
|
input_shape = op.inputs[0].get_shape()
|
|
indices_shape = op.inputs[1].get_shape().with_rank(1)
|
|
unused_segment_ids_shape = op.inputs[2].get_shape().merge_with(indices_shape)
|
|
unused_output_dim0_shape = op.inputs[3].get_shape().merge_with(
|
|
tensor_shape.scalar())
|
|
output_dim0 = tensor_util.ConstantValue(op.inputs[3])
|
|
if output_dim0 is not None:
|
|
dim0 = output_dim0[0]
|
|
else:
|
|
dim0 = None
|
|
return [tensor_shape.TensorShape([dim0]).concatenate(input_shape[1:])]
|
|
|
|
|
|
@ops.RegisterShape("UnsortedSegmentSum")
|
|
def _UnsortedSegmentSumShape(op):
|
|
"""Shape function for UnsortedSegmentSum."""
|
|
data_shape = op.inputs[0].get_shape()
|
|
segment_ids_shape = op.inputs[1].get_shape()
|
|
mid = segment_ids_shape.ndims
|
|
if mid is None:
|
|
return [tensor_shape.unknown_shape()]
|
|
else:
|
|
num_segments = tensor_util.ConstantValue(op.inputs[2])
|
|
return [tensor_shape.TensorShape([num_segments]).concatenate(
|
|
data_shape[mid:])]
|
|
|
|
|
|
@ops.RegisterShape("LinSpace")
|
|
def _LinspaceShape(op):
|
|
num = tensor_util.ConstantValue(op.inputs[2])
|
|
return [tensor_shape.vector(num)]
|