Initial version of tf.contrib.labeled_tensor

Change: 139143754
This commit is contained in:
Stephan Hoyer 2016-11-14 17:24:44 -08:00 committed by TensorFlower Gardener
parent 887892a499
commit 9d20f4ea4b
19 changed files with 5483 additions and 0 deletions

View File

@ -98,6 +98,7 @@ filegroup(
"//tensorflow/contrib/graph_editor:all_files",
"//tensorflow/contrib/grid_rnn:all_files",
"//tensorflow/contrib/integrate:all_files",
"//tensorflow/contrib/labeled_tensor:all_files",
"//tensorflow/contrib/layers:all_files",
"//tensorflow/contrib/layers/kernels:all_files",
"//tensorflow/contrib/learn:all_files",

View File

@ -24,6 +24,7 @@ py_library(
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",

View File

@ -29,6 +29,7 @@ from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
from tensorflow.contrib import integrate
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import linear_optimizer

View File

@ -0,0 +1,166 @@
# Description:
# Labels for TensorFlow.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
py_library(
name = "labeled_tensor",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":core",
":io_ops",
":nn",
":ops",
":sugar",
],
)
py_library(
name = "_typecheck",
srcs = ["python/ops/_typecheck.py"],
srcs_version = "PY2AND3",
visibility = [":__subpackages__"],
)
py_library(
name = "core",
srcs = ["python/ops/core.py"],
srcs_version = "PY2AND3",
deps = [
":_typecheck",
],
)
py_library(
name = "test_util",
srcs = ["python/ops/test_util.py"],
srcs_version = "PY2AND3",
deps = [
":_typecheck",
":core",
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "core_test",
size = "small",
srcs = [
"python/ops/core_test.py",
],
srcs_version = "PY2AND3",
deps = [
":core",
":test_util",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "io_ops",
srcs = ["python/ops/io_ops.py"],
srcs_version = "PY2AND3",
deps = [
":core",
],
)
py_test(
name = "io_ops_test",
size = "small",
srcs = [
"python/ops/io_ops_test.py",
],
srcs_version = "PY2AND3",
deps = [
":io_ops",
":ops",
":test_util",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "nn",
srcs = ["python/ops/nn.py"],
srcs_version = "PY2AND3",
deps = [
":core",
],
)
py_test(
name = "nn_test",
size = "small",
srcs = [
"python/ops/nn_test.py",
],
srcs_version = "PY2AND3",
deps = [
":nn",
":test_util",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "ops",
srcs = ["python/ops/ops.py"],
srcs_version = "PY2AND3",
deps = [
":core",
],
)
py_test(
name = "ops_test",
srcs = [
"python/ops/ops_test.py",
],
srcs_version = "PY2AND3",
deps = [
":ops",
":test_util",
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "sugar",
srcs = ["python/ops/sugar.py"],
srcs_version = "PY2AND3",
deps = [
":core",
":ops",
],
)
py_test(
name = "sugar_test",
size = "small",
srcs = [
"python/ops/sugar_test.py",
],
srcs_version = "PY2AND3",
deps = [
":sugar",
":test_util",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
)

View File

@ -0,0 +1,8 @@
# Labels for TensorFlow
LabeledTensor is a library for adding semantically meaningful dimension and
coordinate labels to tensors in Tensorflow.
Maintainers:
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
- Eric Christiansen (ericmc@google.com, github.com/emchristiansen)

View File

@ -0,0 +1,139 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Labels for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.labeled_tensor.python.ops import core as _core
from tensorflow.contrib.labeled_tensor.python.ops import io_ops as _io_ops
from tensorflow.contrib.labeled_tensor.python.ops import nn
from tensorflow.contrib.labeled_tensor.python.ops import ops as _ops
from tensorflow.contrib.labeled_tensor.python.ops import sugar as _sugar
# pylint: disable=invalid-name
# Core types.
Axis = _core.Axis
Axes = _core.Axes
LabeledTensor = _core.LabeledTensor
as_axis = _core.as_axis
convert_to_labeled_tensor = _core.convert_to_labeled_tensor
identity = _core.identity
slice = _core.slice_function # pylint: disable=redefined-builtin
transpose = _core.transpose
expand_dims = _core.expand_dims
align = _core.align
axis_order_scope = _core.axis_order_scope
check_axis_order = _core.check_axis_order
impose_axis_order = _core.impose_axis_order
AxisOrderError = _core.AxisOrderError
define_unary_op = _core.define_unary_op
define_binary_op = _core.define_binary_op
define_reduce_op = _ops.define_reduce_op
abs = _core.abs_function # pylint: disable=redefined-builtin
neg = _core.neg
sign = _core.sign
inv = _core.inv
square = _core.square
round = _core.round_function # pylint: disable=redefined-builtin
sqrt = _core.sqrt
rsqrt = _core.rsqrt
exp = _core.exp
log = _core.log
ceil = _core.ceil
floor = _core.floor
cos = _core.cos
sin = _core.sin
tan = _core.tan
acos = _core.acos
asin = _core.asin
atan = _core.atan
lgamma = _core.lgamma
digamma = _core.digamma
erf = _core.erf
erfc = _core.erfc
logical_not = _core.logical_not
add = _core.add
sub = _core.sub
mul = _core.mul
div = _core.div
mod = _core.mod
pow = _core.pow_function # pylint: disable=redefined-builtin
equal = _core.equal
greater = _core.greater
greater_equal = _core.greater_equal
not_equal = _core.not_equal
less = _core.less
less_equal = _core.less_equal
logical_and = _core.logical_and
logical_or = _core.logical_or
logical_xor = _core.logical_xor
maximum = _core.maximum
minimum = _core.minimum
squared_difference = _core.squared_difference
igamma = _core.igamma
igammac = _core.igammac
zeta = _core.zeta
polygamma = _core.polygamma
select = _ops.select
concat = _ops.concat
pack = _ops.pack
unpack = _ops.unpack
reshape = _ops.reshape
rename_axis = _ops.rename_axis
random_crop = _ops.random_crop
map_fn = _ops.map_fn
squeeze = _ops.squeeze
matmul = _ops.matmul
tile = _ops.tile
pad = _ops.pad
constant = _ops.constant
zeros_like = _ops.zeros_like
ones_like = _ops.ones_like
cast = _ops.cast
verify_tensor_all_finite = _ops.verify_tensor_all_finite
boolean_mask = _ops.boolean_mask
where = _ops.where
reduce_all = _ops.reduce_all
reduce_any = _ops.reduce_any
reduce_logsumexp = _ops.reduce_logsumexp
reduce_max = _ops.reduce_max
reduce_mean = _ops.reduce_mean
reduce_min = _ops.reduce_min
reduce_prod = _ops.reduce_prod
reduce_sum = _ops.reduce_sum
batch = _ops.batch
shuffle_batch = _ops.shuffle_batch
FixedLenFeature = _io_ops.FixedLenFeature
parse_example = _io_ops.parse_example
parse_single_example = _io_ops.parse_single_example
placeholder = _io_ops.placeholder
ReshapeCoder = _sugar.ReshapeCoder

View File

@ -0,0 +1,322 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Minimal runtime type checking library.
This module should not be considered public API.
"""
# TODO(ericmc,shoyer): Delete this in favor of using pytype or mypy
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import inspect
import re
# used for register_type_abbreviation and _type_repr below.
_TYPE_ABBREVIATIONS = {}
class Type(object):
"""Base class for type checker types.
The custom types defined in this module are based on types in the standard
library's typing module (in Python 3.5):
https://docs.python.org/3/library/typing.html
The only difference should be that we use actual instances of Type classes to
represent custom types rather than the metaclass magic typing uses to create
new class objects. In practice, all this should mean is that we use
`List(int)` rather than `List[int]`.
Custom types should implement __instancecheck__ and inherit from Type. Every
argument in the constructor must be a type or Type instance, and these
arguments must be stored as a tuple on the `_types` attribute.
"""
def __init__(self, *types):
self._types = types
def __repr__(self):
args_repr = ", ".join(repr(t) for t in self._types)
return "typecheck.%s(%s)" % (type(self).__name__, args_repr)
class _SingleArgumentType(Type):
"""Use this subclass for parametric types that accept only one argument."""
def __init__(self, tpe):
super(_SingleArgumentType, self).__init__(tpe)
@property
def _type(self):
tpe, = self._types # pylint: disable=unbalanced-tuple-unpacking
return tpe
class _TwoArgumentType(Type):
"""Use this subclass for parametric types that accept two arguments."""
def __init__(self, first_type, second_type):
super(_TwoArgumentType, self).__init__(first_type, second_type)
class Union(Type):
"""A sum type.
A correct type is any of the types provided.
"""
def __instancecheck__(self, instance):
return isinstance(instance, self._types)
class Optional(_SingleArgumentType):
"""An optional type.
A correct type is either the provided type or NoneType.
"""
def __instancecheck__(self, instance):
# types.NoneType does not exist in Python 3
return isinstance(instance, (self._type, type(None)))
class List(_SingleArgumentType):
"""A typed list.
A correct type is a list where each element has the single provided type.
"""
def __instancecheck__(self, instance):
return (isinstance(instance, list)
and all(isinstance(x, self._type) for x in instance))
class Sequence(_SingleArgumentType):
"""A typed sequence.
A correct type is a sequence where each element has the single provided type.
"""
def __instancecheck__(self, instance):
return (isinstance(instance, collections.Sequence)
and all(isinstance(x, self._type) for x in instance))
class Collection(_SingleArgumentType):
"""A sized, iterable container.
A correct type is an iterable and container with known size where each element
has the single provided type.
We use this in preference to Iterable because we check each instance of the
iterable at runtime, and hence need to avoid iterables that could be
exhausted.
"""
def __instancecheck__(self, instance):
return (isinstance(instance, collections.Iterable)
and isinstance(instance, collections.Sized)
and isinstance(instance, collections.Container)
and all(isinstance(x, self._type) for x in instance))
class Tuple(Type):
"""A typed tuple.
A correct type is a tuple with the correct length where each element has
the correct type.
"""
def __instancecheck__(self, instance):
return (isinstance(instance, tuple)
and len(instance) == len(self._types)
and all(isinstance(x, t) for x, t in zip(instance, self._types)))
class Mapping(_TwoArgumentType):
"""A typed mapping.
A correct type has the correct parametric types for keys and values.
"""
def __instancecheck__(self, instance):
key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking
return (isinstance(instance, collections.Mapping)
and all(isinstance(k, key_type) for k in instance.keys())
and all(isinstance(k, value_type) for k in instance.values()))
class Dict(Mapping):
"""A typed dict.
A correct type has the correct parametric types for keys and values.
"""
def __instancecheck__(self, instance):
return (isinstance(instance, dict)
and super(Dict, self).__instancecheck__(instance))
def _replace_forward_references(t, context):
"""Replace forward references in the given type."""
if isinstance(t, str):
return context[t]
elif isinstance(t, Type):
return type(t)(*[_replace_forward_references(t, context) for t in t._types]) # pylint: disable=protected-access
else:
return t
def register_type_abbreviation(name, alias):
"""Register an abbreviation for a type in typecheck tracebacks.
This makes otherwise very long typecheck errors much more readable.
Example:
typecheck.register_type_abbreviation(tf.Dimension, 'tf.Dimension')
Args:
name: type or class to abbreviate.
alias: string alias to substitute.
"""
_TYPE_ABBREVIATIONS[name] = alias
def _type_repr(t):
"""A more succinct repr for typecheck tracebacks."""
string = repr(t)
for type_, alias in _TYPE_ABBREVIATIONS.items():
string = string.replace(repr(type_), alias)
string = re.sub(r"<(class|type) '([\w.]+)'>", r"\2", string)
string = re.sub(r"typecheck\.(\w+)", r"\1", string)
return string
class Error(TypeError):
"""Exception for typecheck failures."""
def accepts(*types):
"""A decorator which checks the input types of a function.
Based on:
http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
The above draws from:
https://www.python.org/dev/peps/pep-0318/
Args:
*types: A list of Python types.
Returns:
A function to use as a decorator.
"""
def check_accepts(f):
"""Check the types."""
spec = inspect.getargspec(f)
num_function_arguments = len(spec.args)
if len(types) != num_function_arguments:
raise Error(
"Function %r has %d arguments but only %d types were provided in the "
"annotation." % (f, num_function_arguments, len(types)))
if spec.defaults:
num_defaults = len(spec.defaults)
for (name, a, t) in zip(spec.args[-num_defaults:],
spec.defaults,
types[-num_defaults:]):
allowed_type = _replace_forward_references(t, f.__globals__)
if not isinstance(a, allowed_type):
raise Error("default argument value %r of type %r is not an instance "
"of the allowed type %s for the %s argument to %r"
% (a, type(a), _type_repr(allowed_type), name, f))
@functools.wraps(f)
def new_f(*args, **kwds):
"""A helper function."""
for (a, t) in zip(args, types):
allowed_type = _replace_forward_references(t, f.__globals__)
if not isinstance(a, allowed_type):
raise Error("%r of type %r is not an instance of the allowed type %s "
"for %r" % (a, type(a), _type_repr(allowed_type), f))
return f(*args, **kwds)
return new_f
return check_accepts
def returns(*types):
"""A decorator which checks the return types of a function.
Based on:
http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
The above draws from:
https://www.python.org/dev/peps/pep-0318/
Args:
*types: A list of Python types.
A list of one element corresponds to a single return value.
A list of several elements corresponds to several return values.
Note that a function with no explicit return value has an implicit
NoneType return and should be annotated correspondingly.
Returns:
A function to use as a decorator.
"""
def check_returns(f):
"""Check the types."""
if not types:
raise TypeError("A return type annotation must contain at least one type")
@functools.wraps(f)
def new_f(*args, **kwds):
"""A helper function."""
return_value = f(*args, **kwds)
if len(types) == 1:
# The function has a single return value.
allowed_type = _replace_forward_references(types[0], f.__globals__)
if not isinstance(return_value, allowed_type):
raise Error("%r of type %r is not an instance of the allowed type %s "
"for %r"
% (return_value, type(return_value),
_type_repr(allowed_type), f))
else:
if len(return_value) != len(types):
raise Error(
"Function %r has %d return values but only %d types were "
"provided in the annotation." %
(f, len(return_value), len(types)))
for (r, t) in zip(return_value, types):
allowed_type = _replace_forward_references(t, f.__globals__)
if not isinstance(r, allowed_type):
raise Error("%r of type %r is not an instance of allowed type %s "
"for %r" % (r, type(r), _type_repr(allowed_type), f))
return return_value
return new_f
return check_returns

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,842 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
import re
import textwrap
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import test_util
class AxisTest(tf.test.TestCase):
def setUp(self):
d_7 = tf.Dimension(7)
p_rgb = ['red', 'green', 'blue']
self.i_7 = core.Axis('7', d_7)
self.i_7p = core.Axis('7prime', d_7)
self.i_rgb = core.Axis('rgb', p_rgb)
self.i_range = core.Axis('range', range(7))
self.i_unknown = core.Axis('unknown', None)
def test_equality(self):
axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
for i, axis_0 in enumerate(axes):
for j, axis_1 in enumerate(axes):
if i == j:
self.assertEqual(axis_0, axis_1)
else:
self.assertNotEqual(axis_0, axis_1)
def test_axis_value(self):
self.assertEqual(self.i_7.value, tf.Dimension(7))
self.assertTrue(self.i_range.value == tuple(range(7)))
def test_axis_input(self):
axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
for axis in axes:
self.assertEqual(axis, core.Axis(axis.name, axis.value))
def test_axis_value_input(self):
axis = self.i_range
for value in [range(7), list(range(7)), np.arange(7)]:
self.assertEqual(axis, core.Axis(axis.name, value))
def test_size(self):
self.assertEqual(len(self.i_7), 7)
self.assertEqual(len(self.i_rgb), 3)
self.assertEqual(len(self.i_range), 7)
self.assertEqual(self.i_unknown.size, None)
def test_concat_single(self):
red = core.Axis('rgb', ['red'])
self.assertEqual(core.concat_axes([red]), red)
def test_concat_many(self):
red = core.Axis('rgb', ['red'])
green = core.Axis('rgb', ['green'])
blue = core.Axis('rgb', ['blue'])
red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])
self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
def test_concat_different_names(self):
red = core.Axis('red', ['red'])
green = core.Axis('green', ['red'])
with self.assertRaises(ValueError):
core.concat_axes([red, green])
def test_concat_unknown(self):
red = core.Axis('rgb', None)
green = core.Axis('rgb', None)
self.assertEqual(core.concat_axes([red, green]), red)
def test_repr(self):
self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7))
def test_invalid_input(self):
with self.assertRaises(TypeError):
core.Axis('foo', [{}])
with self.assertRaises(ValueError):
core.Axis('foo', [1, 2, 3, 1])
red = core.Axis('foo', ['red'])
with self.assertRaises(tc.Error):
core.concat_axes([red, 1])
def test_as_axis(self):
self.assertEqual(self.i_7, core.as_axis(('7', 7)))
self.assertEqual(self.i_7, core.as_axis(self.i_7))
class AxesTest(tf.test.TestCase):
def setUp(self):
d_7 = tf.Dimension(7)
d_8 = tf.Dimension(8)
p_rgb = ['red', 'green', 'blue']
p_range = range(7)
self.i_8 = core.Axis('8', d_8)
self.a0 = core.Axes([('d7', d_7)])
self.a1 = core.Axes([('d7', d_7)])
self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
self.a3 = core.Axes([('8', d_8), ('range', p_range)])
def test_equality(self):
self.assertEqual(self.a0, self.a0)
self.assertEqual(self.a0, self.a1)
self.assertNotEqual(self.a0, self.a2)
def test_repr(self):
self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0))
def test_remove(self):
a = self.a3.remove('range')
self.assertEqual(a, core.Axes([self.i_8]))
with self.assertRaises(KeyError):
self.a3.remove('foobar')
def test_typecheck_error_message(self):
pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
'Union(Union(numpy.ndarray, %s, list, tuple), '
'Optional(Union(tensorflow.Dimension, int))))))' %
range.__name__)
regexp = re.escape(pattern).replace(re.escape('...'), '.*')
with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
core.Axes(None)
class LabeledTensorTest(test_util.Base):
def setUp(self):
tensor = tf.ones([7, 3, 8, 1])
a0 = ('x', range(7))
a1 = ('channel', ['red', 'green', 'blue'])
a2 = ('y', 8)
a3 = ('z', tf.Dimension(1))
self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
def test_repr(self):
pattern = textwrap.dedent("""\
<LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32
axes=[('x', ...),
('channel', ...),
('y', Dimension(8)),
('z', Dimension(1))]>""")
regexp = re.escape(pattern).replace(re.escape('...'), '.*')
self.assertRegexpMatches(repr(self.lt), regexp)
def test_reuse_existing_axes(self):
alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes)
self.assertLabeledTensorsEqual(alt_lt, self.lt)
def test_reuse_existing_axis_objects(self):
alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values())
self.assertLabeledTensorsEqual(alt_lt, self.lt)
def test_indexing_scalars(self):
actual = self.lt[:, :, :, 0]
expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0],
list(self.lt.axes.values())[:-1])
self.assertLabeledTensorsEqual(actual, expected)
actual = self.lt[1, :, :, 0]
expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0],
list(self.lt.axes.values())[1:-1])
self.assertLabeledTensorsEqual(actual, expected)
actual = self.lt[1, 2, :, 0]
expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0],
list(self.lt.axes.values())[2:-1])
self.assertLabeledTensorsEqual(actual, expected)
def test_indexing_1d(self):
lt_1d = self.lt[1, 2, :, 0]
actual = lt_1d[3]
expected = core.LabeledTensor(lt_1d.tensor[3], [])
self.assertLabeledTensorsEqual(actual, expected)
def test_indexing_slices(self):
actual = self.lt[:3, :, :, :]
axes = [('x', range(3))] + list(self.lt.axes.values())[1:]
expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes)
self.assertLabeledTensorsEqual(actual, expected)
def test_invalid_indexing(self):
with self.assertRaises(ValueError):
self.lt[0] # pylint: disable=pointless-statement
with self.assertRaises(ValueError):
self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement
def test_unknown_size(self):
tensor = tf.placeholder(tf.string, [None])
actual = core.LabeledTensor(tensor, ['x'])
self.assertIsNone(actual.axes['x'].size)
self.assertIs(actual.axes['x'].value, tensor.get_shape()[0])
def test_eq(self):
self.assertEqual(self.lt, self.lt)
self.assertNotEqual(self.lt, self.lt.tensor)
self.assertNotEqual(self.lt.tensor, self.lt)
def test_hash(self):
lt1 = self.lt
lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes)
self.assertEqual(lt1, lt2)
self.assertEqual(hash(lt1), hash(lt2))
def test_name(self):
self.assertEqual(self.lt.name, self.lt.tensor.name)
def test_dtype(self):
self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
def test_get_shape(self):
self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())
def test_convert_to_tensor(self):
expected = self.lt.tensor
actual = tf.convert_to_tensor(self.lt)
self.assertIs(expected, actual)
class Base(test_util.Base):
def setUp(self):
self.x_size = 7
self.channel_size = 3
self.z_size = 4
self.probs_size = 11
tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
self.probs_size)
tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
self.probs_size])
a0 = ('x', range(self.x_size))
a1 = ('channel', ['red', 'green', 'blue'])
a2 = 'z'
a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
self.tensor = tensor
self.a0 = a0
self.a1 = a1
self.a2 = a2
self.a3 = a3
self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0,
'channel': 0})
self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
'z': 0})
class IdentityTest(Base):
def test_name(self):
identity_lt = core.identity(self.original_lt)
self.assertIn('lt_identity', identity_lt.name)
class SliceFunctionTest(Base):
def test_name(self):
select_lt = core.slice_function(self.original_lt, {'channel': 1})
self.assertIn('lt_slice', select_lt.name)
def test_scalar(self):
select_lt = core.slice_function(self.original_lt, {'channel': 1})
golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slice(self):
select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)})
a1_sliced = ('channel', ['red', 'green'])
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
[self.a0, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slices(self):
select_lt = core.slice_function(self.original_lt, {'x': slice(1, 5),
'channel': slice(1,
None)})
a0_sliced = ('x', range(1, 5))
a1_sliced = ('channel', ['green', 'blue'])
golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
[a0_sliced, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slice_unlabeled(self):
select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)})
a2_sliced = 'z'
golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :],
[self.a0, self.a1, a2_sliced, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slice_unknown_shape(self):
lt = core.LabeledTensor(tf.placeholder(tf.float32, [None, 1]), ['x', 'y'])
sliced_lt = core.slice_function(lt, {'y': 0})
self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']])
class TransposeTest(Base):
def test_name(self):
transpose_lt = core.transpose(self.original_lt,
self.original_lt.axes.keys())
self.assertIn('lt_transpose', transpose_lt.name)
def test_identity(self):
transpose_lt = core.transpose(self.original_lt,
self.original_lt.axes.keys())
golden_lt = self.original_lt
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test(self):
transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x',
'probs'])
golden_lt = core.LabeledTensor(
tf.transpose(self.tensor, [2, 1, 0, 3]),
[self.a2, self.a1, self.a0, self.a3])
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_default_axis_order(self):
transpose_lt = core.transpose(self.original_lt)
golden_lt = core.LabeledTensor(
tf.transpose(self.tensor, [3, 2, 1, 0]),
list(reversed(list(self.original_lt.axes.values()))))
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
core.transpose(self.original_lt, ['channel', 'x', 'probs'])
with self.assertRaises(ValueError):
core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])
class ExpandDimsTest(Base):
def test_name(self):
expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
self.assertIn('lt_expand', expand_lt.name)
def test_identity(self):
expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
golden_lt = self.original_lt
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
def test(self):
expand_lt = core.expand_dims(self.original_lt, ['foo', 'x', 'bar',
'channel', 'z', 'probs',
'grok'])
golden_lt = core.LabeledTensor(
tf.reshape(self.tensor, [1, self.x_size, 1, self.channel_size,
self.z_size, self.probs_size, 1]),
['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
def test_label(self):
expand_lt = core.expand_dims(self.original_lt, ['x',
'channel',
('foo', 'bar'),
'z',
'probs',])
golden_lt = core.LabeledTensor(
tf.reshape(self.tensor, [self.x_size, self.channel_size, 1, self.z_size,
self.probs_size]),
[self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
def test_unknown_dimension(self):
orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
def test_invalid_input(self):
with self.assertRaises(core.AxisOrderError):
core.expand_dims(self.original_lt, ['foo', 'not_x', 'bar', 'channel', 'z',
'probs', 'grok'])
with self.assertRaises(core.AxisOrderError):
core.expand_dims(self.original_lt, ['foo', 'z', 'bar', 'channel', 'x',
'probs', 'grok'])
class AxisOrderScopeTest(Base):
def test(self):
xyz = ['x', 'y', 'z']
abc = ['a', 'b', 'c']
self.assertIsNone(core.get_axis_order())
with core.axis_order_scope(xyz):
self.assertEqual(core.get_axis_order(), xyz)
with core.axis_order_scope():
self.assertIsNone(core.get_axis_order())
with core.axis_order_scope(abc):
self.assertEqual(core.get_axis_order(), abc)
self.assertIsNone(core.get_axis_order())
self.assertEqual(core.get_axis_order(), xyz)
self.assertIsNone(core.get_axis_order())
class CheckAxisOrderTest(Base):
def test_passes(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
core.check_axis_order(lt, axis_order)
lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[1:])
core.check_axis_order(lt, axis_order)
lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[:-1])
core.check_axis_order(lt, axis_order)
def test_invalid(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
with self.assertRaises(core.AxisOrderError):
core.check_axis_order(lt)
with self.assertRaises(core.AxisOrderError):
core.check_axis_order(lt, axis_order[:-1])
with self.assertRaises(core.AxisOrderError):
core.check_axis_order(lt, axis_order[::-1])
def test_scope(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
with core.axis_order_scope(axis_order):
core.check_axis_order(lt)
class ImposeAxisOrderTest(Base):
def test_identity(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
actual = core.impose_axis_order(lt, axis_order)
self.assertLabeledTensorsEqual(lt, actual)
lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
actual = core.impose_axis_order(lt, axis_order)
self.assertLabeledTensorsEqual(lt, actual)
def test_reverse(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
actual = core.impose_axis_order(lt, axis_order[::-1])
expected = core.transpose(lt, axis_order[::-1])
self.assertLabeledTensorsEqual(expected, actual)
lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
actual = core.impose_axis_order(lt, axis_order[::-1])
expected = core.transpose(lt, ['y', 'x', 'w'])
self.assertLabeledTensorsEqual(expected, actual)
def test_scope(self):
axis_order = ['w', 'x', 'y', 'z']
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
expected = core.transpose(lt, axis_order[::-1])
with core.axis_order_scope(axis_order[::-1]):
actual = core.impose_axis_order(lt)
self.assertLabeledTensorsEqual(expected, actual)
def test_invalid(self):
lt = core.LabeledTensor(tf.reshape(tf.range(2), (1, 2)), ['x', 'y'])
with self.assertRaises(ValueError):
core.impose_axis_order(lt)
with self.assertRaises(ValueError):
core.impose_axis_order(lt, ['x'])
class FindConsistentOrderingTest(Base):
def test(self):
cases = [
([], [], []),
(['x'], [], ['x']),
([], ['x'], ['x']),
(['x'], ['x'], ['x']),
(['x'], ['y'], ['x', 'y']),
(['y'], ['x'], ['y', 'x']),
(['x', 'y'], ['x', 'y'], ['x', 'y']),
(['x', 'y'], ['y', 'x'], None),
(['x', 'y'], ['y', 'z'], ['x', 'y', 'z']),
(['x', 'z'], ['y', 'z'], ['x', 'y', 'z']),
(['x', 'y'], ['x', 'z'], ['x', 'y', 'z']),
(['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']),
(['x', 'y', 'z'], ['z', 'x'], None),
(['x', 'y', 'z'], ['x'], ['x', 'y', 'z']),
([], ['x', 'y', 'z'], ['x', 'y', 'z']),
]
for a, b, expected in cases:
actual = core._find_consistent_ordering(a, b)
msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r'
% (a, b, expected, actual))
self.assertEqual(expected, actual, msg=msg)
class AlignTest(Base):
def test_name(self):
align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)
self.assertIn('lt_align', align_lt_0.name)
self.assertIn('/0', align_lt_0.name)
self.assertIn('lt_align', align_lt_1.name)
self.assertIn('/1', align_lt_1.name)
def test_identical_shaped_inputs(self):
offset_tensor = self.original_lt.tensor + 1
offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)
align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,
offset_lt)
self.assertLabeledTensorsEqual(align_lt, self.original_lt)
self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)
self.assertEqual(broadcast_axes, self.original_lt.axes)
def test_different_inputs(self):
# The correct axis ordering is ['x', 'channel', 'probs'].
align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
self.x_probs_lt, self.channel_probs_lt)
x_probs_golden_lt = core.LabeledTensor(
tf.reshape(self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]),
[self.a0, 'channel', self.a3])
self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)
channel_probs_golden_lt = core.LabeledTensor(
tf.reshape(self.channel_probs_lt.tensor,
[1, self.channel_size, self.probs_size]),
['x', self.a1, self.a3])
self.assertLabeledTensorsEqual(align_channel_probs_lt,
channel_probs_golden_lt)
self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
def test_axis_order_scope(self):
xz_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'z'])
yz_lt = core.LabeledTensor(tf.ones((4, 3)), ['y', 'z'])
_, _, broadcast_axes = core.align(xz_lt, yz_lt)
self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
_, _, broadcast_axes = core.align(yz_lt, xz_lt)
self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])
with core.axis_order_scope(['x', 'y', 'z']):
_, _, broadcast_axes = core.align(yz_lt, xz_lt)
self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
with core.axis_order_scope(['x', 'y']):
with self.assertRaises(core.AxisOrderError):
core.align(xz_lt, yz_lt)
with self.assertRaises(core.AxisOrderError):
core.align(yz_lt, xz_lt)
def test_invalid_input(self):
lt_0 = core.LabeledTensor(tf.zeros([5]), [('a', range(5))])
lt_1 = core.LabeledTensor(tf.zeros([5]), [('a', range(1, 6))])
with self.assertRaises(ValueError):
core.align(lt_0, lt_1)
class ConvertToLabeledTensorTest(Base):
# TODO(shoyer): Simplify these tests once we can reuse labeled tensors in
# assertLabeledTensorsEqual.
def test_labeled_tensor(self):
actual = core.convert_to_labeled_tensor(self.original_lt)
self.assertLabeledTensorsEqual(actual, self.original_lt)
def test_python_scalar(self):
actual = core.convert_to_labeled_tensor(42)
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
self.assertLabeledTensorsEqual(actual, golden_lt)
def test_numpy_array(self):
actual = core.convert_to_labeled_tensor(np.array(42))
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
self.assertLabeledTensorsEqual(actual, golden_lt)
def test_tensor(self):
actual = core.convert_to_labeled_tensor(tf.constant(42))
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
self.assertLabeledTensorsEqual(actual, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
core.convert_to_labeled_tensor(tf.range(5))
with self.assertRaises(ValueError):
core.convert_to_labeled_tensor(np.array([1, 2]))
class DocStringCheckMixin(object):
# requires self.ops to be defined
def test_function_docstring_and_name(self):
for op_name, _, _, lt_op in self.ops:
if lt_op is not None:
self.assertIn('tf.%s' % op_name, lt_op.__doc__)
self.assertEqual(op_name, lt_op.__name__)
class UnaryOpsTestsMixin(object):
# requires self.ops and self.test_lt to be defined
def test_core_op(self):
for op_name, _, tf_op, lt_op in self.ops:
if tf_op is not None:
golden_lt = core.LabeledTensor(tf_op(self.test_lt.tensor),
self.test_lt.axes)
actual_lt = lt_op(self.test_lt)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
def test_infix(self):
for op_name, infix_op, _, _ in self.ops:
if infix_op is not None:
expected_lt = core.LabeledTensor(infix_op(self.test_lt.tensor),
self.test_lt.axes)
actual_lt = infix_op(self.test_lt)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
def setUp(self):
super(CoreUnaryOpsTest, self).setUp()
self.ops = [
('abs', operator.abs, tf.abs, core.abs_function),
('neg', operator.neg, tf.neg, core.neg),
# TODO(shoyer): add unary + to core TensorFlow
('pos', None, None, None),
('sign', None, tf.sign, core.sign),
('inv', None, tf.inv, core.inv),
('square', None, tf.square, core.square),
('round', None, tf.round, core.round_function),
('sqrt', None, tf.sqrt, core.sqrt),
('rsqrt', None, tf.rsqrt, core.rsqrt),
('log', None, tf.log, core.log),
('exp', None, tf.exp, core.exp),
('log', None, tf.log, core.log),
('ceil', None, tf.ceil, core.ceil),
('floor', None, tf.floor, core.floor),
('cos', None, tf.cos, core.cos),
('sin', None, tf.sin, core.sin),
('tan', None, tf.tan, core.tan),
('acos', None, tf.acos, core.acos),
('asin', None, tf.asin, core.asin),
('atan', None, tf.atan, core.atan),
('lgamma', None, tf.lgamma, core.lgamma),
('digamma', None, tf.digamma, core.digamma),
('erf', None, tf.erf, core.erf),
('erfc', None, tf.erfc, core.erfc),
('lgamma', None, tf.lgamma, core.lgamma),
]
total_size = np.prod([v.size for v in self.original_lt.axes.values()])
self.test_lt = core.LabeledTensor(
tf.cast(self.original_lt, tf.float32) / total_size,
self.original_lt.axes)
class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
def setUp(self):
super(LogicalNotTest, self).setUp()
self.ops = [
('logical_not', operator.invert, tf.logical_not, core.logical_not),
]
self.test_lt = self.original_lt < 10
class BinaryOpsTestsMixin(object):
# requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast
# and self.test_lt_2_broadcast to be defined
def test_core_op(self):
for op_name, _, tf_op, lt_op in self.ops:
golden_tensor = tf_op(self.test_lt_1_broadcast,
self.test_lt_2_broadcast)
golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes)
actual_lt = lt_op(self.test_lt_1, self.test_lt_2)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
def test_infix(self):
for op_name, infix_op, _, lt_op in self.ops:
if infix_op is not None:
expected_lt = lt_op(self.test_lt_1, self.test_lt_2)
actual_lt = infix_op(self.test_lt_1, self.test_lt_2)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
def setUp(self):
super(CoreBinaryOpsTest, self).setUp()
self.x_probs_broadcast_tensor = tf.reshape(
self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])
self.channel_probs_broadcast_tensor = tf.reshape(
self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])
# == and != are not element-wise for tf.Tensor, so they shouldn't be
# elementwise for LabeledTensor, either.
self.ops = [
('add', operator.add, tf.add, core.add),
('sub', operator.sub, tf.sub, core.sub),
('mul', operator.mul, tf.mul, core.mul),
('div', operator.truediv, tf.div, core.div),
('mod', operator.mod, tf.mod, core.mod),
('pow', operator.pow, tf.pow, core.pow_function),
('equal', None, tf.equal, core.equal),
('less', operator.lt, tf.less, core.less),
('less_equal', operator.le, tf.less_equal, core.less_equal),
('not_equal', None, tf.not_equal, core.not_equal),
('greater', operator.gt, tf.greater, core.greater),
('greater_equal', operator.ge, tf.greater_equal, core.greater_equal),
]
self.test_lt_1 = self.x_probs_lt
self.test_lt_2 = self.channel_probs_lt
self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
self.broadcast_axes = [self.a0, self.a1, self.a3]
def test_reflexive(self):
labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division
for op_name, infix_op, _, lt_op in self.ops:
if infix_op is not None:
expected_lt = lt_op(2, labeled_tensor)
actual_lt = infix_op(2, labeled_tensor)
# Python uses greater for the reflexive version of less (and vise-versa)
if 'less' in op_name:
op_name = op_name.replace('less', 'greater')
elif 'greater' in op_name:
op_name = op_name.replace('greater', 'less')
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
def setUp(self):
super(LogicalBinaryOpsTest, self).setUp()
self.ops = [
('logical_and', operator.and_, tf.logical_and, core.logical_and),
('logical_or', operator.or_, tf.logical_or, core.logical_or),
('logical_xor', operator.xor, tf.logical_xor, core.logical_xor),
]
self.test_lt_1 = self.original_lt < 10
self.test_lt_2 = self.original_lt < 5
self.test_lt_1_broadcast = self.test_lt_1.tensor
self.test_lt_2_broadcast = self.test_lt_2.tensor
self.broadcast_axes = self.test_lt_1.axes
class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
def setUp(self):
super(FloatBinaryOpsTest, self).setUp()
self.ops = [
('igamma', None, tf.igamma, core.igamma),
('igammac', None, tf.igammac, core.igammac),
('zeta', None, tf.zeta, core.zeta),
('polygamma', None, tf.polygamma, core.polygamma),
('maximum', None, tf.maximum, core.maximum),
('minimum', None, tf.minimum, core.minimum),
('squared_difference', None, tf.squared_difference,
core.squared_difference),
]
total_size = np.prod([v.size for v in self.original_lt.axes.values()])
test_lt = core.LabeledTensor(
tf.cast(self.original_lt, tf.float32) / total_size,
self.original_lt.axes)
self.test_lt_1 = test_lt
self.test_lt_2 = 1.0 - test_lt
self.test_lt_1_broadcast = self.test_lt_1.tensor
self.test_lt_2_broadcast = self.test_lt_2.tensor
self.broadcast_axes = self.test_lt_1.axes
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,178 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input parsing code for LabeledTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six import string_types
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
class FixedLenFeature(object):
"""Configuration for parsing a fixed-length input feature.
Fields:
axes: A list of Axis objects or tuples (axis_name, axis_value),
where `axis_name` is a string and `axis_value` is None (unknown size), an
integer or a list of tick labels.
dtype: Data type of input.
default_value: Value to be used if an example is missing this feature. It
must be compatible with `dtype`.
"""
def __init__(self, axes, dtype, default_value=None):
self._axes = [core.as_axis(a) for a in axes]
self._dtype = dtype
self._default_value = default_value
@property
def axes(self):
return self._axes
@property
def dtype(self):
return self._dtype
@property
def default_value(self):
return self._default_value
@tc.returns(tc.Dict(string_types, parsing_ops.FixedLenFeature))
@tc.accepts(tc.Mapping(string_types, FixedLenFeature))
def _labeled_to_unlabeled_features(features):
"""Convert a dict of lt.FixedLenFeature into a dict of tf.FixedLenFeature."""
unlabeled_features = {}
for name, labeled_feature in features.items():
shape = [ax.size for ax in labeled_feature.axes]
if any(size is None for size in shape):
# This should be caught on the TensorFlow side, but it isn't yet:
# https://github.com/tensorflow/tensorflow/issues/2874
raise ValueError('axes with unknown size are not supported')
dtype = labeled_feature.dtype
default_value = labeled_feature.default_value
unlabeled_features[name] = parsing_ops.FixedLenFeature(
shape, dtype, default_value)
return unlabeled_features
@tc.returns(tc.Dict(string_types, core.LabeledTensor))
@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
tc.Optional(string_types), object)
def parse_example(serialized, features, name=None, example_names=None):
"""Parse `Example` protos into a `dict` of labeled tensors.
See tf.parse_example.
Args:
serialized: A 1-D LabeledTensor of strings, a batch of binary serialized
`Example` protos.
features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
values.
name: A name for this operation (optional).
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
Returns:
A `dict` mapping feature keys to `LabeledTensor` values. The single axis
from `serialized` will be prepended to the axes provided by each feature.
Raises:
ValueError: if any feature is invalid.
"""
serialized = core.convert_to_labeled_tensor(serialized)
unlabeled_features = _labeled_to_unlabeled_features(features)
unlabeled_parsed = parsing_ops.parse_example(
serialized.tensor, unlabeled_features, name, example_names)
parsed = {}
for name, parsed_feature in unlabeled_parsed.items():
axes = list(serialized.axes.values()) + features[name].axes
parsed[name] = core.LabeledTensor(parsed_feature, axes)
return parsed
@tc.returns(tc.Dict(string_types, core.LabeledTensor))
@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
tc.Optional(string_types), object)
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
See tf.parse_single_example.
Args:
serialized: A scalar string Tensor or LabeledTensor, a single serialized
Example.
features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
values.
name: A name for this operation (optional).
example_names: (Optional) A scalar string Tensor, the associated name.
Returns:
A `dict` mapping feature keys to `LabeledTensor` values.
Raises:
ValueError: if any feature is invalid.
"""
serialized = core.convert_to_labeled_tensor(serialized)
unlabeled_features = _labeled_to_unlabeled_features(features)
unlabeled_parsed = parsing_ops.parse_single_example(
serialized.tensor, unlabeled_features, name, example_names)
parsed = {}
for name, parsed_feature in unlabeled_parsed.items():
parsed[name] = core.LabeledTensor(parsed_feature, features[name].axes)
return parsed
@tc.returns(core.LabeledTensor)
@tc.accepts(dtypes.DType, tc.Collection(tc.Union(string_types, core.AxisLike)),
tc.Optional(string_types))
def placeholder(dtype, axes, name=None):
"""Create a placeholder for a labeled tensor.
For example:
lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])
See tf.placeholder for more details.
Args:
dtype: The type of elements in the tensor to be fed.
axes: sequence of strings (denoting axes of unknown size) and/or objects
convertable to lt.Axis to label the result.
name: Optional op name.
Returns:
Placeholder labeled tensor.
"""
with ops.name_scope(name, 'lt_placeholder', []) as scope:
axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
for axis in axes])
shape = [axis.size for axis in axes.values()]
tensor = array_ops.placeholder(dtype, shape, name=scope)
return core.LabeledTensor(tensor, axes)

View File

@ -0,0 +1,106 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import io_ops
from tensorflow.contrib.labeled_tensor.python.ops import test_util
class ParseBase(test_util.Base):
def setUp(self):
super(ParseBase, self).setUp()
examples = [
tf.train.Example(features=tf.train.Features(feature={
'a': tf.train.Feature(
int64_list=tf.train.Int64List(value=[1])),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=[2, 3, 4])),
})),
tf.train.Example(features=tf.train.Features(feature={
'a': tf.train.Feature(
int64_list=tf.train.Int64List(value=[5])),
'b': tf.train.Feature(
int64_list=tf.train.Int64List(value=[6, 7, 8])),
})),
]
self.serialized = core.LabeledTensor(
tf.constant([ex.SerializeToString() for ex in examples]), ['batch'])
self.features = {'a': io_ops.FixedLenFeature([], tf.int64),
'b': io_ops.FixedLenFeature([('x', 3)], tf.int64)}
class TestParseExample(ParseBase):
def test(self):
expected_a = core.LabeledTensor(tf.constant([1, 5]), ['batch'])
expected_b = core.LabeledTensor(tf.constant([[2, 3, 4], [6, 7, 8]]),
['batch', 'x'])
parsed = io_ops.parse_example(self.serialized, self.features)
self.assertLabeledTensorsEqual(expected_a, parsed['a'])
self.assertLabeledTensorsEqual(expected_b, parsed['b'])
def test_placeholder(self):
serialized = core.LabeledTensor(tf.placeholder(tf.string, [None]),
['batch'])
# should not raise
io_ops.parse_example(serialized, self.features)
class TestParseSingleExample(ParseBase):
def test(self):
expected_a = core.LabeledTensor(tf.constant(1), [])
expected_b = core.LabeledTensor(tf.constant([2, 3, 4]), ['x'])
parsed = io_ops.parse_single_example(self.serialized[0], self.features)
self.assertLabeledTensorsEqual(expected_a, parsed['a'])
self.assertLabeledTensorsEqual(expected_b, parsed['b'])
def test_unknown_size(self):
features = {'a': io_ops.FixedLenFeature([('x', None)], tf.int64)}
serialized = tf.placeholder(tf.string, [])
with self.assertRaisesRegexp(ValueError, 'unknown size'):
io_ops.parse_single_example(serialized, features)
class PlaceholderTest(test_util.Base):
def test_name(self):
placeholder_lt = io_ops.placeholder(tf.float32, [])
self.assertIn('lt_placeholder', placeholder_lt.name)
def test(self):
placeholder_lt = io_ops.placeholder(tf.float32,
['batch', ('x', ['a', 'b'])])
self.assertEqual(placeholder_lt.dtype, tf.float32)
self.assertEqual(placeholder_lt.axes,
core.Axes([('batch', None), ('x', ['a', 'b'])]))
def test_feed(self):
sess = tf.Session()
placeholder_lt = io_ops.placeholder(tf.float32, [])
two_times = 2.0 * placeholder_lt
result = sess.run(two_times, {placeholder_lt.tensor: 1})
self.assertEqual(result, 2.0)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,42 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Neural network ops for LabeledTensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.python.ops import nn
relu = core.define_unary_op('relu', nn.relu)
relu6 = core.define_unary_op('relu6', nn.relu6)
crelu = core.define_unary_op('crelu', nn.crelu)
elu = core.define_unary_op('elu', nn.elu)
softplus = core.define_unary_op('softplus', nn.softplus)
l2_loss = core.define_unary_op('l2_loss', nn.l2_loss)
sigmoid_cross_entropy_with_logits = core.define_binary_op(
'sigmoid_cross_entropy_with_logits',
nn.sigmoid_cross_entropy_with_logits)
softmax = core.define_unary_op('softmax', nn.softmax)
log_softmax = core.define_unary_op('log_softmax', nn.log_softmax)
softmax_cross_entropy_with_logits = core.define_binary_op(
'softmax_cross_entropy_with_logits',
nn.softmax_cross_entropy_with_logits)
sparse_softmax_cross_entropy_with_logits = core.define_binary_op(
'sparse_softmax_cross_entropy_with_logits',
nn.sparse_softmax_cross_entropy_with_logits)

View File

@ -0,0 +1,70 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import nn
from tensorflow.contrib.labeled_tensor.python.ops import test_util
class NNTests(test_util.Base):
def setUp(self):
super(NNTests, self).setUp()
self.axes = ['x']
self.original_lt = core.LabeledTensor([0.0, 0.5, 1.0], self.axes)
self.other_lt = 1 - self.original_lt
def test_unary_ops(self):
ops = [
('relu', tf.nn.relu, nn.relu),
('relu6', tf.nn.relu6, nn.relu6),
('crelu', tf.nn.crelu, nn.crelu),
('elu', tf.nn.elu, nn.elu),
('softplus', tf.nn.softplus, nn.softplus),
('l2_loss', tf.nn.l2_loss, nn.l2_loss),
('softmax', tf.nn.softmax, nn.softmax),
('log_softmax', tf.nn.log_softmax, nn.log_softmax),
]
for op_name, tf_op, lt_op in ops:
golden_tensor = tf_op(self.original_lt.tensor)
golden_lt = core.LabeledTensor(golden_tensor, self.axes)
actual_lt = lt_op(self.original_lt)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
def test_binary_ops(self):
ops = [
('sigmoid_cross_entropy_with_logits',
tf.nn.sigmoid_cross_entropy_with_logits,
nn.sigmoid_cross_entropy_with_logits),
('softmax_cross_entropy_with_logits',
tf.nn.softmax_cross_entropy_with_logits,
nn.softmax_cross_entropy_with_logits),
('sparse_softmax_cross_entropy_with_logits',
tf.nn.sparse_softmax_cross_entropy_with_logits,
nn.sparse_softmax_cross_entropy_with_logits),
]
for op_name, tf_op, lt_op in ops:
golden_tensor = tf_op(self.original_lt.tensor, self.other_lt.tensor)
golden_lt = core.LabeledTensor(golden_tensor, self.axes)
actual_lt = lt_op(self.original_lt, self.other_lt)
self.assertIn(op_name, actual_lt.name)
self.assertLabeledTensorsEqual(golden_lt, actual_lt)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,918 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import ops
from tensorflow.contrib.labeled_tensor.python.ops import test_util
class Base(test_util.Base):
def setUp(self):
super(Base, self).setUp()
self.x_size = 7
self.channel_size = 3
self.z_size = 4
self.probs_size = 11
tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
self.probs_size)
tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
self.probs_size])
a0 = ('x', range(self.x_size))
a1 = ('channel', ['red', 'green', 'blue'])
a2 = 'z'
a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
self.tensor = tensor
self.a0 = a0
self.a1 = a1
self.a2 = a2
self.a2_resolved = ('z', self.z_size)
self.a3 = a3
self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
'z': 0})
class SelectTest(Base):
def test_name(self):
select_lt = ops.select(self.original_lt, {'channel': 'green'})
self.assertIn('lt_select', select_lt.name)
def test_scalar(self):
select_lt = ops.select(self.original_lt, {'channel': 'green'})
golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slice(self):
select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})
a1_sliced = ('channel', ['red', 'green'])
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
[self.a0, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_slices(self):
select_lt = ops.select(self.original_lt, {'x': slice(1, 4),
'channel': slice('green', None)})
a0_sliced = ('x', range(1, 5))
a1_sliced = ('channel', ['green', 'blue'])
golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
[a0_sliced, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_list(self):
select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
a1_sliced = ('channel', ['red', 'green'])
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
[self.a0, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_list_one_item(self):
select_lt = ops.select(self.original_lt, {'channel': ['red']})
a1_sliced = ('channel', ['red'])
golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
[self.a0, a1_sliced, self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_list_zero_items(self):
select_lt = ops.select(self.original_lt, {'channel': []})
golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
[self.a0, 'channel', self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_scalars(self):
select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})
golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :],
[self.a2, self.a3])
self.assertLabeledTensorsEqual(select_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.select(self.original_lt, {'foo': 1})
with self.assertRaises(ValueError):
ops.select(self.original_lt, {'z': 1})
with self.assertRaises(KeyError):
ops.select(self.original_lt, {'channel': 'purple'})
with self.assertRaises(KeyError):
ops.select(self.original_lt, {'channel': ['red', 'purple']})
with self.assertRaises(NotImplementedError):
ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
with self.assertRaises(NotImplementedError):
ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
with self.assertRaises(NotImplementedError):
ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
class ConcatTest(Base):
def setUp(self):
super(ConcatTest, self).setUp()
self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
def test_name(self):
concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')
self.assertIn('lt_concat', concat_lt.name)
def test(self):
concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
self.assertLabeledTensorsEqual(concat_lt, golden_lt)
def test_transposed(self):
green_transposed = core.transpose(self.green_lt,
['probs', 'channel', 'z', 'x'])
with self.assertRaises(ValueError):
ops.concat([self.red_lt, green_transposed], 'channel')
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.concat([], 'channel')
with self.assertRaises(ValueError):
ops.concat([self.red_lt, self.red_lt], 'channel')
with self.assertRaises(ValueError):
ops.concat([self.red_lt, self.red_lt], 'foo')
class PackTest(Base):
def test_name(self):
pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
self.assertIn('lt_pack', pack_lt.name)
def test(self):
pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
golden_lt = core.LabeledTensor(
tf.stack([self.original_lt.tensor, self.original_lt.tensor]),
['batch', self.a0, self.a1, self.a2, self.a3])
self.assertLabeledTensorsEqual(pack_lt, golden_lt)
def test_axis(self):
pack_lt = ops.pack([self.original_lt, self.original_lt],
new_axis='batch',
axis_position=4)
golden_lt = core.LabeledTensor(
tf.stack(
[self.original_lt.tensor, self.original_lt.tensor], axis=4),
[self.a0, self.a1, self.a2, self.a3, 'batch'])
self.assertLabeledTensorsEqual(pack_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.pack([self.original_lt, self.original_lt], 'channel')
class UnpackTest(Base):
def test_name(self):
unpack_lts = ops.unpack(self.original_lt)
for t in unpack_lts:
self.assertIn('lt_unpack', t.name)
def test(self):
unpack_lt = ops.unpack(self.original_lt)[0]
golden_lt = core.LabeledTensor(
tf.unstack(self.original_lt.tensor)[0], [self.a1, self.a2, self.a3])
self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
def test_axis(self):
unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]
golden_lt = core.LabeledTensor(
tf.unstack(
self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])
self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.unpack(self.original_lt, axis_name='not_found')
class ReshapeTest(Base):
def test_name(self):
reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
self.assertIn('lt_reshape', reshape_lt.name)
def test_identity(self):
reshape_lt = ops.reshape(self.original_lt, self.original_lt.axes.keys(),
self.original_lt.axes.values())
self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
def test_known_size(self):
new_dim_size = self.channel_size * self.z_size * self.probs_size
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
[('new_dim', new_dim_size)])
golden_lt = core.LabeledTensor(
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
[self.original_lt.axes['x'], 'new_dim'])
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
def test_unknown_size(self):
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
['new_dim'])
golden_lt = core.LabeledTensor(
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
[self.original_lt.axes['x'], 'new_dim'])
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
def test_unknown_dimension(self):
orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
with self.test_session() as sess:
result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
np.testing.assert_array_equal(result, [[1], [2]])
def test_with_labels(self):
new_dim_size = self.channel_size * self.z_size * self.probs_size
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
[('new_dim', range(new_dim_size))])
golden_lt = core.LabeledTensor(
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
[self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
ops.reshape(self.original_lt, ['foo'], ['bar'])
with self.assertRaisesRegexp(core.AxisOrderError,
'not a slice of axis names'):
ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
class RenameAxisTest(Base):
def test_name(self):
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
self.assertIn('lt_rename_axis', rename_axis_lt.name)
def test_identity(self):
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')
self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)
def test_new_name(self):
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
expected_axes = [(name if name != 'channel' else 'foo', axis.value)
for name, axis in self.original_lt.axes.items()]
expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)
self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)
def test_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
ops.rename_axis(self.original_lt, 'foo', 'bar')
class BatchTest(Base):
def setUp(self):
super(BatchTest, self).setUp()
tensors = []
for i in range(10):
offset_lt = core.LabeledTensor(tf.constant(i), [])
tensors.append(core.add(self.original_lt, offset_lt))
self.pack_lt = ops.pack(tensors, 'batch')
def test_name(self):
batch_ops = ops.batch([self.pack_lt, self.pack_lt],
batch_size=2,
enqueue_many=True)
for bo in batch_ops:
self.assertIn('lt_batch', bo.name)
def test_enqueue_many(self):
[batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)
self.assertEqual(len(batch_2_op.axes['batch']), 2)
[batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
def test_no_enqueue_many(self):
[batch_2_op] = ops.batch([self.original_lt], batch_size=2)
self.assertEqual(len(batch_2_op.axes['batch']), 2)
[batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
self.assertLabeledTensorsEqual(
ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.batch([self.original_lt], 3, enqueue_many=True)
def test_allow_smaller_final_batch(self):
[batch_2_op] = ops.batch([self.original_lt], batch_size=2,
allow_smaller_final_batch=True)
self.assertEqual(batch_2_op.axes['batch'].size, None)
class ShuffleBatchTest(Base):
def setUp(self):
super(ShuffleBatchTest, self).setUp()
tensors = []
for i in range(10):
offset_lt = core.LabeledTensor(tf.constant(i), [])
tensors.append(core.add(self.original_lt, offset_lt))
self.pack_lt = ops.pack(tensors, 'batch')
def test_name(self):
batch_lts = ops.shuffle_batch([self.pack_lt, self.pack_lt],
batch_size=2,
enqueue_many=True)
for blt in batch_lts:
self.assertIn('lt_shuffle_batch', blt.name)
def test_enqueue_many(self):
[batch_2_lt] = ops.shuffle_batch([self.pack_lt],
batch_size=2,
enqueue_many=True,
min_after_dequeue=8,
seed=0)
self.assertEqual(len(batch_2_lt.axes['batch']), 2)
[batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)
self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)
[batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])
self.assertFalse((batch_10 == pack).all())
def test_allow_smaller_final_batch(self):
[batch_2_op] = ops.shuffle_batch([self.original_lt], batch_size=2,
allow_smaller_final_batch=True)
self.assertEqual(batch_2_op.axes['batch'].size, None)
class RandomCropTest(Base):
def test_name(self):
crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
self.assertIn('lt_random_crop', crop_lt.name)
def test_single(self):
crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
self.assertEqual(
core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),
crop_lt.axes)
def test_double(self):
crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})
self.assertEqual(
core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
crop_lt.axes)
def test_size1(self):
crop_lt = ops.random_crop(self.original_lt, {'probs': 1})
self.assertEqual(
core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
crop_lt.axes)
def test_different_seeds(self):
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
'channel': 2},
seed=0)
crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
'channel': 2},
seed=1)
self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)
[crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])
self.assertFalse((crop_0 == crop_1).all())
def test_identical_seeds(self):
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
'channel': 2},
seed=0)
crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
'channel': 2},
seed=0)
self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
def test_crop_idempotent(self):
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
'channel': 2},
seed=0)
crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)
self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.random_crop(self.original_lt, {'foobar': 2})
class MapFnTest(Base):
def test_name(self):
map_lt = ops.map_fn(core.identity, self.original_lt)
self.assertIn('lt_map_fn', map_lt.name)
def test_identity(self):
map_lt = ops.map_fn(core.identity, self.original_lt)
self.assertLabeledTensorsEqual(map_lt, self.original_lt)
def test_callable_object(self):
class Identity(object):
def __call__(self, other):
return other
map_lt = ops.map_fn(Identity(), self.original_lt)
self.assertLabeledTensorsEqual(map_lt, self.original_lt)
def test_slice(self):
map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),
self.original_lt)
slice_lt = core.slice_function(self.original_lt, {'channel': 1})
self.assertLabeledTensorsEqual(map_lt, slice_lt)
class SqueezeTest(Base):
def setUp(self):
super(SqueezeTest, self).setUp()
self.squeezable_lt = core.slice_function(self.original_lt,
{'channel': slice(0, 1),
'probs': slice(0, 1)})
def test_name(self):
squeeze_lt = ops.squeeze(self.squeezable_lt)
self.assertIn('lt_squeeze', squeeze_lt.name)
def test_none(self):
none_lt = ops.squeeze(self.squeezable_lt, None)
axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])
self.assertLabeledTensorsEqual(none_lt, axes_lt)
def test(self):
squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])
golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})
self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
ops.squeeze(self.original_lt, ['channel'])
with self.assertRaises(ValueError):
ops.squeeze(self.squeezable_lt, ['foo'])
class MatMulTest(Base):
def test_name(self):
x_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
matmul_lt = ops.matmul(x_lt, x_lt)
self.assertIn('lt_matmul', matmul_lt.name)
def test_vector_vector(self):
x_lt = core.LabeledTensor(tf.range(3), ['x'])
matmul_lt = ops.matmul(x_lt, x_lt)
golden_lt = core.convert_to_labeled_tensor(5)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
def test_matrix_vector(self):
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
y_lt = core.LabeledTensor(tf.range(3), ['y'])
matmul_lt = ops.matmul(xy_lt, y_lt)
golden_lt = core.LabeledTensor(
tf.matmul(xy_lt.tensor, tf.reshape(y_lt.tensor, (-1, 1)))[:, 0], ['x'])
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(y_lt, xy_lt)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
def test_matrix_matrix(self):
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
matmul_lt = ops.matmul(xy_lt, yz_lt)
golden_lt = core.LabeledTensor(
tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(yz_lt, xy_lt)
self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
def test_matrix_matrix_axis_order(self):
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
golden_lt = core.LabeledTensor(
tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
with core.axis_order_scope(['x', 'y', 'z']):
matmul_lt = ops.matmul(xy_lt, yz_lt)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
matmul_lt = ops.matmul(yz_lt, xy_lt)
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
def test_invalid(self):
scalar_lt = core.LabeledTensor(tf.ones(()), [])
x_lt = core.LabeledTensor(tf.ones((2,)), ['x'])
x2_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
y_lt = core.LabeledTensor(tf.ones((3,)), ['y'])
xy_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'y'])
xyz_lt = core.LabeledTensor(tf.ones((2, 3, 1)), ['x', 'y', 'z'])
with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
ops.matmul(x_lt, scalar_lt)
with self.assertRaises(NotImplementedError):
ops.matmul(x_lt, xyz_lt)
with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
ops.matmul(x_lt, y_lt)
with self.assertRaises(NotImplementedError):
ops.matmul(xy_lt, xy_lt)
with self.assertRaisesRegexp(ValueError, 'does not match'):
ops.matmul(x_lt, x2_lt)
class ReduceSumTest(Base):
def test_name(self):
sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
self.assertIn('lt_reduce_sum', sum_lt.name)
def test_drop_axis(self):
sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
def test_drop_scalar_axis(self):
sum_lt = ops.reduce_sum(self.original_lt, 'channel')
golden_lt = core.LabeledTensor(
tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
def test_keep_axis(self):
sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})
golden_lt = core.LabeledTensor(
tf.reduce_sum(self.original_lt.tensor,
1, keep_dims=True),
[self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
def test_keep_scalar_axis(self):
sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))
golden_lt = core.LabeledTensor(
tf.reduce_sum(self.original_lt.tensor,
1, keep_dims=True),
[self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
def test_scalar(self):
scalar_lt = core.LabeledTensor(tf.constant(42), [])
reduce_lt = ops.reduce_sum(scalar_lt, [])
self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)
def test_empty_list(self):
reduce_lt = ops.reduce_sum(self.original_lt, [])
self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)
def test_none(self):
sum_lt = ops.reduce_sum(self.original_lt)
golden_lt = core.LabeledTensor(tf.reduce_sum(self.original_lt.tensor), [])
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
def test_function_docstring_and_name(self):
self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)
self.assertEqual('reduce_sum', ops.reduce_sum.__name__)
class ReduceMeanTest(Base):
def test_name(self):
actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
self.assertIn('lt_reduce_mean', actual_lt.name)
def test(self):
actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_mean(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(actual_lt, golden_lt)
class ReduceProdTest(Base):
def test_name(self):
result_lt = ops.reduce_prod(self.original_lt, {'channel'})
self.assertIn('lt_reduce_prod', result_lt.name)
def test(self):
result_lt = ops.reduce_prod(self.original_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_prod(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
class ReduceMinTest(Base):
def test_name(self):
result_lt = ops.reduce_min(self.original_lt, {'channel'})
self.assertIn('lt_reduce_min', result_lt.name)
def test(self):
result_lt = ops.reduce_min(self.original_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_min(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
class ReduceMaxTest(Base):
def test_name(self):
result_lt = ops.reduce_max(self.original_lt, {'channel'})
self.assertIn('lt_reduce_max', result_lt.name)
def test(self):
result_lt = ops.reduce_max(self.original_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_max(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
class BaseReduceBoolean(Base):
def setUp(self):
super(BaseReduceBoolean, self).setUp()
self.bool_tensor = tf.cast(self.original_lt.tensor > 5, tf.bool)
self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)
class ReduceAllTest(BaseReduceBoolean):
def test_name(self):
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
self.assertIn('lt_reduce_all', result_lt.name)
def test(self):
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
class ReduceAnyTest(BaseReduceBoolean):
def test_name(self):
result_lt = ops.reduce_any(self.bool_lt, {'channel'})
self.assertIn('lt_reduce_any', result_lt.name)
def test(self):
result_lt = ops.reduce_any(self.bool_lt, {'channel'})
golden_lt = core.LabeledTensor(
tf.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
self.assertLabeledTensorsEqual(result_lt, golden_lt)
class TileTest(Base):
def test_name(self):
tile_lt = ops.tile(self.original_lt, {'z': 2})
self.assertIn('lt_tile', tile_lt.name)
def test(self):
for multiple in [2, tf.constant(2)]:
tile_lt = ops.tile(self.original_lt, {'z': multiple})
golden_op = tf.tile(self.original_lt.tensor, [1, 1, multiple, 1])
golden_axes = ['z' if axis.name == 'z' else axis
for axis in self.original_lt.axes.values()]
golden_lt = core.LabeledTensor(golden_op, golden_axes)
self.assertLabeledTensorsEqual(tile_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
ops.tile(self.original_lt, {'foo': 5})
with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):
ops.tile(self.original_lt, {'x': 5})
class PadTest(Base):
def test_name(self):
pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
'channel': ([], ['alpha'])})
self.assertIn('lt_pad', pad_lt.name)
def test(self):
pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
'channel': ([], ['alpha'])})
golden_op = tf.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],
[0, 0]])
golden_axes = [('x', self.x_size + 2),
('channel', ['red', 'green', 'blue', 'alpha']), self.a2,
self.a3]
golden_lt = core.LabeledTensor(golden_op, golden_axes)
self.assertLabeledTensorsEqual(pad_lt, golden_lt)
def test_invalid_input(self):
with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})
class ConstantTest(Base):
def test_name(self):
constant_lt = ops.constant(1)
self.assertIn('lt_constant', constant_lt.name)
def test_scalar(self):
constant_lt = ops.constant(1)
golden_lt = core.LabeledTensor(tf.constant(1), [])
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
def test_infer_shape(self):
constant_lt = ops.constant([1, 2], axes=['x'])
golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
def test_specify_shape(self):
constant_lt = ops.constant(1, axes=[('x', 3)])
golden_lt = core.LabeledTensor(tf.constant(1, shape=(3,)), ['x'])
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
def test_existing_axes(self):
golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
constant_lt = ops.constant([1, 2], axes=golden_lt.axes)
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
class ZerosLikeTest(Base):
def test_name(self):
like_lt = ops.zeros_like(self.original_lt)
self.assertIn('lt_zeros_like', like_lt.name)
def test(self):
like_lt = ops.zeros_like(self.original_lt)
golden_lt = core.LabeledTensor(
tf.zeros_like(self.original_lt.tensor), self.original_lt.axes)
self.assertLabeledTensorsEqual(like_lt, golden_lt)
class OnesLikeTest(Base):
def test_name(self):
like_lt = ops.ones_like(self.original_lt)
self.assertIn('lt_ones_like', like_lt.name)
def test(self):
like_lt = ops.ones_like(self.original_lt)
golden_lt = core.LabeledTensor(
tf.ones_like(self.original_lt.tensor), self.original_lt.axes)
self.assertLabeledTensorsEqual(like_lt, golden_lt)
class CastTest(Base):
def test_name(self):
cast_lt = ops.cast(self.original_lt, tf.float16)
self.assertIn('lt_cast', cast_lt.name)
def test(self):
cast_lt = ops.cast(self.original_lt, tf.float16)
golden_lt = core.LabeledTensor(
tf.cast(self.original_lt.tensor, tf.float16), self.original_lt.axes)
self.assertLabeledTensorsEqual(cast_lt, golden_lt)
class VerifyTensorAllFiniteTest(Base):
def setUp(self):
super(VerifyTensorAllFiniteTest, self).setUp()
self.finite_lt = core.LabeledTensor(tf.constant(42.0), [])
self.nan_lt = core.LabeledTensor(tf.constant(np.nan), [])
self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')
self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')
def test_name(self):
self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)
self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)
def test_finite(self):
self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)
def test_nan(self):
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
'Tensor had NaN values'):
self.eval([self.checked_nan_lt])
class BooleanMaskTest(Base):
def test_name(self):
mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
masked_lt = ops.boolean_mask(self.original_lt, mask)
self.assertIn('lt_boolean_mask', masked_lt.name)
def test(self):
mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
masked_lt = ops.boolean_mask(self.original_lt, mask)
golden_lt = core.LabeledTensor(
tf.boolean_mask(self.original_lt.tensor, mask.tensor),
['x', self.a1, self.a2, self.a3])
self.assertLabeledTensorsEqual(masked_lt, golden_lt)
def test_invalid_rank(self):
mask = core.LabeledTensor(tf.ones((7, 3)) > 3, [self.a0, self.a1])
with self.assertRaises(NotImplementedError):
ops.boolean_mask(self.original_lt, mask)
def test_mismatched_axis(self):
mask = core.LabeledTensor(tf.range(7) > 3, ['foo'])
with self.assertRaisesRegexp(ValueError, 'not equal'):
ops.boolean_mask(self.original_lt, mask)
class WhereTest(Base):
def test_name(self):
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
where_lt = ops.where(condition, condition, condition)
self.assertIn('lt_where', where_lt.name)
def test(self):
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
x = core.LabeledTensor(tf.ones(5), ['x'])
y = core.LabeledTensor(tf.zeros(5), ['x'])
where_lt = ops.where(condition, x, y)
golden_lt = core.LabeledTensor(
tf.concat(0, [tf.ones(3), tf.zeros(2)]), ['x'])
self.assertLabeledTensorsEqual(where_lt, golden_lt)
def test_mismatched_axes(self):
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
with self.assertRaisesRegexp(ValueError, 'equal axes'):
ops.where(condition, condition[:3], condition)
with self.assertRaisesRegexp(ValueError, 'equal axes'):
ops.where(condition, condition, condition[:3])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,131 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools to make it a bit easier to use LabeledTensor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six import string_types
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import ops
from tensorflow.python.framework import ops as tf_ops
class ReshapeCoder(object):
"""Utility class for mapping to and from another shape.
For example, say you have a function `crop_center` which expects a
LabeledTensor with axes named ['batch', 'row', 'column', 'depth'], and
you have a LabeledTensor `masked_image_lt` with axes ['batch', 'row',
'column', 'channel', 'mask'].
To call `crop_center` with `masked_image_lt` you'd normally have to write:
>>> reshape_lt = lt.reshape(masked_image_lt, ['channel', 'mask'], ['depth'])
>>> crop_lt = crop_center(reshape_lt)
>>> result_lt = lt.reshape(crop_lt, ['depth'],
... [masked_image_lt.axes['channel'], masked_image_lt.axes['mask']])
ReshapeCoder takes care of this renaming logic for you, allowing you to
instead write:
>>> rc = ReshapeCoder(['channel', 'mask'], ['depth'])
>>> result_lt = rc.decode(crop_center(rc.encode(masked_image_lt)))
Here, `decode` restores the original axes 'channel' and 'mask', so
`crop_center` must not have modified the size of the 'depth' axis.
"""
@tc.accepts(object, tc.Collection(str),
tc.Collection(tc.Union(str, core.AxisLike)), tc.Optional(str))
def __init__(self, existing_axis_names, new_axes, name=None):
self._name = name
self._existing_axis_names = existing_axis_names
self._new_axes = new_axes
self._existing_axes = None
@tc.returns(core.LabeledTensor)
@tc.accepts(object, core.LabeledTensorLike)
def encode(self, labeled_tensor):
"""Reshape the input to the target shape.
If called several times, the axes named in existing_axis_names must be
identical.
Args:
labeled_tensor: The input tensor.
Returns:
The input reshaped to the target shape.
Raises:
ValueError: If the axes in existing_axis_names don't match the axes of
a tensor in a previous invocation of this method.
"""
with tf_ops.name_scope(self._name, 'lt_reshape_encode',
[labeled_tensor]) as scope:
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
reshape_lt = ops.reshape(labeled_tensor,
self._existing_axis_names,
self._new_axes,
name=scope)
axes = [labeled_tensor.axes[n] for n in self._existing_axis_names]
if self._existing_axes is not None and self._existing_axes != axes:
raise ValueError(
'input axes %r do not match axes from previous method call %r' %
(axes, self._existing_axes))
else:
self._existing_axes = axes
return reshape_lt
@tc.returns(core.LabeledTensor)
@tc.accepts(object, core.LabeledTensorLike)
def decode(self, labeled_tensor):
"""Reshape the input to the original shape.
This is the inverse of encode.
Encode must have been called at least once prior to this method being
called.
Args:
labeled_tensor: The input tensor.
Returns:
The input reshaped to the original shape.
Raises:
ValueError: If this method was called before encode was called.
"""
if self._existing_axes is None:
raise ValueError('decode called before encode')
with tf_ops.name_scope(self._name, 'lt_reshape_decode',
[labeled_tensor]) as scope:
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
new_axis_names = [axis if isinstance(axis, string_types) else
core.as_axis(axis).name for axis in self._new_axes]
return ops.reshape(labeled_tensor,
new_axis_names,
self._existing_axes,
name=scope)

View File

@ -0,0 +1,106 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import range # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.labeled_tensor.python.ops import core
from tensorflow.contrib.labeled_tensor.python.ops import ops
from tensorflow.contrib.labeled_tensor.python.ops import sugar
from tensorflow.contrib.labeled_tensor.python.ops import test_util
class Base(test_util.Base):
def setUp(self):
super(Base, self).setUp()
self.small_lt = core.LabeledTensor(tf.constant([1]), [('x', 1)])
class ReshapeCoderTest(Base):
def setUp(self):
super(ReshapeCoderTest, self).setUp()
self.batch_size = 8
self.num_rows = 50
self.num_columns = 100
self.channels = ['red', 'green', 'blue']
self.masks = [False, True]
tensor = tf.range(0, self.batch_size * self.num_rows * self.num_columns *
len(self.channels) * len(self.masks))
tensor = tf.reshape(tensor, [self.batch_size, self.num_rows,
self.num_columns, len(self.channels),
len(self.masks)])
self.batch_axis = ('batch', range(self.batch_size))
self.row_axis = ('row', range(self.num_rows))
self.column_axis = ('column', range(self.num_columns))
self.channel_axis = ('channel', self.channels)
self.mask_axis = ('mask', self.masks)
axes = [self.batch_axis, self.row_axis, self.column_axis, self.channel_axis,
self.mask_axis]
self.masked_image_lt = core.LabeledTensor(tensor, axes)
def test_name(self):
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
encode_lt = rc.encode(self.masked_image_lt)
decode_lt = rc.decode(encode_lt)
self.assertIn('lt_reshape_encode', encode_lt.name)
self.assertIn('lt_reshape_decode', decode_lt.name)
def test_bijection_flat(self):
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
encode_lt = rc.encode(self.masked_image_lt)
golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
('depth', len(self.channels) * len(self.masks))])
self.assertEqual(encode_lt.axes, golden_axes)
decode_lt = rc.decode(encode_lt)
self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
def test_bijection_with_labels(self):
depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks)))
rc = sugar.ReshapeCoder(['channel', 'mask'], [depth_axis,
('other', ['label'])])
encode_lt = rc.encode(self.masked_image_lt)
golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
depth_axis, ('other', ['label'])])
self.assertEqual(encode_lt.axes, golden_axes)
decode_lt = rc.decode(encode_lt)
self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
def test_invalid_input(self):
with self.assertRaises(ValueError):
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
rc.decode(self.masked_image_lt)
with self.assertRaises(ValueError):
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
rc.encode(self.masked_image_lt)
rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'}))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,47 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for writing tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class Base(tf.test.TestCase):
"""A class with some useful methods for testing."""
def eval(self, tensors):
with self.test_session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
results = sess.run(tensors)
finally:
coord.request_stop()
coord.join(threads)
return results
def assertTensorsEqual(self, tensor_0, tensor_1):
[tensor_0_eval, tensor_1_eval] = self.eval([tensor_0, tensor_1])
self.assertAllEqual(tensor_0_eval, tensor_1_eval)
def assertLabeledTensorsEqual(self, tensor_0, tensor_1):
self.assertEqual(tensor_0.axes, tensor_1.axes)
self.assertTensorsEqual(tensor_0.tensor, tensor_1.tensor)

View File

@ -90,6 +90,7 @@ sh_binary(
":included_headers",
":simple_console",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/labeled_tensor:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/session_bundle:all_files",
"//tensorflow/contrib/slim:all_files",