Initial version of tf.contrib.labeled_tensor
Change: 139143754
This commit is contained in:
parent
887892a499
commit
9d20f4ea4b
@ -98,6 +98,7 @@ filegroup(
|
||||
"//tensorflow/contrib/graph_editor:all_files",
|
||||
"//tensorflow/contrib/grid_rnn:all_files",
|
||||
"//tensorflow/contrib/integrate:all_files",
|
||||
"//tensorflow/contrib/labeled_tensor:all_files",
|
||||
"//tensorflow/contrib/layers:all_files",
|
||||
"//tensorflow/contrib/layers/kernels:all_files",
|
||||
"//tensorflow/contrib/learn:all_files",
|
||||
|
@ -24,6 +24,7 @@ py_library(
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
|
||||
"//tensorflow/contrib/integrate:integrate_py",
|
||||
"//tensorflow/contrib/labeled_tensor",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.contrib import framework
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib import grid_rnn
|
||||
from tensorflow.contrib import integrate
|
||||
from tensorflow.contrib import labeled_tensor
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
|
166
tensorflow/contrib/labeled_tensor/BUILD
Normal file
166
tensorflow/contrib/labeled_tensor/BUILD
Normal file
@ -0,0 +1,166 @@
|
||||
# Description:
|
||||
# Labels for TensorFlow.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
py_library(
|
||||
name = "labeled_tensor",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
":io_ops",
|
||||
":nn",
|
||||
":ops",
|
||||
":sugar",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "_typecheck",
|
||||
srcs = ["python/ops/_typecheck.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = [":__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["python/ops/core.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_typecheck",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
srcs = ["python/ops/test_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_typecheck",
|
||||
":core",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "core_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"python/ops/core_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
":test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "io_ops",
|
||||
srcs = ["python/ops/io_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "io_ops_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"python/ops/io_ops_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":io_ops",
|
||||
":ops",
|
||||
":test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "nn",
|
||||
srcs = ["python/ops/nn.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "nn_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"python/ops/nn_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":nn",
|
||||
":test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ops",
|
||||
srcs = ["python/ops/ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ops_test",
|
||||
srcs = [
|
||||
"python/ops/ops_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ops",
|
||||
":test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "sugar",
|
||||
srcs = ["python/ops/sugar.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":core",
|
||||
":ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "sugar_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"python/ops/sugar_test.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":sugar",
|
||||
":test_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
)
|
8
tensorflow/contrib/labeled_tensor/README.md
Normal file
8
tensorflow/contrib/labeled_tensor/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Labels for TensorFlow
|
||||
|
||||
LabeledTensor is a library for adding semantically meaningful dimension and
|
||||
coordinate labels to tensors in Tensorflow.
|
||||
|
||||
Maintainers:
|
||||
- Stephan Hoyer (shoyer@google.com, github.com/shoyer)
|
||||
- Eric Christiansen (ericmc@google.com, github.com/emchristiansen)
|
139
tensorflow/contrib/labeled_tensor/__init__.py
Normal file
139
tensorflow/contrib/labeled_tensor/__init__.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Labels for TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core as _core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import io_ops as _io_ops
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import nn
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import ops as _ops
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import sugar as _sugar
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
# Core types.
|
||||
Axis = _core.Axis
|
||||
Axes = _core.Axes
|
||||
LabeledTensor = _core.LabeledTensor
|
||||
|
||||
as_axis = _core.as_axis
|
||||
convert_to_labeled_tensor = _core.convert_to_labeled_tensor
|
||||
|
||||
identity = _core.identity
|
||||
slice = _core.slice_function # pylint: disable=redefined-builtin
|
||||
transpose = _core.transpose
|
||||
expand_dims = _core.expand_dims
|
||||
align = _core.align
|
||||
|
||||
axis_order_scope = _core.axis_order_scope
|
||||
check_axis_order = _core.check_axis_order
|
||||
impose_axis_order = _core.impose_axis_order
|
||||
AxisOrderError = _core.AxisOrderError
|
||||
|
||||
define_unary_op = _core.define_unary_op
|
||||
define_binary_op = _core.define_binary_op
|
||||
define_reduce_op = _ops.define_reduce_op
|
||||
|
||||
abs = _core.abs_function # pylint: disable=redefined-builtin
|
||||
neg = _core.neg
|
||||
sign = _core.sign
|
||||
inv = _core.inv
|
||||
square = _core.square
|
||||
round = _core.round_function # pylint: disable=redefined-builtin
|
||||
sqrt = _core.sqrt
|
||||
rsqrt = _core.rsqrt
|
||||
exp = _core.exp
|
||||
log = _core.log
|
||||
ceil = _core.ceil
|
||||
floor = _core.floor
|
||||
cos = _core.cos
|
||||
sin = _core.sin
|
||||
tan = _core.tan
|
||||
acos = _core.acos
|
||||
asin = _core.asin
|
||||
atan = _core.atan
|
||||
lgamma = _core.lgamma
|
||||
digamma = _core.digamma
|
||||
erf = _core.erf
|
||||
erfc = _core.erfc
|
||||
logical_not = _core.logical_not
|
||||
|
||||
add = _core.add
|
||||
sub = _core.sub
|
||||
mul = _core.mul
|
||||
div = _core.div
|
||||
mod = _core.mod
|
||||
pow = _core.pow_function # pylint: disable=redefined-builtin
|
||||
|
||||
equal = _core.equal
|
||||
greater = _core.greater
|
||||
greater_equal = _core.greater_equal
|
||||
not_equal = _core.not_equal
|
||||
less = _core.less
|
||||
less_equal = _core.less_equal
|
||||
logical_and = _core.logical_and
|
||||
logical_or = _core.logical_or
|
||||
logical_xor = _core.logical_xor
|
||||
|
||||
maximum = _core.maximum
|
||||
minimum = _core.minimum
|
||||
squared_difference = _core.squared_difference
|
||||
igamma = _core.igamma
|
||||
igammac = _core.igammac
|
||||
zeta = _core.zeta
|
||||
polygamma = _core.polygamma
|
||||
|
||||
select = _ops.select
|
||||
concat = _ops.concat
|
||||
pack = _ops.pack
|
||||
unpack = _ops.unpack
|
||||
reshape = _ops.reshape
|
||||
rename_axis = _ops.rename_axis
|
||||
random_crop = _ops.random_crop
|
||||
map_fn = _ops.map_fn
|
||||
squeeze = _ops.squeeze
|
||||
matmul = _ops.matmul
|
||||
tile = _ops.tile
|
||||
pad = _ops.pad
|
||||
constant = _ops.constant
|
||||
zeros_like = _ops.zeros_like
|
||||
ones_like = _ops.ones_like
|
||||
cast = _ops.cast
|
||||
verify_tensor_all_finite = _ops.verify_tensor_all_finite
|
||||
boolean_mask = _ops.boolean_mask
|
||||
where = _ops.where
|
||||
|
||||
reduce_all = _ops.reduce_all
|
||||
reduce_any = _ops.reduce_any
|
||||
reduce_logsumexp = _ops.reduce_logsumexp
|
||||
reduce_max = _ops.reduce_max
|
||||
reduce_mean = _ops.reduce_mean
|
||||
reduce_min = _ops.reduce_min
|
||||
reduce_prod = _ops.reduce_prod
|
||||
reduce_sum = _ops.reduce_sum
|
||||
|
||||
batch = _ops.batch
|
||||
shuffle_batch = _ops.shuffle_batch
|
||||
|
||||
FixedLenFeature = _io_ops.FixedLenFeature
|
||||
parse_example = _io_ops.parse_example
|
||||
parse_single_example = _io_ops.parse_single_example
|
||||
placeholder = _io_ops.placeholder
|
||||
|
||||
ReshapeCoder = _sugar.ReshapeCoder
|
322
tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
Normal file
322
tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
Normal file
@ -0,0 +1,322 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Minimal runtime type checking library.
|
||||
|
||||
This module should not be considered public API.
|
||||
"""
|
||||
# TODO(ericmc,shoyer): Delete this in favor of using pytype or mypy
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
|
||||
# used for register_type_abbreviation and _type_repr below.
|
||||
_TYPE_ABBREVIATIONS = {}
|
||||
|
||||
|
||||
class Type(object):
|
||||
"""Base class for type checker types.
|
||||
|
||||
The custom types defined in this module are based on types in the standard
|
||||
library's typing module (in Python 3.5):
|
||||
https://docs.python.org/3/library/typing.html
|
||||
|
||||
The only difference should be that we use actual instances of Type classes to
|
||||
represent custom types rather than the metaclass magic typing uses to create
|
||||
new class objects. In practice, all this should mean is that we use
|
||||
`List(int)` rather than `List[int]`.
|
||||
|
||||
Custom types should implement __instancecheck__ and inherit from Type. Every
|
||||
argument in the constructor must be a type or Type instance, and these
|
||||
arguments must be stored as a tuple on the `_types` attribute.
|
||||
"""
|
||||
|
||||
def __init__(self, *types):
|
||||
self._types = types
|
||||
|
||||
def __repr__(self):
|
||||
args_repr = ", ".join(repr(t) for t in self._types)
|
||||
return "typecheck.%s(%s)" % (type(self).__name__, args_repr)
|
||||
|
||||
|
||||
class _SingleArgumentType(Type):
|
||||
"""Use this subclass for parametric types that accept only one argument."""
|
||||
|
||||
def __init__(self, tpe):
|
||||
super(_SingleArgumentType, self).__init__(tpe)
|
||||
|
||||
@property
|
||||
def _type(self):
|
||||
tpe, = self._types # pylint: disable=unbalanced-tuple-unpacking
|
||||
return tpe
|
||||
|
||||
|
||||
class _TwoArgumentType(Type):
|
||||
"""Use this subclass for parametric types that accept two arguments."""
|
||||
|
||||
def __init__(self, first_type, second_type):
|
||||
super(_TwoArgumentType, self).__init__(first_type, second_type)
|
||||
|
||||
|
||||
class Union(Type):
|
||||
"""A sum type.
|
||||
|
||||
A correct type is any of the types provided.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return isinstance(instance, self._types)
|
||||
|
||||
|
||||
class Optional(_SingleArgumentType):
|
||||
"""An optional type.
|
||||
|
||||
A correct type is either the provided type or NoneType.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
# types.NoneType does not exist in Python 3
|
||||
return isinstance(instance, (self._type, type(None)))
|
||||
|
||||
|
||||
class List(_SingleArgumentType):
|
||||
"""A typed list.
|
||||
|
||||
A correct type is a list where each element has the single provided type.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return (isinstance(instance, list)
|
||||
and all(isinstance(x, self._type) for x in instance))
|
||||
|
||||
|
||||
class Sequence(_SingleArgumentType):
|
||||
"""A typed sequence.
|
||||
|
||||
A correct type is a sequence where each element has the single provided type.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return (isinstance(instance, collections.Sequence)
|
||||
and all(isinstance(x, self._type) for x in instance))
|
||||
|
||||
|
||||
class Collection(_SingleArgumentType):
|
||||
"""A sized, iterable container.
|
||||
|
||||
A correct type is an iterable and container with known size where each element
|
||||
has the single provided type.
|
||||
|
||||
We use this in preference to Iterable because we check each instance of the
|
||||
iterable at runtime, and hence need to avoid iterables that could be
|
||||
exhausted.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return (isinstance(instance, collections.Iterable)
|
||||
and isinstance(instance, collections.Sized)
|
||||
and isinstance(instance, collections.Container)
|
||||
and all(isinstance(x, self._type) for x in instance))
|
||||
|
||||
|
||||
class Tuple(Type):
|
||||
"""A typed tuple.
|
||||
|
||||
A correct type is a tuple with the correct length where each element has
|
||||
the correct type.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return (isinstance(instance, tuple)
|
||||
and len(instance) == len(self._types)
|
||||
and all(isinstance(x, t) for x, t in zip(instance, self._types)))
|
||||
|
||||
|
||||
class Mapping(_TwoArgumentType):
|
||||
"""A typed mapping.
|
||||
|
||||
A correct type has the correct parametric types for keys and values.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking
|
||||
return (isinstance(instance, collections.Mapping)
|
||||
and all(isinstance(k, key_type) for k in instance.keys())
|
||||
and all(isinstance(k, value_type) for k in instance.values()))
|
||||
|
||||
|
||||
class Dict(Mapping):
|
||||
"""A typed dict.
|
||||
|
||||
A correct type has the correct parametric types for keys and values.
|
||||
"""
|
||||
|
||||
def __instancecheck__(self, instance):
|
||||
return (isinstance(instance, dict)
|
||||
and super(Dict, self).__instancecheck__(instance))
|
||||
|
||||
|
||||
def _replace_forward_references(t, context):
|
||||
"""Replace forward references in the given type."""
|
||||
if isinstance(t, str):
|
||||
return context[t]
|
||||
elif isinstance(t, Type):
|
||||
return type(t)(*[_replace_forward_references(t, context) for t in t._types]) # pylint: disable=protected-access
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def register_type_abbreviation(name, alias):
|
||||
"""Register an abbreviation for a type in typecheck tracebacks.
|
||||
|
||||
This makes otherwise very long typecheck errors much more readable.
|
||||
|
||||
Example:
|
||||
typecheck.register_type_abbreviation(tf.Dimension, 'tf.Dimension')
|
||||
|
||||
Args:
|
||||
name: type or class to abbreviate.
|
||||
alias: string alias to substitute.
|
||||
"""
|
||||
_TYPE_ABBREVIATIONS[name] = alias
|
||||
|
||||
|
||||
def _type_repr(t):
|
||||
"""A more succinct repr for typecheck tracebacks."""
|
||||
string = repr(t)
|
||||
for type_, alias in _TYPE_ABBREVIATIONS.items():
|
||||
string = string.replace(repr(type_), alias)
|
||||
string = re.sub(r"<(class|type) '([\w.]+)'>", r"\2", string)
|
||||
string = re.sub(r"typecheck\.(\w+)", r"\1", string)
|
||||
return string
|
||||
|
||||
|
||||
class Error(TypeError):
|
||||
"""Exception for typecheck failures."""
|
||||
|
||||
|
||||
def accepts(*types):
|
||||
"""A decorator which checks the input types of a function.
|
||||
|
||||
Based on:
|
||||
http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
|
||||
The above draws from:
|
||||
https://www.python.org/dev/peps/pep-0318/
|
||||
|
||||
Args:
|
||||
*types: A list of Python types.
|
||||
|
||||
Returns:
|
||||
A function to use as a decorator.
|
||||
"""
|
||||
|
||||
def check_accepts(f):
|
||||
"""Check the types."""
|
||||
spec = inspect.getargspec(f)
|
||||
|
||||
num_function_arguments = len(spec.args)
|
||||
if len(types) != num_function_arguments:
|
||||
raise Error(
|
||||
"Function %r has %d arguments but only %d types were provided in the "
|
||||
"annotation." % (f, num_function_arguments, len(types)))
|
||||
|
||||
if spec.defaults:
|
||||
num_defaults = len(spec.defaults)
|
||||
for (name, a, t) in zip(spec.args[-num_defaults:],
|
||||
spec.defaults,
|
||||
types[-num_defaults:]):
|
||||
allowed_type = _replace_forward_references(t, f.__globals__)
|
||||
if not isinstance(a, allowed_type):
|
||||
raise Error("default argument value %r of type %r is not an instance "
|
||||
"of the allowed type %s for the %s argument to %r"
|
||||
% (a, type(a), _type_repr(allowed_type), name, f))
|
||||
|
||||
@functools.wraps(f)
|
||||
def new_f(*args, **kwds):
|
||||
"""A helper function."""
|
||||
for (a, t) in zip(args, types):
|
||||
allowed_type = _replace_forward_references(t, f.__globals__)
|
||||
if not isinstance(a, allowed_type):
|
||||
raise Error("%r of type %r is not an instance of the allowed type %s "
|
||||
"for %r" % (a, type(a), _type_repr(allowed_type), f))
|
||||
return f(*args, **kwds)
|
||||
|
||||
return new_f
|
||||
|
||||
return check_accepts
|
||||
|
||||
|
||||
def returns(*types):
|
||||
"""A decorator which checks the return types of a function.
|
||||
|
||||
Based on:
|
||||
http://stackoverflow.com/questions/15299878/how-to-use-python-decorators-to-check-function-arguments
|
||||
The above draws from:
|
||||
https://www.python.org/dev/peps/pep-0318/
|
||||
|
||||
Args:
|
||||
*types: A list of Python types.
|
||||
A list of one element corresponds to a single return value.
|
||||
A list of several elements corresponds to several return values.
|
||||
Note that a function with no explicit return value has an implicit
|
||||
NoneType return and should be annotated correspondingly.
|
||||
|
||||
Returns:
|
||||
A function to use as a decorator.
|
||||
"""
|
||||
|
||||
def check_returns(f):
|
||||
"""Check the types."""
|
||||
if not types:
|
||||
raise TypeError("A return type annotation must contain at least one type")
|
||||
|
||||
@functools.wraps(f)
|
||||
def new_f(*args, **kwds):
|
||||
"""A helper function."""
|
||||
return_value = f(*args, **kwds)
|
||||
|
||||
if len(types) == 1:
|
||||
# The function has a single return value.
|
||||
allowed_type = _replace_forward_references(types[0], f.__globals__)
|
||||
if not isinstance(return_value, allowed_type):
|
||||
raise Error("%r of type %r is not an instance of the allowed type %s "
|
||||
"for %r"
|
||||
% (return_value, type(return_value),
|
||||
_type_repr(allowed_type), f))
|
||||
|
||||
else:
|
||||
if len(return_value) != len(types):
|
||||
raise Error(
|
||||
"Function %r has %d return values but only %d types were "
|
||||
"provided in the annotation." %
|
||||
(f, len(return_value), len(types)))
|
||||
|
||||
for (r, t) in zip(return_value, types):
|
||||
allowed_type = _replace_forward_references(t, f.__globals__)
|
||||
if not isinstance(r, allowed_type):
|
||||
raise Error("%r of type %r is not an instance of allowed type %s "
|
||||
"for %r" % (r, type(r), _type_repr(allowed_type), f))
|
||||
|
||||
return return_value
|
||||
|
||||
return new_f
|
||||
|
||||
return check_returns
|
1197
tensorflow/contrib/labeled_tensor/python/ops/core.py
Normal file
1197
tensorflow/contrib/labeled_tensor/python/ops/core.py
Normal file
File diff suppressed because it is too large
Load Diff
842
tensorflow/contrib/labeled_tensor/python/ops/core_test.py
Normal file
842
tensorflow/contrib/labeled_tensor/python/ops/core_test.py
Normal file
@ -0,0 +1,842 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import operator
|
||||
import re
|
||||
import textwrap
|
||||
|
||||
import numpy as np
|
||||
from six.moves import range # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import test_util
|
||||
|
||||
|
||||
class AxisTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
d_7 = tf.Dimension(7)
|
||||
p_rgb = ['red', 'green', 'blue']
|
||||
|
||||
self.i_7 = core.Axis('7', d_7)
|
||||
self.i_7p = core.Axis('7prime', d_7)
|
||||
self.i_rgb = core.Axis('rgb', p_rgb)
|
||||
self.i_range = core.Axis('range', range(7))
|
||||
self.i_unknown = core.Axis('unknown', None)
|
||||
|
||||
def test_equality(self):
|
||||
|
||||
axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
|
||||
for i, axis_0 in enumerate(axes):
|
||||
for j, axis_1 in enumerate(axes):
|
||||
if i == j:
|
||||
self.assertEqual(axis_0, axis_1)
|
||||
else:
|
||||
self.assertNotEqual(axis_0, axis_1)
|
||||
|
||||
def test_axis_value(self):
|
||||
self.assertEqual(self.i_7.value, tf.Dimension(7))
|
||||
self.assertTrue(self.i_range.value == tuple(range(7)))
|
||||
|
||||
def test_axis_input(self):
|
||||
axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
|
||||
for axis in axes:
|
||||
self.assertEqual(axis, core.Axis(axis.name, axis.value))
|
||||
|
||||
def test_axis_value_input(self):
|
||||
axis = self.i_range
|
||||
for value in [range(7), list(range(7)), np.arange(7)]:
|
||||
self.assertEqual(axis, core.Axis(axis.name, value))
|
||||
|
||||
def test_size(self):
|
||||
self.assertEqual(len(self.i_7), 7)
|
||||
self.assertEqual(len(self.i_rgb), 3)
|
||||
self.assertEqual(len(self.i_range), 7)
|
||||
self.assertEqual(self.i_unknown.size, None)
|
||||
|
||||
def test_concat_single(self):
|
||||
red = core.Axis('rgb', ['red'])
|
||||
|
||||
self.assertEqual(core.concat_axes([red]), red)
|
||||
|
||||
def test_concat_many(self):
|
||||
red = core.Axis('rgb', ['red'])
|
||||
green = core.Axis('rgb', ['green'])
|
||||
blue = core.Axis('rgb', ['blue'])
|
||||
red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])
|
||||
|
||||
self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
|
||||
|
||||
def test_concat_different_names(self):
|
||||
red = core.Axis('red', ['red'])
|
||||
green = core.Axis('green', ['red'])
|
||||
with self.assertRaises(ValueError):
|
||||
core.concat_axes([red, green])
|
||||
|
||||
def test_concat_unknown(self):
|
||||
red = core.Axis('rgb', None)
|
||||
green = core.Axis('rgb', None)
|
||||
self.assertEqual(core.concat_axes([red, green]), red)
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7))
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(TypeError):
|
||||
core.Axis('foo', [{}])
|
||||
with self.assertRaises(ValueError):
|
||||
core.Axis('foo', [1, 2, 3, 1])
|
||||
red = core.Axis('foo', ['red'])
|
||||
with self.assertRaises(tc.Error):
|
||||
core.concat_axes([red, 1])
|
||||
|
||||
def test_as_axis(self):
|
||||
self.assertEqual(self.i_7, core.as_axis(('7', 7)))
|
||||
self.assertEqual(self.i_7, core.as_axis(self.i_7))
|
||||
|
||||
|
||||
class AxesTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
d_7 = tf.Dimension(7)
|
||||
d_8 = tf.Dimension(8)
|
||||
p_rgb = ['red', 'green', 'blue']
|
||||
p_range = range(7)
|
||||
|
||||
self.i_8 = core.Axis('8', d_8)
|
||||
|
||||
self.a0 = core.Axes([('d7', d_7)])
|
||||
self.a1 = core.Axes([('d7', d_7)])
|
||||
self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
|
||||
self.a3 = core.Axes([('8', d_8), ('range', p_range)])
|
||||
|
||||
def test_equality(self):
|
||||
self.assertEqual(self.a0, self.a0)
|
||||
self.assertEqual(self.a0, self.a1)
|
||||
self.assertNotEqual(self.a0, self.a2)
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0))
|
||||
|
||||
def test_remove(self):
|
||||
a = self.a3.remove('range')
|
||||
self.assertEqual(a, core.Axes([self.i_8]))
|
||||
with self.assertRaises(KeyError):
|
||||
self.a3.remove('foobar')
|
||||
|
||||
def test_typecheck_error_message(self):
|
||||
pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
|
||||
'Union(Union(numpy.ndarray, %s, list, tuple), '
|
||||
'Optional(Union(tensorflow.Dimension, int))))))' %
|
||||
range.__name__)
|
||||
regexp = re.escape(pattern).replace(re.escape('...'), '.*')
|
||||
with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
|
||||
core.Axes(None)
|
||||
|
||||
|
||||
class LabeledTensorTest(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
tensor = tf.ones([7, 3, 8, 1])
|
||||
a0 = ('x', range(7))
|
||||
a1 = ('channel', ['red', 'green', 'blue'])
|
||||
a2 = ('y', 8)
|
||||
a3 = ('z', tf.Dimension(1))
|
||||
|
||||
self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
|
||||
|
||||
def test_repr(self):
|
||||
pattern = textwrap.dedent("""\
|
||||
<LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32
|
||||
axes=[('x', ...),
|
||||
('channel', ...),
|
||||
('y', Dimension(8)),
|
||||
('z', Dimension(1))]>""")
|
||||
regexp = re.escape(pattern).replace(re.escape('...'), '.*')
|
||||
self.assertRegexpMatches(repr(self.lt), regexp)
|
||||
|
||||
def test_reuse_existing_axes(self):
|
||||
alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes)
|
||||
self.assertLabeledTensorsEqual(alt_lt, self.lt)
|
||||
|
||||
def test_reuse_existing_axis_objects(self):
|
||||
alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values())
|
||||
self.assertLabeledTensorsEqual(alt_lt, self.lt)
|
||||
|
||||
def test_indexing_scalars(self):
|
||||
actual = self.lt[:, :, :, 0]
|
||||
expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0],
|
||||
list(self.lt.axes.values())[:-1])
|
||||
self.assertLabeledTensorsEqual(actual, expected)
|
||||
|
||||
actual = self.lt[1, :, :, 0]
|
||||
expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0],
|
||||
list(self.lt.axes.values())[1:-1])
|
||||
self.assertLabeledTensorsEqual(actual, expected)
|
||||
|
||||
actual = self.lt[1, 2, :, 0]
|
||||
expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0],
|
||||
list(self.lt.axes.values())[2:-1])
|
||||
self.assertLabeledTensorsEqual(actual, expected)
|
||||
|
||||
def test_indexing_1d(self):
|
||||
lt_1d = self.lt[1, 2, :, 0]
|
||||
actual = lt_1d[3]
|
||||
expected = core.LabeledTensor(lt_1d.tensor[3], [])
|
||||
self.assertLabeledTensorsEqual(actual, expected)
|
||||
|
||||
def test_indexing_slices(self):
|
||||
actual = self.lt[:3, :, :, :]
|
||||
axes = [('x', range(3))] + list(self.lt.axes.values())[1:]
|
||||
expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes)
|
||||
self.assertLabeledTensorsEqual(actual, expected)
|
||||
|
||||
def test_invalid_indexing(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.lt[0] # pylint: disable=pointless-statement
|
||||
with self.assertRaises(ValueError):
|
||||
self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement
|
||||
|
||||
def test_unknown_size(self):
|
||||
tensor = tf.placeholder(tf.string, [None])
|
||||
actual = core.LabeledTensor(tensor, ['x'])
|
||||
self.assertIsNone(actual.axes['x'].size)
|
||||
self.assertIs(actual.axes['x'].value, tensor.get_shape()[0])
|
||||
|
||||
def test_eq(self):
|
||||
self.assertEqual(self.lt, self.lt)
|
||||
self.assertNotEqual(self.lt, self.lt.tensor)
|
||||
self.assertNotEqual(self.lt.tensor, self.lt)
|
||||
|
||||
def test_hash(self):
|
||||
lt1 = self.lt
|
||||
lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes)
|
||||
self.assertEqual(lt1, lt2)
|
||||
self.assertEqual(hash(lt1), hash(lt2))
|
||||
|
||||
def test_name(self):
|
||||
self.assertEqual(self.lt.name, self.lt.tensor.name)
|
||||
|
||||
def test_dtype(self):
|
||||
self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
|
||||
|
||||
def test_get_shape(self):
|
||||
self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())
|
||||
|
||||
def test_convert_to_tensor(self):
|
||||
expected = self.lt.tensor
|
||||
actual = tf.convert_to_tensor(self.lt)
|
||||
self.assertIs(expected, actual)
|
||||
|
||||
|
||||
class Base(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
self.x_size = 7
|
||||
self.channel_size = 3
|
||||
self.z_size = 4
|
||||
self.probs_size = 11
|
||||
|
||||
tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
|
||||
self.probs_size)
|
||||
tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
|
||||
self.probs_size])
|
||||
a0 = ('x', range(self.x_size))
|
||||
a1 = ('channel', ['red', 'green', 'blue'])
|
||||
a2 = 'z'
|
||||
a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
|
||||
|
||||
self.tensor = tensor
|
||||
self.a0 = a0
|
||||
self.a1 = a1
|
||||
self.a2 = a2
|
||||
self.a3 = a3
|
||||
self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
|
||||
|
||||
self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0,
|
||||
'channel': 0})
|
||||
self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
|
||||
'z': 0})
|
||||
|
||||
|
||||
class IdentityTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
identity_lt = core.identity(self.original_lt)
|
||||
self.assertIn('lt_identity', identity_lt.name)
|
||||
|
||||
|
||||
class SliceFunctionTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
select_lt = core.slice_function(self.original_lt, {'channel': 1})
|
||||
self.assertIn('lt_slice', select_lt.name)
|
||||
|
||||
def test_scalar(self):
|
||||
select_lt = core.slice_function(self.original_lt, {'channel': 1})
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
|
||||
self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slice(self):
|
||||
select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)})
|
||||
|
||||
a1_sliced = ('channel', ['red', 'green'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
|
||||
[self.a0, a1_sliced, self.a2, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slices(self):
|
||||
select_lt = core.slice_function(self.original_lt, {'x': slice(1, 5),
|
||||
'channel': slice(1,
|
||||
None)})
|
||||
|
||||
a0_sliced = ('x', range(1, 5))
|
||||
a1_sliced = ('channel', ['green', 'blue'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
|
||||
[a0_sliced, a1_sliced, self.a2, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slice_unlabeled(self):
|
||||
select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)})
|
||||
|
||||
a2_sliced = 'z'
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :],
|
||||
[self.a0, self.a1, a2_sliced, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slice_unknown_shape(self):
|
||||
lt = core.LabeledTensor(tf.placeholder(tf.float32, [None, 1]), ['x', 'y'])
|
||||
sliced_lt = core.slice_function(lt, {'y': 0})
|
||||
self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']])
|
||||
|
||||
|
||||
class TransposeTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
transpose_lt = core.transpose(self.original_lt,
|
||||
self.original_lt.axes.keys())
|
||||
self.assertIn('lt_transpose', transpose_lt.name)
|
||||
|
||||
def test_identity(self):
|
||||
transpose_lt = core.transpose(self.original_lt,
|
||||
self.original_lt.axes.keys())
|
||||
golden_lt = self.original_lt
|
||||
|
||||
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
|
||||
|
||||
def test(self):
|
||||
transpose_lt = core.transpose(self.original_lt, ['z', 'channel', 'x',
|
||||
'probs'])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.transpose(self.tensor, [2, 1, 0, 3]),
|
||||
[self.a2, self.a1, self.a0, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
|
||||
|
||||
def test_default_axis_order(self):
|
||||
transpose_lt = core.transpose(self.original_lt)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.transpose(self.tensor, [3, 2, 1, 0]),
|
||||
list(reversed(list(self.original_lt.axes.values()))))
|
||||
|
||||
self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
core.transpose(self.original_lt, ['channel', 'x', 'probs'])
|
||||
with self.assertRaises(ValueError):
|
||||
core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])
|
||||
|
||||
|
||||
class ExpandDimsTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
|
||||
self.assertIn('lt_expand', expand_lt.name)
|
||||
|
||||
def test_identity(self):
|
||||
expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
|
||||
golden_lt = self.original_lt
|
||||
|
||||
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
|
||||
|
||||
def test(self):
|
||||
expand_lt = core.expand_dims(self.original_lt, ['foo', 'x', 'bar',
|
||||
'channel', 'z', 'probs',
|
||||
'grok'])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.tensor, [1, self.x_size, 1, self.channel_size,
|
||||
self.z_size, self.probs_size, 1]),
|
||||
['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])
|
||||
|
||||
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
|
||||
|
||||
def test_label(self):
|
||||
expand_lt = core.expand_dims(self.original_lt, ['x',
|
||||
'channel',
|
||||
('foo', 'bar'),
|
||||
'z',
|
||||
'probs',])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.tensor, [self.x_size, self.channel_size, 1, self.z_size,
|
||||
self.probs_size]),
|
||||
[self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(expand_lt, golden_lt)
|
||||
|
||||
def test_unknown_dimension(self):
|
||||
orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
|
||||
expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
|
||||
self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.expand_dims(self.original_lt, ['foo', 'not_x', 'bar', 'channel', 'z',
|
||||
'probs', 'grok'])
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.expand_dims(self.original_lt, ['foo', 'z', 'bar', 'channel', 'x',
|
||||
'probs', 'grok'])
|
||||
|
||||
|
||||
class AxisOrderScopeTest(Base):
|
||||
|
||||
def test(self):
|
||||
xyz = ['x', 'y', 'z']
|
||||
abc = ['a', 'b', 'c']
|
||||
|
||||
self.assertIsNone(core.get_axis_order())
|
||||
|
||||
with core.axis_order_scope(xyz):
|
||||
self.assertEqual(core.get_axis_order(), xyz)
|
||||
|
||||
with core.axis_order_scope():
|
||||
self.assertIsNone(core.get_axis_order())
|
||||
|
||||
with core.axis_order_scope(abc):
|
||||
self.assertEqual(core.get_axis_order(), abc)
|
||||
|
||||
self.assertIsNone(core.get_axis_order())
|
||||
|
||||
self.assertEqual(core.get_axis_order(), xyz)
|
||||
|
||||
self.assertIsNone(core.get_axis_order())
|
||||
|
||||
|
||||
class CheckAxisOrderTest(Base):
|
||||
|
||||
def test_passes(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
|
||||
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
|
||||
core.check_axis_order(lt, axis_order)
|
||||
|
||||
lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[1:])
|
||||
core.check_axis_order(lt, axis_order)
|
||||
|
||||
lt = core.LabeledTensor(tf.ones((1, 1, 1)), axis_order[:-1])
|
||||
core.check_axis_order(lt, axis_order)
|
||||
|
||||
def test_invalid(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.check_axis_order(lt)
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.check_axis_order(lt, axis_order[:-1])
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.check_axis_order(lt, axis_order[::-1])
|
||||
|
||||
def test_scope(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
lt = core.LabeledTensor(tf.ones((1, 1, 1, 1)), axis_order)
|
||||
with core.axis_order_scope(axis_order):
|
||||
core.check_axis_order(lt)
|
||||
|
||||
|
||||
class ImposeAxisOrderTest(Base):
|
||||
|
||||
def test_identity(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
|
||||
actual = core.impose_axis_order(lt, axis_order)
|
||||
self.assertLabeledTensorsEqual(lt, actual)
|
||||
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
|
||||
actual = core.impose_axis_order(lt, axis_order)
|
||||
self.assertLabeledTensorsEqual(lt, actual)
|
||||
|
||||
def test_reverse(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
|
||||
actual = core.impose_axis_order(lt, axis_order[::-1])
|
||||
expected = core.transpose(lt, axis_order[::-1])
|
||||
self.assertLabeledTensorsEqual(expected, actual)
|
||||
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(6), (1, 2, 3)), axis_order[:3])
|
||||
actual = core.impose_axis_order(lt, axis_order[::-1])
|
||||
expected = core.transpose(lt, ['y', 'x', 'w'])
|
||||
self.assertLabeledTensorsEqual(expected, actual)
|
||||
|
||||
def test_scope(self):
|
||||
axis_order = ['w', 'x', 'y', 'z']
|
||||
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(24), (1, 2, 3, 4)), axis_order)
|
||||
expected = core.transpose(lt, axis_order[::-1])
|
||||
with core.axis_order_scope(axis_order[::-1]):
|
||||
actual = core.impose_axis_order(lt)
|
||||
self.assertLabeledTensorsEqual(expected, actual)
|
||||
|
||||
def test_invalid(self):
|
||||
lt = core.LabeledTensor(tf.reshape(tf.range(2), (1, 2)), ['x', 'y'])
|
||||
with self.assertRaises(ValueError):
|
||||
core.impose_axis_order(lt)
|
||||
with self.assertRaises(ValueError):
|
||||
core.impose_axis_order(lt, ['x'])
|
||||
|
||||
|
||||
class FindConsistentOrderingTest(Base):
|
||||
|
||||
def test(self):
|
||||
cases = [
|
||||
([], [], []),
|
||||
(['x'], [], ['x']),
|
||||
([], ['x'], ['x']),
|
||||
(['x'], ['x'], ['x']),
|
||||
(['x'], ['y'], ['x', 'y']),
|
||||
(['y'], ['x'], ['y', 'x']),
|
||||
(['x', 'y'], ['x', 'y'], ['x', 'y']),
|
||||
(['x', 'y'], ['y', 'x'], None),
|
||||
(['x', 'y'], ['y', 'z'], ['x', 'y', 'z']),
|
||||
(['x', 'z'], ['y', 'z'], ['x', 'y', 'z']),
|
||||
(['x', 'y'], ['x', 'z'], ['x', 'y', 'z']),
|
||||
(['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']),
|
||||
(['x', 'y', 'z'], ['z', 'x'], None),
|
||||
(['x', 'y', 'z'], ['x'], ['x', 'y', 'z']),
|
||||
([], ['x', 'y', 'z'], ['x', 'y', 'z']),
|
||||
]
|
||||
for a, b, expected in cases:
|
||||
actual = core._find_consistent_ordering(a, b)
|
||||
msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r'
|
||||
% (a, b, expected, actual))
|
||||
self.assertEqual(expected, actual, msg=msg)
|
||||
|
||||
|
||||
class AlignTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)
|
||||
self.assertIn('lt_align', align_lt_0.name)
|
||||
self.assertIn('/0', align_lt_0.name)
|
||||
self.assertIn('lt_align', align_lt_1.name)
|
||||
self.assertIn('/1', align_lt_1.name)
|
||||
|
||||
def test_identical_shaped_inputs(self):
|
||||
offset_tensor = self.original_lt.tensor + 1
|
||||
offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)
|
||||
|
||||
align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,
|
||||
offset_lt)
|
||||
|
||||
self.assertLabeledTensorsEqual(align_lt, self.original_lt)
|
||||
self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)
|
||||
self.assertEqual(broadcast_axes, self.original_lt.axes)
|
||||
|
||||
def test_different_inputs(self):
|
||||
# The correct axis ordering is ['x', 'channel', 'probs'].
|
||||
align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
|
||||
self.x_probs_lt, self.channel_probs_lt)
|
||||
|
||||
x_probs_golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]),
|
||||
[self.a0, 'channel', self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)
|
||||
|
||||
channel_probs_golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.channel_probs_lt.tensor,
|
||||
[1, self.channel_size, self.probs_size]),
|
||||
['x', self.a1, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(align_channel_probs_lt,
|
||||
channel_probs_golden_lt)
|
||||
|
||||
self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
|
||||
|
||||
def test_axis_order_scope(self):
|
||||
xz_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'z'])
|
||||
yz_lt = core.LabeledTensor(tf.ones((4, 3)), ['y', 'z'])
|
||||
|
||||
_, _, broadcast_axes = core.align(xz_lt, yz_lt)
|
||||
self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
|
||||
|
||||
_, _, broadcast_axes = core.align(yz_lt, xz_lt)
|
||||
self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])
|
||||
|
||||
with core.axis_order_scope(['x', 'y', 'z']):
|
||||
_, _, broadcast_axes = core.align(yz_lt, xz_lt)
|
||||
self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
|
||||
|
||||
with core.axis_order_scope(['x', 'y']):
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.align(xz_lt, yz_lt)
|
||||
with self.assertRaises(core.AxisOrderError):
|
||||
core.align(yz_lt, xz_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
lt_0 = core.LabeledTensor(tf.zeros([5]), [('a', range(5))])
|
||||
lt_1 = core.LabeledTensor(tf.zeros([5]), [('a', range(1, 6))])
|
||||
with self.assertRaises(ValueError):
|
||||
core.align(lt_0, lt_1)
|
||||
|
||||
|
||||
class ConvertToLabeledTensorTest(Base):
|
||||
|
||||
# TODO(shoyer): Simplify these tests once we can reuse labeled tensors in
|
||||
# assertLabeledTensorsEqual.
|
||||
|
||||
def test_labeled_tensor(self):
|
||||
actual = core.convert_to_labeled_tensor(self.original_lt)
|
||||
self.assertLabeledTensorsEqual(actual, self.original_lt)
|
||||
|
||||
def test_python_scalar(self):
|
||||
actual = core.convert_to_labeled_tensor(42)
|
||||
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
|
||||
self.assertLabeledTensorsEqual(actual, golden_lt)
|
||||
|
||||
def test_numpy_array(self):
|
||||
actual = core.convert_to_labeled_tensor(np.array(42))
|
||||
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
|
||||
self.assertLabeledTensorsEqual(actual, golden_lt)
|
||||
|
||||
def test_tensor(self):
|
||||
actual = core.convert_to_labeled_tensor(tf.constant(42))
|
||||
golden_lt = core.LabeledTensor(tf.convert_to_tensor(42), [])
|
||||
self.assertLabeledTensorsEqual(actual, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
core.convert_to_labeled_tensor(tf.range(5))
|
||||
with self.assertRaises(ValueError):
|
||||
core.convert_to_labeled_tensor(np.array([1, 2]))
|
||||
|
||||
|
||||
class DocStringCheckMixin(object):
|
||||
# requires self.ops to be defined
|
||||
|
||||
def test_function_docstring_and_name(self):
|
||||
for op_name, _, _, lt_op in self.ops:
|
||||
if lt_op is not None:
|
||||
self.assertIn('tf.%s' % op_name, lt_op.__doc__)
|
||||
self.assertEqual(op_name, lt_op.__name__)
|
||||
|
||||
|
||||
class UnaryOpsTestsMixin(object):
|
||||
# requires self.ops and self.test_lt to be defined
|
||||
|
||||
def test_core_op(self):
|
||||
for op_name, _, tf_op, lt_op in self.ops:
|
||||
if tf_op is not None:
|
||||
golden_lt = core.LabeledTensor(tf_op(self.test_lt.tensor),
|
||||
self.test_lt.axes)
|
||||
actual_lt = lt_op(self.test_lt)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
|
||||
|
||||
def test_infix(self):
|
||||
for op_name, infix_op, _, _ in self.ops:
|
||||
if infix_op is not None:
|
||||
expected_lt = core.LabeledTensor(infix_op(self.test_lt.tensor),
|
||||
self.test_lt.axes)
|
||||
actual_lt = infix_op(self.test_lt)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
|
||||
|
||||
|
||||
class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
|
||||
|
||||
def setUp(self):
|
||||
super(CoreUnaryOpsTest, self).setUp()
|
||||
|
||||
self.ops = [
|
||||
('abs', operator.abs, tf.abs, core.abs_function),
|
||||
('neg', operator.neg, tf.neg, core.neg),
|
||||
# TODO(shoyer): add unary + to core TensorFlow
|
||||
('pos', None, None, None),
|
||||
('sign', None, tf.sign, core.sign),
|
||||
('inv', None, tf.inv, core.inv),
|
||||
('square', None, tf.square, core.square),
|
||||
('round', None, tf.round, core.round_function),
|
||||
('sqrt', None, tf.sqrt, core.sqrt),
|
||||
('rsqrt', None, tf.rsqrt, core.rsqrt),
|
||||
('log', None, tf.log, core.log),
|
||||
('exp', None, tf.exp, core.exp),
|
||||
('log', None, tf.log, core.log),
|
||||
('ceil', None, tf.ceil, core.ceil),
|
||||
('floor', None, tf.floor, core.floor),
|
||||
('cos', None, tf.cos, core.cos),
|
||||
('sin', None, tf.sin, core.sin),
|
||||
('tan', None, tf.tan, core.tan),
|
||||
('acos', None, tf.acos, core.acos),
|
||||
('asin', None, tf.asin, core.asin),
|
||||
('atan', None, tf.atan, core.atan),
|
||||
('lgamma', None, tf.lgamma, core.lgamma),
|
||||
('digamma', None, tf.digamma, core.digamma),
|
||||
('erf', None, tf.erf, core.erf),
|
||||
('erfc', None, tf.erfc, core.erfc),
|
||||
('lgamma', None, tf.lgamma, core.lgamma),
|
||||
]
|
||||
total_size = np.prod([v.size for v in self.original_lt.axes.values()])
|
||||
self.test_lt = core.LabeledTensor(
|
||||
tf.cast(self.original_lt, tf.float32) / total_size,
|
||||
self.original_lt.axes)
|
||||
|
||||
|
||||
class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
|
||||
|
||||
def setUp(self):
|
||||
super(LogicalNotTest, self).setUp()
|
||||
self.ops = [
|
||||
('logical_not', operator.invert, tf.logical_not, core.logical_not),
|
||||
]
|
||||
self.test_lt = self.original_lt < 10
|
||||
|
||||
|
||||
class BinaryOpsTestsMixin(object):
|
||||
# requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast
|
||||
# and self.test_lt_2_broadcast to be defined
|
||||
|
||||
def test_core_op(self):
|
||||
for op_name, _, tf_op, lt_op in self.ops:
|
||||
golden_tensor = tf_op(self.test_lt_1_broadcast,
|
||||
self.test_lt_2_broadcast)
|
||||
golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes)
|
||||
actual_lt = lt_op(self.test_lt_1, self.test_lt_2)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
|
||||
|
||||
def test_infix(self):
|
||||
for op_name, infix_op, _, lt_op in self.ops:
|
||||
if infix_op is not None:
|
||||
expected_lt = lt_op(self.test_lt_1, self.test_lt_2)
|
||||
actual_lt = infix_op(self.test_lt_1, self.test_lt_2)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
|
||||
|
||||
|
||||
class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
|
||||
|
||||
def setUp(self):
|
||||
super(CoreBinaryOpsTest, self).setUp()
|
||||
|
||||
self.x_probs_broadcast_tensor = tf.reshape(
|
||||
self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])
|
||||
|
||||
self.channel_probs_broadcast_tensor = tf.reshape(
|
||||
self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])
|
||||
|
||||
# == and != are not element-wise for tf.Tensor, so they shouldn't be
|
||||
# elementwise for LabeledTensor, either.
|
||||
self.ops = [
|
||||
('add', operator.add, tf.add, core.add),
|
||||
('sub', operator.sub, tf.sub, core.sub),
|
||||
('mul', operator.mul, tf.mul, core.mul),
|
||||
('div', operator.truediv, tf.div, core.div),
|
||||
('mod', operator.mod, tf.mod, core.mod),
|
||||
('pow', operator.pow, tf.pow, core.pow_function),
|
||||
('equal', None, tf.equal, core.equal),
|
||||
('less', operator.lt, tf.less, core.less),
|
||||
('less_equal', operator.le, tf.less_equal, core.less_equal),
|
||||
('not_equal', None, tf.not_equal, core.not_equal),
|
||||
('greater', operator.gt, tf.greater, core.greater),
|
||||
('greater_equal', operator.ge, tf.greater_equal, core.greater_equal),
|
||||
]
|
||||
self.test_lt_1 = self.x_probs_lt
|
||||
self.test_lt_2 = self.channel_probs_lt
|
||||
self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
|
||||
self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
|
||||
self.broadcast_axes = [self.a0, self.a1, self.a3]
|
||||
|
||||
def test_reflexive(self):
|
||||
labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division
|
||||
for op_name, infix_op, _, lt_op in self.ops:
|
||||
if infix_op is not None:
|
||||
expected_lt = lt_op(2, labeled_tensor)
|
||||
actual_lt = infix_op(2, labeled_tensor)
|
||||
# Python uses greater for the reflexive version of less (and vise-versa)
|
||||
if 'less' in op_name:
|
||||
op_name = op_name.replace('less', 'greater')
|
||||
elif 'greater' in op_name:
|
||||
op_name = op_name.replace('greater', 'less')
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(expected_lt, actual_lt)
|
||||
|
||||
|
||||
class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
|
||||
|
||||
def setUp(self):
|
||||
super(LogicalBinaryOpsTest, self).setUp()
|
||||
|
||||
self.ops = [
|
||||
('logical_and', operator.and_, tf.logical_and, core.logical_and),
|
||||
('logical_or', operator.or_, tf.logical_or, core.logical_or),
|
||||
('logical_xor', operator.xor, tf.logical_xor, core.logical_xor),
|
||||
]
|
||||
self.test_lt_1 = self.original_lt < 10
|
||||
self.test_lt_2 = self.original_lt < 5
|
||||
self.test_lt_1_broadcast = self.test_lt_1.tensor
|
||||
self.test_lt_2_broadcast = self.test_lt_2.tensor
|
||||
self.broadcast_axes = self.test_lt_1.axes
|
||||
|
||||
|
||||
class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
|
||||
|
||||
def setUp(self):
|
||||
super(FloatBinaryOpsTest, self).setUp()
|
||||
|
||||
self.ops = [
|
||||
('igamma', None, tf.igamma, core.igamma),
|
||||
('igammac', None, tf.igammac, core.igammac),
|
||||
('zeta', None, tf.zeta, core.zeta),
|
||||
('polygamma', None, tf.polygamma, core.polygamma),
|
||||
('maximum', None, tf.maximum, core.maximum),
|
||||
('minimum', None, tf.minimum, core.minimum),
|
||||
('squared_difference', None, tf.squared_difference,
|
||||
core.squared_difference),
|
||||
]
|
||||
total_size = np.prod([v.size for v in self.original_lt.axes.values()])
|
||||
test_lt = core.LabeledTensor(
|
||||
tf.cast(self.original_lt, tf.float32) / total_size,
|
||||
self.original_lt.axes)
|
||||
self.test_lt_1 = test_lt
|
||||
self.test_lt_2 = 1.0 - test_lt
|
||||
self.test_lt_1_broadcast = self.test_lt_1.tensor
|
||||
self.test_lt_2_broadcast = self.test_lt_2.tensor
|
||||
self.broadcast_axes = self.test_lt_1.axes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
178
tensorflow/contrib/labeled_tensor/python/ops/io_ops.py
Normal file
178
tensorflow/contrib/labeled_tensor/python/ops/io_ops.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Input parsing code for LabeledTensors."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six import string_types
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
|
||||
|
||||
class FixedLenFeature(object):
|
||||
"""Configuration for parsing a fixed-length input feature.
|
||||
|
||||
Fields:
|
||||
axes: A list of Axis objects or tuples (axis_name, axis_value),
|
||||
where `axis_name` is a string and `axis_value` is None (unknown size), an
|
||||
integer or a list of tick labels.
|
||||
dtype: Data type of input.
|
||||
default_value: Value to be used if an example is missing this feature. It
|
||||
must be compatible with `dtype`.
|
||||
"""
|
||||
|
||||
def __init__(self, axes, dtype, default_value=None):
|
||||
self._axes = [core.as_axis(a) for a in axes]
|
||||
self._dtype = dtype
|
||||
self._default_value = default_value
|
||||
|
||||
@property
|
||||
def axes(self):
|
||||
return self._axes
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
return self._default_value
|
||||
|
||||
|
||||
@tc.returns(tc.Dict(string_types, parsing_ops.FixedLenFeature))
|
||||
@tc.accepts(tc.Mapping(string_types, FixedLenFeature))
|
||||
def _labeled_to_unlabeled_features(features):
|
||||
"""Convert a dict of lt.FixedLenFeature into a dict of tf.FixedLenFeature."""
|
||||
unlabeled_features = {}
|
||||
for name, labeled_feature in features.items():
|
||||
shape = [ax.size for ax in labeled_feature.axes]
|
||||
if any(size is None for size in shape):
|
||||
# This should be caught on the TensorFlow side, but it isn't yet:
|
||||
# https://github.com/tensorflow/tensorflow/issues/2874
|
||||
raise ValueError('axes with unknown size are not supported')
|
||||
dtype = labeled_feature.dtype
|
||||
default_value = labeled_feature.default_value
|
||||
unlabeled_features[name] = parsing_ops.FixedLenFeature(
|
||||
shape, dtype, default_value)
|
||||
return unlabeled_features
|
||||
|
||||
|
||||
@tc.returns(tc.Dict(string_types, core.LabeledTensor))
|
||||
@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
|
||||
tc.Optional(string_types), object)
|
||||
def parse_example(serialized, features, name=None, example_names=None):
|
||||
"""Parse `Example` protos into a `dict` of labeled tensors.
|
||||
|
||||
See tf.parse_example.
|
||||
|
||||
Args:
|
||||
serialized: A 1-D LabeledTensor of strings, a batch of binary serialized
|
||||
`Example` protos.
|
||||
features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
|
||||
values.
|
||||
name: A name for this operation (optional).
|
||||
example_names: A vector (1-D Tensor) of strings (optional), the names of
|
||||
the serialized protos in the batch.
|
||||
|
||||
Returns:
|
||||
A `dict` mapping feature keys to `LabeledTensor` values. The single axis
|
||||
from `serialized` will be prepended to the axes provided by each feature.
|
||||
|
||||
Raises:
|
||||
ValueError: if any feature is invalid.
|
||||
"""
|
||||
serialized = core.convert_to_labeled_tensor(serialized)
|
||||
unlabeled_features = _labeled_to_unlabeled_features(features)
|
||||
|
||||
unlabeled_parsed = parsing_ops.parse_example(
|
||||
serialized.tensor, unlabeled_features, name, example_names)
|
||||
|
||||
parsed = {}
|
||||
for name, parsed_feature in unlabeled_parsed.items():
|
||||
axes = list(serialized.axes.values()) + features[name].axes
|
||||
parsed[name] = core.LabeledTensor(parsed_feature, axes)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
@tc.returns(tc.Dict(string_types, core.LabeledTensor))
|
||||
@tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, FixedLenFeature),
|
||||
tc.Optional(string_types), object)
|
||||
def parse_single_example(serialized, features, name=None, example_names=None):
|
||||
"""Parses a single `Example` proto.
|
||||
|
||||
See tf.parse_single_example.
|
||||
|
||||
Args:
|
||||
serialized: A scalar string Tensor or LabeledTensor, a single serialized
|
||||
Example.
|
||||
features: A `dict` mapping feature keys to `labeled_tensor.FixedLenFeature`
|
||||
values.
|
||||
name: A name for this operation (optional).
|
||||
example_names: (Optional) A scalar string Tensor, the associated name.
|
||||
|
||||
Returns:
|
||||
A `dict` mapping feature keys to `LabeledTensor` values.
|
||||
|
||||
Raises:
|
||||
ValueError: if any feature is invalid.
|
||||
"""
|
||||
serialized = core.convert_to_labeled_tensor(serialized)
|
||||
unlabeled_features = _labeled_to_unlabeled_features(features)
|
||||
|
||||
unlabeled_parsed = parsing_ops.parse_single_example(
|
||||
serialized.tensor, unlabeled_features, name, example_names)
|
||||
|
||||
parsed = {}
|
||||
for name, parsed_feature in unlabeled_parsed.items():
|
||||
parsed[name] = core.LabeledTensor(parsed_feature, features[name].axes)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
@tc.returns(core.LabeledTensor)
|
||||
@tc.accepts(dtypes.DType, tc.Collection(tc.Union(string_types, core.AxisLike)),
|
||||
tc.Optional(string_types))
|
||||
def placeholder(dtype, axes, name=None):
|
||||
"""Create a placeholder for a labeled tensor.
|
||||
|
||||
For example:
|
||||
|
||||
lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])
|
||||
|
||||
See tf.placeholder for more details.
|
||||
|
||||
Args:
|
||||
dtype: The type of elements in the tensor to be fed.
|
||||
axes: sequence of strings (denoting axes of unknown size) and/or objects
|
||||
convertable to lt.Axis to label the result.
|
||||
name: Optional op name.
|
||||
|
||||
Returns:
|
||||
Placeholder labeled tensor.
|
||||
"""
|
||||
with ops.name_scope(name, 'lt_placeholder', []) as scope:
|
||||
axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
|
||||
for axis in axes])
|
||||
shape = [axis.size for axis in axes.values()]
|
||||
tensor = array_ops.placeholder(dtype, shape, name=scope)
|
||||
return core.LabeledTensor(tensor, axes)
|
106
tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py
Normal file
106
tensorflow/contrib/labeled_tensor/python/ops/io_ops_test.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import io_ops
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import test_util
|
||||
|
||||
|
||||
class ParseBase(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
super(ParseBase, self).setUp()
|
||||
examples = [
|
||||
tf.train.Example(features=tf.train.Features(feature={
|
||||
'a': tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=[1])),
|
||||
'b': tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=[2, 3, 4])),
|
||||
})),
|
||||
tf.train.Example(features=tf.train.Features(feature={
|
||||
'a': tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=[5])),
|
||||
'b': tf.train.Feature(
|
||||
int64_list=tf.train.Int64List(value=[6, 7, 8])),
|
||||
})),
|
||||
]
|
||||
self.serialized = core.LabeledTensor(
|
||||
tf.constant([ex.SerializeToString() for ex in examples]), ['batch'])
|
||||
self.features = {'a': io_ops.FixedLenFeature([], tf.int64),
|
||||
'b': io_ops.FixedLenFeature([('x', 3)], tf.int64)}
|
||||
|
||||
|
||||
class TestParseExample(ParseBase):
|
||||
|
||||
def test(self):
|
||||
expected_a = core.LabeledTensor(tf.constant([1, 5]), ['batch'])
|
||||
expected_b = core.LabeledTensor(tf.constant([[2, 3, 4], [6, 7, 8]]),
|
||||
['batch', 'x'])
|
||||
parsed = io_ops.parse_example(self.serialized, self.features)
|
||||
self.assertLabeledTensorsEqual(expected_a, parsed['a'])
|
||||
self.assertLabeledTensorsEqual(expected_b, parsed['b'])
|
||||
|
||||
def test_placeholder(self):
|
||||
serialized = core.LabeledTensor(tf.placeholder(tf.string, [None]),
|
||||
['batch'])
|
||||
# should not raise
|
||||
io_ops.parse_example(serialized, self.features)
|
||||
|
||||
|
||||
class TestParseSingleExample(ParseBase):
|
||||
|
||||
def test(self):
|
||||
expected_a = core.LabeledTensor(tf.constant(1), [])
|
||||
expected_b = core.LabeledTensor(tf.constant([2, 3, 4]), ['x'])
|
||||
parsed = io_ops.parse_single_example(self.serialized[0], self.features)
|
||||
self.assertLabeledTensorsEqual(expected_a, parsed['a'])
|
||||
self.assertLabeledTensorsEqual(expected_b, parsed['b'])
|
||||
|
||||
def test_unknown_size(self):
|
||||
features = {'a': io_ops.FixedLenFeature([('x', None)], tf.int64)}
|
||||
serialized = tf.placeholder(tf.string, [])
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown size'):
|
||||
io_ops.parse_single_example(serialized, features)
|
||||
|
||||
|
||||
class PlaceholderTest(test_util.Base):
|
||||
|
||||
def test_name(self):
|
||||
placeholder_lt = io_ops.placeholder(tf.float32, [])
|
||||
self.assertIn('lt_placeholder', placeholder_lt.name)
|
||||
|
||||
def test(self):
|
||||
placeholder_lt = io_ops.placeholder(tf.float32,
|
||||
['batch', ('x', ['a', 'b'])])
|
||||
self.assertEqual(placeholder_lt.dtype, tf.float32)
|
||||
self.assertEqual(placeholder_lt.axes,
|
||||
core.Axes([('batch', None), ('x', ['a', 'b'])]))
|
||||
|
||||
def test_feed(self):
|
||||
sess = tf.Session()
|
||||
placeholder_lt = io_ops.placeholder(tf.float32, [])
|
||||
two_times = 2.0 * placeholder_lt
|
||||
result = sess.run(two_times, {placeholder_lt.tensor: 1})
|
||||
self.assertEqual(result, 2.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
42
tensorflow/contrib/labeled_tensor/python/ops/nn.py
Normal file
42
tensorflow/contrib/labeled_tensor/python/ops/nn.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Neural network ops for LabeledTensors."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.python.ops import nn
|
||||
|
||||
|
||||
relu = core.define_unary_op('relu', nn.relu)
|
||||
relu6 = core.define_unary_op('relu6', nn.relu6)
|
||||
crelu = core.define_unary_op('crelu', nn.crelu)
|
||||
elu = core.define_unary_op('elu', nn.elu)
|
||||
softplus = core.define_unary_op('softplus', nn.softplus)
|
||||
|
||||
l2_loss = core.define_unary_op('l2_loss', nn.l2_loss)
|
||||
sigmoid_cross_entropy_with_logits = core.define_binary_op(
|
||||
'sigmoid_cross_entropy_with_logits',
|
||||
nn.sigmoid_cross_entropy_with_logits)
|
||||
softmax = core.define_unary_op('softmax', nn.softmax)
|
||||
log_softmax = core.define_unary_op('log_softmax', nn.log_softmax)
|
||||
softmax_cross_entropy_with_logits = core.define_binary_op(
|
||||
'softmax_cross_entropy_with_logits',
|
||||
nn.softmax_cross_entropy_with_logits)
|
||||
sparse_softmax_cross_entropy_with_logits = core.define_binary_op(
|
||||
'sparse_softmax_cross_entropy_with_logits',
|
||||
nn.sparse_softmax_cross_entropy_with_logits)
|
70
tensorflow/contrib/labeled_tensor/python/ops/nn_test.py
Normal file
70
tensorflow/contrib/labeled_tensor/python/ops/nn_test.py
Normal file
@ -0,0 +1,70 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import nn
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import test_util
|
||||
|
||||
|
||||
class NNTests(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
super(NNTests, self).setUp()
|
||||
self.axes = ['x']
|
||||
self.original_lt = core.LabeledTensor([0.0, 0.5, 1.0], self.axes)
|
||||
self.other_lt = 1 - self.original_lt
|
||||
|
||||
def test_unary_ops(self):
|
||||
ops = [
|
||||
('relu', tf.nn.relu, nn.relu),
|
||||
('relu6', tf.nn.relu6, nn.relu6),
|
||||
('crelu', tf.nn.crelu, nn.crelu),
|
||||
('elu', tf.nn.elu, nn.elu),
|
||||
('softplus', tf.nn.softplus, nn.softplus),
|
||||
('l2_loss', tf.nn.l2_loss, nn.l2_loss),
|
||||
('softmax', tf.nn.softmax, nn.softmax),
|
||||
('log_softmax', tf.nn.log_softmax, nn.log_softmax),
|
||||
]
|
||||
for op_name, tf_op, lt_op in ops:
|
||||
golden_tensor = tf_op(self.original_lt.tensor)
|
||||
golden_lt = core.LabeledTensor(golden_tensor, self.axes)
|
||||
actual_lt = lt_op(self.original_lt)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
|
||||
|
||||
def test_binary_ops(self):
|
||||
ops = [
|
||||
('sigmoid_cross_entropy_with_logits',
|
||||
tf.nn.sigmoid_cross_entropy_with_logits,
|
||||
nn.sigmoid_cross_entropy_with_logits),
|
||||
('softmax_cross_entropy_with_logits',
|
||||
tf.nn.softmax_cross_entropy_with_logits,
|
||||
nn.softmax_cross_entropy_with_logits),
|
||||
('sparse_softmax_cross_entropy_with_logits',
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits,
|
||||
nn.sparse_softmax_cross_entropy_with_logits),
|
||||
]
|
||||
for op_name, tf_op, lt_op in ops:
|
||||
golden_tensor = tf_op(self.original_lt.tensor, self.other_lt.tensor)
|
||||
golden_lt = core.LabeledTensor(golden_tensor, self.axes)
|
||||
actual_lt = lt_op(self.original_lt, self.other_lt)
|
||||
self.assertIn(op_name, actual_lt.name)
|
||||
self.assertLabeledTensorsEqual(golden_lt, actual_lt)
|
1207
tensorflow/contrib/labeled_tensor/python/ops/ops.py
Normal file
1207
tensorflow/contrib/labeled_tensor/python/ops/ops.py
Normal file
File diff suppressed because it is too large
Load Diff
918
tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
Normal file
918
tensorflow/contrib/labeled_tensor/python/ops/ops_test.py
Normal file
@ -0,0 +1,918 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from six.moves import range # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import ops
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import test_util
|
||||
|
||||
|
||||
class Base(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
super(Base, self).setUp()
|
||||
|
||||
self.x_size = 7
|
||||
self.channel_size = 3
|
||||
self.z_size = 4
|
||||
self.probs_size = 11
|
||||
|
||||
tensor = tf.range(0, self.x_size * self.channel_size * self.z_size *
|
||||
self.probs_size)
|
||||
tensor = tf.reshape(tensor, [self.x_size, self.channel_size, self.z_size,
|
||||
self.probs_size])
|
||||
a0 = ('x', range(self.x_size))
|
||||
a1 = ('channel', ['red', 'green', 'blue'])
|
||||
a2 = 'z'
|
||||
a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
|
||||
|
||||
self.tensor = tensor
|
||||
self.a0 = a0
|
||||
self.a1 = a1
|
||||
self.a2 = a2
|
||||
self.a2_resolved = ('z', self.z_size)
|
||||
self.a3 = a3
|
||||
self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
|
||||
|
||||
self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
|
||||
self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
|
||||
self.channel_probs_lt = core.slice_function(self.original_lt, {'x': 3,
|
||||
'z': 0})
|
||||
|
||||
|
||||
class SelectTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': 'green'})
|
||||
self.assertIn('lt_select', select_lt.name)
|
||||
|
||||
def test_scalar(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': 'green'})
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], [self.a0, self.a2,
|
||||
self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slice(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})
|
||||
a1_sliced = ('channel', ['red', 'green'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
|
||||
[self.a0, a1_sliced, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_slices(self):
|
||||
select_lt = ops.select(self.original_lt, {'x': slice(1, 4),
|
||||
'channel': slice('green', None)})
|
||||
|
||||
a0_sliced = ('x', range(1, 5))
|
||||
a1_sliced = ('channel', ['green', 'blue'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
|
||||
[a0_sliced, a1_sliced, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_list(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
|
||||
a1_sliced = ('channel', ['red', 'green'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
|
||||
[self.a0, a1_sliced, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_list_one_item(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': ['red']})
|
||||
a1_sliced = ('channel', ['red'])
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
|
||||
[self.a0, a1_sliced, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_list_zero_items(self):
|
||||
select_lt = ops.select(self.original_lt, {'channel': []})
|
||||
golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
|
||||
[self.a0, 'channel', self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_scalars(self):
|
||||
select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})
|
||||
golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :],
|
||||
[self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(select_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.select(self.original_lt, {'foo': 1})
|
||||
with self.assertRaises(ValueError):
|
||||
ops.select(self.original_lt, {'z': 1})
|
||||
with self.assertRaises(KeyError):
|
||||
ops.select(self.original_lt, {'channel': 'purple'})
|
||||
with self.assertRaises(KeyError):
|
||||
ops.select(self.original_lt, {'channel': ['red', 'purple']})
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
|
||||
|
||||
|
||||
class ConcatTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(ConcatTest, self).setUp()
|
||||
|
||||
self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
|
||||
self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
|
||||
self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
|
||||
|
||||
def test_name(self):
|
||||
concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')
|
||||
self.assertIn('lt_concat', concat_lt.name)
|
||||
|
||||
def test(self):
|
||||
concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
|
||||
golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
|
||||
|
||||
self.assertLabeledTensorsEqual(concat_lt, golden_lt)
|
||||
|
||||
def test_transposed(self):
|
||||
green_transposed = core.transpose(self.green_lt,
|
||||
['probs', 'channel', 'z', 'x'])
|
||||
with self.assertRaises(ValueError):
|
||||
ops.concat([self.red_lt, green_transposed], 'channel')
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.concat([], 'channel')
|
||||
with self.assertRaises(ValueError):
|
||||
ops.concat([self.red_lt, self.red_lt], 'channel')
|
||||
with self.assertRaises(ValueError):
|
||||
ops.concat([self.red_lt, self.red_lt], 'foo')
|
||||
|
||||
|
||||
class PackTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
|
||||
self.assertIn('lt_pack', pack_lt.name)
|
||||
|
||||
def test(self):
|
||||
pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.stack([self.original_lt.tensor, self.original_lt.tensor]),
|
||||
['batch', self.a0, self.a1, self.a2, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(pack_lt, golden_lt)
|
||||
|
||||
def test_axis(self):
|
||||
pack_lt = ops.pack([self.original_lt, self.original_lt],
|
||||
new_axis='batch',
|
||||
axis_position=4)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.stack(
|
||||
[self.original_lt.tensor, self.original_lt.tensor], axis=4),
|
||||
[self.a0, self.a1, self.a2, self.a3, 'batch'])
|
||||
|
||||
self.assertLabeledTensorsEqual(pack_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.pack([self.original_lt, self.original_lt], 'channel')
|
||||
|
||||
|
||||
class UnpackTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
unpack_lts = ops.unpack(self.original_lt)
|
||||
for t in unpack_lts:
|
||||
self.assertIn('lt_unpack', t.name)
|
||||
|
||||
def test(self):
|
||||
unpack_lt = ops.unpack(self.original_lt)[0]
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.unstack(self.original_lt.tensor)[0], [self.a1, self.a2, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
|
||||
|
||||
def test_axis(self):
|
||||
unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.unstack(
|
||||
self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])
|
||||
|
||||
self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.unpack(self.original_lt, axis_name='not_found')
|
||||
|
||||
|
||||
class ReshapeTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
|
||||
self.assertIn('lt_reshape', reshape_lt.name)
|
||||
|
||||
def test_identity(self):
|
||||
reshape_lt = ops.reshape(self.original_lt, self.original_lt.axes.keys(),
|
||||
self.original_lt.axes.values())
|
||||
self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
|
||||
|
||||
def test_known_size(self):
|
||||
new_dim_size = self.channel_size * self.z_size * self.probs_size
|
||||
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
|
||||
[('new_dim', new_dim_size)])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
|
||||
[self.original_lt.axes['x'], 'new_dim'])
|
||||
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
|
||||
|
||||
def test_unknown_size(self):
|
||||
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
|
||||
['new_dim'])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
|
||||
[self.original_lt.axes['x'], 'new_dim'])
|
||||
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
|
||||
|
||||
def test_unknown_dimension(self):
|
||||
orig_lt = core.LabeledTensor(tf.placeholder(tf.float32, [None]), ['x'])
|
||||
reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
|
||||
self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
|
||||
with self.test_session() as sess:
|
||||
result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
|
||||
np.testing.assert_array_equal(result, [[1], [2]])
|
||||
|
||||
def test_with_labels(self):
|
||||
new_dim_size = self.channel_size * self.z_size * self.probs_size
|
||||
reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
|
||||
[('new_dim', range(new_dim_size))])
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reshape(self.original_lt.tensor, [self.x_size, -1]),
|
||||
[self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
|
||||
self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
|
||||
ops.reshape(self.original_lt, ['foo'], ['bar'])
|
||||
with self.assertRaisesRegexp(core.AxisOrderError,
|
||||
'not a slice of axis names'):
|
||||
ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
|
||||
with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
|
||||
ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
|
||||
|
||||
|
||||
class RenameAxisTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
|
||||
self.assertIn('lt_rename_axis', rename_axis_lt.name)
|
||||
|
||||
def test_identity(self):
|
||||
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')
|
||||
self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)
|
||||
|
||||
def test_new_name(self):
|
||||
rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
|
||||
expected_axes = [(name if name != 'channel' else 'foo', axis.value)
|
||||
for name, axis in self.original_lt.axes.items()]
|
||||
expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)
|
||||
self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
|
||||
ops.rename_axis(self.original_lt, 'foo', 'bar')
|
||||
|
||||
|
||||
class BatchTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(BatchTest, self).setUp()
|
||||
|
||||
tensors = []
|
||||
for i in range(10):
|
||||
offset_lt = core.LabeledTensor(tf.constant(i), [])
|
||||
tensors.append(core.add(self.original_lt, offset_lt))
|
||||
self.pack_lt = ops.pack(tensors, 'batch')
|
||||
|
||||
def test_name(self):
|
||||
batch_ops = ops.batch([self.pack_lt, self.pack_lt],
|
||||
batch_size=2,
|
||||
enqueue_many=True)
|
||||
for bo in batch_ops:
|
||||
self.assertIn('lt_batch', bo.name)
|
||||
|
||||
def test_enqueue_many(self):
|
||||
[batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)
|
||||
self.assertEqual(len(batch_2_op.axes['batch']), 2)
|
||||
|
||||
[batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
|
||||
|
||||
self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
|
||||
|
||||
def test_no_enqueue_many(self):
|
||||
[batch_2_op] = ops.batch([self.original_lt], batch_size=2)
|
||||
self.assertEqual(len(batch_2_op.axes['batch']), 2)
|
||||
|
||||
[batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
|
||||
|
||||
self.assertLabeledTensorsEqual(
|
||||
ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.batch([self.original_lt], 3, enqueue_many=True)
|
||||
|
||||
def test_allow_smaller_final_batch(self):
|
||||
[batch_2_op] = ops.batch([self.original_lt], batch_size=2,
|
||||
allow_smaller_final_batch=True)
|
||||
self.assertEqual(batch_2_op.axes['batch'].size, None)
|
||||
|
||||
|
||||
class ShuffleBatchTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(ShuffleBatchTest, self).setUp()
|
||||
|
||||
tensors = []
|
||||
for i in range(10):
|
||||
offset_lt = core.LabeledTensor(tf.constant(i), [])
|
||||
tensors.append(core.add(self.original_lt, offset_lt))
|
||||
self.pack_lt = ops.pack(tensors, 'batch')
|
||||
|
||||
def test_name(self):
|
||||
batch_lts = ops.shuffle_batch([self.pack_lt, self.pack_lt],
|
||||
batch_size=2,
|
||||
enqueue_many=True)
|
||||
for blt in batch_lts:
|
||||
self.assertIn('lt_shuffle_batch', blt.name)
|
||||
|
||||
def test_enqueue_many(self):
|
||||
[batch_2_lt] = ops.shuffle_batch([self.pack_lt],
|
||||
batch_size=2,
|
||||
enqueue_many=True,
|
||||
min_after_dequeue=8,
|
||||
seed=0)
|
||||
self.assertEqual(len(batch_2_lt.axes['batch']), 2)
|
||||
|
||||
[batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)
|
||||
|
||||
self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)
|
||||
[batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])
|
||||
self.assertFalse((batch_10 == pack).all())
|
||||
|
||||
def test_allow_smaller_final_batch(self):
|
||||
[batch_2_op] = ops.shuffle_batch([self.original_lt], batch_size=2,
|
||||
allow_smaller_final_batch=True)
|
||||
self.assertEqual(batch_2_op.axes['batch'].size, None)
|
||||
|
||||
|
||||
class RandomCropTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
|
||||
self.assertIn('lt_random_crop', crop_lt.name)
|
||||
|
||||
def test_single(self):
|
||||
crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
|
||||
|
||||
self.assertEqual(
|
||||
core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),
|
||||
crop_lt.axes)
|
||||
|
||||
def test_double(self):
|
||||
crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})
|
||||
|
||||
self.assertEqual(
|
||||
core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
|
||||
crop_lt.axes)
|
||||
|
||||
def test_size1(self):
|
||||
crop_lt = ops.random_crop(self.original_lt, {'probs': 1})
|
||||
|
||||
self.assertEqual(
|
||||
core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
|
||||
crop_lt.axes)
|
||||
|
||||
def test_different_seeds(self):
|
||||
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
|
||||
'channel': 2},
|
||||
seed=0)
|
||||
crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
|
||||
'channel': 2},
|
||||
seed=1)
|
||||
|
||||
self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)
|
||||
[crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])
|
||||
self.assertFalse((crop_0 == crop_1).all())
|
||||
|
||||
def test_identical_seeds(self):
|
||||
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
|
||||
'channel': 2},
|
||||
seed=0)
|
||||
crop_1_lt = ops.random_crop(self.original_lt, {'probs': 3,
|
||||
'channel': 2},
|
||||
seed=0)
|
||||
|
||||
self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
|
||||
|
||||
def test_crop_idempotent(self):
|
||||
crop_0_lt = ops.random_crop(self.original_lt, {'probs': 3,
|
||||
'channel': 2},
|
||||
seed=0)
|
||||
crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)
|
||||
|
||||
self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.random_crop(self.original_lt, {'foobar': 2})
|
||||
|
||||
|
||||
class MapFnTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
map_lt = ops.map_fn(core.identity, self.original_lt)
|
||||
self.assertIn('lt_map_fn', map_lt.name)
|
||||
|
||||
def test_identity(self):
|
||||
map_lt = ops.map_fn(core.identity, self.original_lt)
|
||||
self.assertLabeledTensorsEqual(map_lt, self.original_lt)
|
||||
|
||||
def test_callable_object(self):
|
||||
|
||||
class Identity(object):
|
||||
|
||||
def __call__(self, other):
|
||||
return other
|
||||
|
||||
map_lt = ops.map_fn(Identity(), self.original_lt)
|
||||
self.assertLabeledTensorsEqual(map_lt, self.original_lt)
|
||||
|
||||
def test_slice(self):
|
||||
map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),
|
||||
self.original_lt)
|
||||
slice_lt = core.slice_function(self.original_lt, {'channel': 1})
|
||||
self.assertLabeledTensorsEqual(map_lt, slice_lt)
|
||||
|
||||
|
||||
class SqueezeTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(SqueezeTest, self).setUp()
|
||||
|
||||
self.squeezable_lt = core.slice_function(self.original_lt,
|
||||
{'channel': slice(0, 1),
|
||||
'probs': slice(0, 1)})
|
||||
|
||||
def test_name(self):
|
||||
squeeze_lt = ops.squeeze(self.squeezable_lt)
|
||||
self.assertIn('lt_squeeze', squeeze_lt.name)
|
||||
|
||||
def test_none(self):
|
||||
none_lt = ops.squeeze(self.squeezable_lt, None)
|
||||
axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])
|
||||
self.assertLabeledTensorsEqual(none_lt, axes_lt)
|
||||
|
||||
def test(self):
|
||||
squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])
|
||||
golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})
|
||||
self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ops.squeeze(self.original_lt, ['channel'])
|
||||
with self.assertRaises(ValueError):
|
||||
ops.squeeze(self.squeezable_lt, ['foo'])
|
||||
|
||||
|
||||
class MatMulTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
x_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
|
||||
matmul_lt = ops.matmul(x_lt, x_lt)
|
||||
self.assertIn('lt_matmul', matmul_lt.name)
|
||||
|
||||
def test_vector_vector(self):
|
||||
x_lt = core.LabeledTensor(tf.range(3), ['x'])
|
||||
matmul_lt = ops.matmul(x_lt, x_lt)
|
||||
golden_lt = core.convert_to_labeled_tensor(5)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
def test_matrix_vector(self):
|
||||
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
|
||||
y_lt = core.LabeledTensor(tf.range(3), ['y'])
|
||||
|
||||
matmul_lt = ops.matmul(xy_lt, y_lt)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.matmul(xy_lt.tensor, tf.reshape(y_lt.tensor, (-1, 1)))[:, 0], ['x'])
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
matmul_lt = ops.matmul(y_lt, xy_lt)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
def test_matrix_matrix(self):
|
||||
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
|
||||
yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
|
||||
|
||||
matmul_lt = ops.matmul(xy_lt, yz_lt)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
|
||||
|
||||
matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
matmul_lt = ops.matmul(yz_lt, xy_lt)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
|
||||
|
||||
def test_matrix_matrix_axis_order(self):
|
||||
xy_lt = core.LabeledTensor(tf.reshape(tf.range(6), (2, 3)), ['x', 'y'])
|
||||
yz_lt = core.LabeledTensor(tf.reshape(tf.range(12), (3, 4)), ['y', 'z'])
|
||||
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
|
||||
|
||||
with core.axis_order_scope(['x', 'y', 'z']):
|
||||
|
||||
matmul_lt = ops.matmul(xy_lt, yz_lt)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
matmul_lt = ops.matmul(yz_lt, xy_lt)
|
||||
self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
|
||||
|
||||
def test_invalid(self):
|
||||
scalar_lt = core.LabeledTensor(tf.ones(()), [])
|
||||
x_lt = core.LabeledTensor(tf.ones((2,)), ['x'])
|
||||
x2_lt = core.LabeledTensor(tf.ones((3,)), ['x'])
|
||||
y_lt = core.LabeledTensor(tf.ones((3,)), ['y'])
|
||||
xy_lt = core.LabeledTensor(tf.ones((2, 3)), ['x', 'y'])
|
||||
xyz_lt = core.LabeledTensor(tf.ones((2, 3, 1)), ['x', 'y', 'z'])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
|
||||
ops.matmul(x_lt, scalar_lt)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.matmul(x_lt, xyz_lt)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
|
||||
ops.matmul(x_lt, y_lt)
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.matmul(xy_lt, xy_lt)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'does not match'):
|
||||
ops.matmul(x_lt, x2_lt)
|
||||
|
||||
|
||||
class ReduceSumTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_sum', sum_lt.name)
|
||||
|
||||
def test_drop_axis(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
|
||||
|
||||
def test_drop_scalar_axis(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt, 'channel')
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_sum(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
|
||||
|
||||
def test_keep_axis(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_sum(self.original_lt.tensor,
|
||||
1, keep_dims=True),
|
||||
[self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
|
||||
|
||||
def test_keep_scalar_axis(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_sum(self.original_lt.tensor,
|
||||
1, keep_dims=True),
|
||||
[self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
|
||||
|
||||
def test_scalar(self):
|
||||
scalar_lt = core.LabeledTensor(tf.constant(42), [])
|
||||
reduce_lt = ops.reduce_sum(scalar_lt, [])
|
||||
self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)
|
||||
|
||||
def test_empty_list(self):
|
||||
reduce_lt = ops.reduce_sum(self.original_lt, [])
|
||||
self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)
|
||||
|
||||
def test_none(self):
|
||||
sum_lt = ops.reduce_sum(self.original_lt)
|
||||
golden_lt = core.LabeledTensor(tf.reduce_sum(self.original_lt.tensor), [])
|
||||
self.assertLabeledTensorsEqual(sum_lt, golden_lt)
|
||||
|
||||
def test_function_docstring_and_name(self):
|
||||
self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)
|
||||
self.assertEqual('reduce_sum', ops.reduce_sum.__name__)
|
||||
|
||||
|
||||
class ReduceMeanTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_mean', actual_lt.name)
|
||||
|
||||
def test(self):
|
||||
actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_mean(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(actual_lt, golden_lt)
|
||||
|
||||
|
||||
class ReduceProdTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
result_lt = ops.reduce_prod(self.original_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_prod', result_lt.name)
|
||||
|
||||
def test(self):
|
||||
result_lt = ops.reduce_prod(self.original_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_prod(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(result_lt, golden_lt)
|
||||
|
||||
|
||||
class ReduceMinTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
result_lt = ops.reduce_min(self.original_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_min', result_lt.name)
|
||||
|
||||
def test(self):
|
||||
result_lt = ops.reduce_min(self.original_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_min(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(result_lt, golden_lt)
|
||||
|
||||
|
||||
class ReduceMaxTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
result_lt = ops.reduce_max(self.original_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_max', result_lt.name)
|
||||
|
||||
def test(self):
|
||||
result_lt = ops.reduce_max(self.original_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_max(self.original_lt.tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(result_lt, golden_lt)
|
||||
|
||||
|
||||
class BaseReduceBoolean(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(BaseReduceBoolean, self).setUp()
|
||||
self.bool_tensor = tf.cast(self.original_lt.tensor > 5, tf.bool)
|
||||
self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)
|
||||
|
||||
|
||||
class ReduceAllTest(BaseReduceBoolean):
|
||||
|
||||
def test_name(self):
|
||||
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_all', result_lt.name)
|
||||
|
||||
def test(self):
|
||||
result_lt = ops.reduce_all(self.bool_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(result_lt, golden_lt)
|
||||
|
||||
|
||||
class ReduceAnyTest(BaseReduceBoolean):
|
||||
|
||||
def test_name(self):
|
||||
result_lt = ops.reduce_any(self.bool_lt, {'channel'})
|
||||
self.assertIn('lt_reduce_any', result_lt.name)
|
||||
|
||||
def test(self):
|
||||
result_lt = ops.reduce_any(self.bool_lt, {'channel'})
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(result_lt, golden_lt)
|
||||
|
||||
|
||||
class TileTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
tile_lt = ops.tile(self.original_lt, {'z': 2})
|
||||
self.assertIn('lt_tile', tile_lt.name)
|
||||
|
||||
def test(self):
|
||||
for multiple in [2, tf.constant(2)]:
|
||||
tile_lt = ops.tile(self.original_lt, {'z': multiple})
|
||||
golden_op = tf.tile(self.original_lt.tensor, [1, 1, multiple, 1])
|
||||
golden_axes = ['z' if axis.name == 'z' else axis
|
||||
for axis in self.original_lt.axes.values()]
|
||||
golden_lt = core.LabeledTensor(golden_op, golden_axes)
|
||||
self.assertLabeledTensorsEqual(tile_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
|
||||
ops.tile(self.original_lt, {'foo': 5})
|
||||
with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):
|
||||
ops.tile(self.original_lt, {'x': 5})
|
||||
|
||||
|
||||
class PadTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
|
||||
'channel': ([], ['alpha'])})
|
||||
self.assertIn('lt_pad', pad_lt.name)
|
||||
|
||||
def test(self):
|
||||
pad_lt = ops.pad(self.original_lt, {'x': (1, 1),
|
||||
'channel': ([], ['alpha'])})
|
||||
|
||||
golden_op = tf.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],
|
||||
[0, 0]])
|
||||
golden_axes = [('x', self.x_size + 2),
|
||||
('channel', ['red', 'green', 'blue', 'alpha']), self.a2,
|
||||
self.a3]
|
||||
golden_lt = core.LabeledTensor(golden_op, golden_axes)
|
||||
self.assertLabeledTensorsEqual(pad_lt, golden_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
|
||||
ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})
|
||||
|
||||
|
||||
class ConstantTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
constant_lt = ops.constant(1)
|
||||
self.assertIn('lt_constant', constant_lt.name)
|
||||
|
||||
def test_scalar(self):
|
||||
constant_lt = ops.constant(1)
|
||||
golden_lt = core.LabeledTensor(tf.constant(1), [])
|
||||
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
|
||||
|
||||
def test_infer_shape(self):
|
||||
constant_lt = ops.constant([1, 2], axes=['x'])
|
||||
golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
|
||||
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
|
||||
|
||||
def test_specify_shape(self):
|
||||
constant_lt = ops.constant(1, axes=[('x', 3)])
|
||||
golden_lt = core.LabeledTensor(tf.constant(1, shape=(3,)), ['x'])
|
||||
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
|
||||
|
||||
def test_existing_axes(self):
|
||||
golden_lt = core.LabeledTensor(tf.constant([1, 2]), ['x'])
|
||||
constant_lt = ops.constant([1, 2], axes=golden_lt.axes)
|
||||
self.assertLabeledTensorsEqual(constant_lt, golden_lt)
|
||||
|
||||
|
||||
class ZerosLikeTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
like_lt = ops.zeros_like(self.original_lt)
|
||||
self.assertIn('lt_zeros_like', like_lt.name)
|
||||
|
||||
def test(self):
|
||||
like_lt = ops.zeros_like(self.original_lt)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.zeros_like(self.original_lt.tensor), self.original_lt.axes)
|
||||
self.assertLabeledTensorsEqual(like_lt, golden_lt)
|
||||
|
||||
|
||||
class OnesLikeTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
like_lt = ops.ones_like(self.original_lt)
|
||||
self.assertIn('lt_ones_like', like_lt.name)
|
||||
|
||||
def test(self):
|
||||
like_lt = ops.ones_like(self.original_lt)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.ones_like(self.original_lt.tensor), self.original_lt.axes)
|
||||
self.assertLabeledTensorsEqual(like_lt, golden_lt)
|
||||
|
||||
|
||||
class CastTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
cast_lt = ops.cast(self.original_lt, tf.float16)
|
||||
self.assertIn('lt_cast', cast_lt.name)
|
||||
|
||||
def test(self):
|
||||
cast_lt = ops.cast(self.original_lt, tf.float16)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.cast(self.original_lt.tensor, tf.float16), self.original_lt.axes)
|
||||
self.assertLabeledTensorsEqual(cast_lt, golden_lt)
|
||||
|
||||
|
||||
class VerifyTensorAllFiniteTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(VerifyTensorAllFiniteTest, self).setUp()
|
||||
|
||||
self.finite_lt = core.LabeledTensor(tf.constant(42.0), [])
|
||||
self.nan_lt = core.LabeledTensor(tf.constant(np.nan), [])
|
||||
|
||||
self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')
|
||||
self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')
|
||||
|
||||
def test_name(self):
|
||||
self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)
|
||||
self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)
|
||||
|
||||
def test_finite(self):
|
||||
self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)
|
||||
|
||||
def test_nan(self):
|
||||
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
|
||||
'Tensor had NaN values'):
|
||||
self.eval([self.checked_nan_lt])
|
||||
|
||||
|
||||
class BooleanMaskTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
|
||||
masked_lt = ops.boolean_mask(self.original_lt, mask)
|
||||
self.assertIn('lt_boolean_mask', masked_lt.name)
|
||||
|
||||
def test(self):
|
||||
mask = core.LabeledTensor(tf.range(7) > 3, [self.a0])
|
||||
masked_lt = ops.boolean_mask(self.original_lt, mask)
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.boolean_mask(self.original_lt.tensor, mask.tensor),
|
||||
['x', self.a1, self.a2, self.a3])
|
||||
self.assertLabeledTensorsEqual(masked_lt, golden_lt)
|
||||
|
||||
def test_invalid_rank(self):
|
||||
mask = core.LabeledTensor(tf.ones((7, 3)) > 3, [self.a0, self.a1])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
ops.boolean_mask(self.original_lt, mask)
|
||||
|
||||
def test_mismatched_axis(self):
|
||||
mask = core.LabeledTensor(tf.range(7) > 3, ['foo'])
|
||||
with self.assertRaisesRegexp(ValueError, 'not equal'):
|
||||
ops.boolean_mask(self.original_lt, mask)
|
||||
|
||||
|
||||
class WhereTest(Base):
|
||||
|
||||
def test_name(self):
|
||||
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
|
||||
where_lt = ops.where(condition, condition, condition)
|
||||
self.assertIn('lt_where', where_lt.name)
|
||||
|
||||
def test(self):
|
||||
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
|
||||
x = core.LabeledTensor(tf.ones(5), ['x'])
|
||||
y = core.LabeledTensor(tf.zeros(5), ['x'])
|
||||
where_lt = ops.where(condition, x, y)
|
||||
|
||||
golden_lt = core.LabeledTensor(
|
||||
tf.concat(0, [tf.ones(3), tf.zeros(2)]), ['x'])
|
||||
self.assertLabeledTensorsEqual(where_lt, golden_lt)
|
||||
|
||||
def test_mismatched_axes(self):
|
||||
condition = core.LabeledTensor(tf.range(5) < 3, ['x'])
|
||||
with self.assertRaisesRegexp(ValueError, 'equal axes'):
|
||||
ops.where(condition, condition[:3], condition)
|
||||
with self.assertRaisesRegexp(ValueError, 'equal axes'):
|
||||
ops.where(condition, condition, condition[:3])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
131
tensorflow/contrib/labeled_tensor/python/ops/sugar.py
Normal file
131
tensorflow/contrib/labeled_tensor/python/ops/sugar.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tools to make it a bit easier to use LabeledTensor."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six import string_types
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import ops
|
||||
from tensorflow.python.framework import ops as tf_ops
|
||||
|
||||
|
||||
class ReshapeCoder(object):
|
||||
"""Utility class for mapping to and from another shape.
|
||||
|
||||
For example, say you have a function `crop_center` which expects a
|
||||
LabeledTensor with axes named ['batch', 'row', 'column', 'depth'], and
|
||||
you have a LabeledTensor `masked_image_lt` with axes ['batch', 'row',
|
||||
'column', 'channel', 'mask'].
|
||||
|
||||
To call `crop_center` with `masked_image_lt` you'd normally have to write:
|
||||
|
||||
>>> reshape_lt = lt.reshape(masked_image_lt, ['channel', 'mask'], ['depth'])
|
||||
>>> crop_lt = crop_center(reshape_lt)
|
||||
>>> result_lt = lt.reshape(crop_lt, ['depth'],
|
||||
... [masked_image_lt.axes['channel'], masked_image_lt.axes['mask']])
|
||||
|
||||
ReshapeCoder takes care of this renaming logic for you, allowing you to
|
||||
instead write:
|
||||
|
||||
>>> rc = ReshapeCoder(['channel', 'mask'], ['depth'])
|
||||
>>> result_lt = rc.decode(crop_center(rc.encode(masked_image_lt)))
|
||||
|
||||
Here, `decode` restores the original axes 'channel' and 'mask', so
|
||||
`crop_center` must not have modified the size of the 'depth' axis.
|
||||
"""
|
||||
|
||||
@tc.accepts(object, tc.Collection(str),
|
||||
tc.Collection(tc.Union(str, core.AxisLike)), tc.Optional(str))
|
||||
def __init__(self, existing_axis_names, new_axes, name=None):
|
||||
self._name = name
|
||||
self._existing_axis_names = existing_axis_names
|
||||
self._new_axes = new_axes
|
||||
|
||||
self._existing_axes = None
|
||||
|
||||
@tc.returns(core.LabeledTensor)
|
||||
@tc.accepts(object, core.LabeledTensorLike)
|
||||
def encode(self, labeled_tensor):
|
||||
"""Reshape the input to the target shape.
|
||||
|
||||
If called several times, the axes named in existing_axis_names must be
|
||||
identical.
|
||||
|
||||
Args:
|
||||
labeled_tensor: The input tensor.
|
||||
|
||||
Returns:
|
||||
The input reshaped to the target shape.
|
||||
|
||||
Raises:
|
||||
ValueError: If the axes in existing_axis_names don't match the axes of
|
||||
a tensor in a previous invocation of this method.
|
||||
"""
|
||||
with tf_ops.name_scope(self._name, 'lt_reshape_encode',
|
||||
[labeled_tensor]) as scope:
|
||||
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
|
||||
|
||||
reshape_lt = ops.reshape(labeled_tensor,
|
||||
self._existing_axis_names,
|
||||
self._new_axes,
|
||||
name=scope)
|
||||
|
||||
axes = [labeled_tensor.axes[n] for n in self._existing_axis_names]
|
||||
if self._existing_axes is not None and self._existing_axes != axes:
|
||||
raise ValueError(
|
||||
'input axes %r do not match axes from previous method call %r' %
|
||||
(axes, self._existing_axes))
|
||||
else:
|
||||
self._existing_axes = axes
|
||||
|
||||
return reshape_lt
|
||||
|
||||
@tc.returns(core.LabeledTensor)
|
||||
@tc.accepts(object, core.LabeledTensorLike)
|
||||
def decode(self, labeled_tensor):
|
||||
"""Reshape the input to the original shape.
|
||||
|
||||
This is the inverse of encode.
|
||||
Encode must have been called at least once prior to this method being
|
||||
called.
|
||||
|
||||
Args:
|
||||
labeled_tensor: The input tensor.
|
||||
|
||||
Returns:
|
||||
The input reshaped to the original shape.
|
||||
|
||||
Raises:
|
||||
ValueError: If this method was called before encode was called.
|
||||
"""
|
||||
if self._existing_axes is None:
|
||||
raise ValueError('decode called before encode')
|
||||
|
||||
with tf_ops.name_scope(self._name, 'lt_reshape_decode',
|
||||
[labeled_tensor]) as scope:
|
||||
labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
|
||||
|
||||
new_axis_names = [axis if isinstance(axis, string_types) else
|
||||
core.as_axis(axis).name for axis in self._new_axes]
|
||||
|
||||
return ops.reshape(labeled_tensor,
|
||||
new_axis_names,
|
||||
self._existing_axes,
|
||||
name=scope)
|
106
tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py
Normal file
106
tensorflow/contrib/labeled_tensor/python/ops/sugar_test.py
Normal file
@ -0,0 +1,106 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from six.moves import range # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import core
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import ops
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import sugar
|
||||
from tensorflow.contrib.labeled_tensor.python.ops import test_util
|
||||
|
||||
|
||||
class Base(test_util.Base):
|
||||
|
||||
def setUp(self):
|
||||
super(Base, self).setUp()
|
||||
|
||||
self.small_lt = core.LabeledTensor(tf.constant([1]), [('x', 1)])
|
||||
|
||||
|
||||
class ReshapeCoderTest(Base):
|
||||
|
||||
def setUp(self):
|
||||
super(ReshapeCoderTest, self).setUp()
|
||||
|
||||
self.batch_size = 8
|
||||
self.num_rows = 50
|
||||
self.num_columns = 100
|
||||
self.channels = ['red', 'green', 'blue']
|
||||
self.masks = [False, True]
|
||||
|
||||
tensor = tf.range(0, self.batch_size * self.num_rows * self.num_columns *
|
||||
len(self.channels) * len(self.masks))
|
||||
tensor = tf.reshape(tensor, [self.batch_size, self.num_rows,
|
||||
self.num_columns, len(self.channels),
|
||||
len(self.masks)])
|
||||
|
||||
self.batch_axis = ('batch', range(self.batch_size))
|
||||
self.row_axis = ('row', range(self.num_rows))
|
||||
self.column_axis = ('column', range(self.num_columns))
|
||||
self.channel_axis = ('channel', self.channels)
|
||||
self.mask_axis = ('mask', self.masks)
|
||||
|
||||
axes = [self.batch_axis, self.row_axis, self.column_axis, self.channel_axis,
|
||||
self.mask_axis]
|
||||
self.masked_image_lt = core.LabeledTensor(tensor, axes)
|
||||
|
||||
def test_name(self):
|
||||
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
|
||||
encode_lt = rc.encode(self.masked_image_lt)
|
||||
decode_lt = rc.decode(encode_lt)
|
||||
self.assertIn('lt_reshape_encode', encode_lt.name)
|
||||
self.assertIn('lt_reshape_decode', decode_lt.name)
|
||||
|
||||
def test_bijection_flat(self):
|
||||
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
|
||||
|
||||
encode_lt = rc.encode(self.masked_image_lt)
|
||||
golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
|
||||
('depth', len(self.channels) * len(self.masks))])
|
||||
self.assertEqual(encode_lt.axes, golden_axes)
|
||||
|
||||
decode_lt = rc.decode(encode_lt)
|
||||
self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
|
||||
|
||||
def test_bijection_with_labels(self):
|
||||
depth_axis = core.Axis('depth', range(len(self.channels) * len(self.masks)))
|
||||
rc = sugar.ReshapeCoder(['channel', 'mask'], [depth_axis,
|
||||
('other', ['label'])])
|
||||
|
||||
encode_lt = rc.encode(self.masked_image_lt)
|
||||
golden_axes = core.Axes([self.batch_axis, self.row_axis, self.column_axis,
|
||||
depth_axis, ('other', ['label'])])
|
||||
self.assertEqual(encode_lt.axes, golden_axes)
|
||||
|
||||
decode_lt = rc.decode(encode_lt)
|
||||
self.assertLabeledTensorsEqual(decode_lt, self.masked_image_lt)
|
||||
|
||||
def test_invalid_input(self):
|
||||
with self.assertRaises(ValueError):
|
||||
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
|
||||
rc.decode(self.masked_image_lt)
|
||||
with self.assertRaises(ValueError):
|
||||
rc = sugar.ReshapeCoder(['channel', 'mask'], ['depth'])
|
||||
rc.encode(self.masked_image_lt)
|
||||
rc.encode(ops.select(self.masked_image_lt, {'channel': 'red'}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
47
tensorflow/contrib/labeled_tensor/python/ops/test_util.py
Normal file
47
tensorflow/contrib/labeled_tensor/python/ops/test_util.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utils for writing tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class Base(tf.test.TestCase):
|
||||
"""A class with some useful methods for testing."""
|
||||
|
||||
def eval(self, tensors):
|
||||
with self.test_session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
|
||||
|
||||
try:
|
||||
results = sess.run(tensors)
|
||||
finally:
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
return results
|
||||
|
||||
def assertTensorsEqual(self, tensor_0, tensor_1):
|
||||
[tensor_0_eval, tensor_1_eval] = self.eval([tensor_0, tensor_1])
|
||||
self.assertAllEqual(tensor_0_eval, tensor_1_eval)
|
||||
|
||||
def assertLabeledTensorsEqual(self, tensor_0, tensor_1):
|
||||
self.assertEqual(tensor_0.axes, tensor_1.axes)
|
||||
self.assertTensorsEqual(tensor_0.tensor, tensor_1.tensor)
|
@ -90,6 +90,7 @@ sh_binary(
|
||||
":included_headers",
|
||||
":simple_console",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/contrib/labeled_tensor:all_files",
|
||||
"//tensorflow/contrib/ndlstm:all_files",
|
||||
"//tensorflow/contrib/session_bundle:all_files",
|
||||
"//tensorflow/contrib/slim:all_files",
|
||||
|
Loading…
Reference in New Issue
Block a user