STT-tensorflow/tensorflow/python/ops/op_def_library.py
Manjunath Kudlur 3dfd14421d TensorFlow: upstream changes to git.
Change 109418220
	Update WORKSPACE to use gmock.BUILD from google/protobuf instead of a duplicate.
	Update google/protobuf's commit hash to include damieng@'s commit.
Change 109417314
	TensorFlow: add .gitignore to ignore some in-tree modified files.
Change 109400051
	Optionally build full TensorFlow for Android.
	1. --define ANDROID_TYPES=__ANDROID_TYPES_FULL__ to register ops
	   for all types, not just float. Today this increases codesize
	   by ~700K when compiled for ARM, though only for clients who
	   request full type support.
	2. Add more ops to android_extended_ops, sufficient to train on
	   the linear regression baseball codelab.
Change 109388118
	Fix the option changed in templatize. Oops.
Change 109382553
	Allows setting a function name in an op's attr in the py frontend.
Change 109380896
	Remove assert_same_graph in favor of op_scope. Change the latter to handle tensor-like objects such as SparseTensor, IndexedSlices, and Variable.

Base CL: 109418322
2015-12-04 09:54:09 -08:00

672 lines
26 KiB
Python

# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to hold a library of OpDefs and use it to create Brain operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import constant_op
from tensorflow.python.platform import logging
from tensorflow.python.util import compat
def _Attr(op_def, name):
for attr in op_def.attr:
if attr.name == name:
return attr
raise TypeError("Inconsistent OpDef for '%s', missing attr '%s'" %
(op_def.name, name))
def _AttrValue(attr_protos, name):
if name in attr_protos:
return attr_protos[name]
raise TypeError("Inconsistent OpDef, missing attr '%s' from '%s'." %
(name, attr_protos))
def _SatisfiesTypeConstraint(dtype, attr_def):
if attr_def.HasField("allowed_values"):
allowed_list = attr_def.allowed_values.list.type
if dtype not in allowed_list:
raise TypeError(
"DataType %s for attr '%s' not in list of allowed values: %s" %
(dtypes.as_dtype(dtype).name, attr_def.name,
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
def _IsListParameter(arg):
if arg.number_attr:
return True
elif arg.type_list_attr:
return True
return False
def _NumTypeFields(arg):
num = 0
if arg.type != types_pb2.DT_INVALID: num += 1
if arg.type_attr: num += 1
if arg.type_list_attr: num += 1
return num
def _IsListValue(v):
return isinstance(v, (list, tuple))
def _Flatten(l):
"""Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5]."""
# [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]]
l_of_l = [x if _IsListValue(x) else [x] for x in l]
# [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5]
return [item for sublist in l_of_l for item in sublist]
def _Restructure(l, structure):
"""Returns the elements of list l structured according to the given structure.
A structure is represented by a list whose elements are either
`None` or a non-negative integer. `None` corresponds to a single
element in the output list, and an integer N corresponds to a nested
list of length N.
The function returns a data structure whose shape is given by
`structure`, and whose elements are taken from `l`. If `structure`
is a singleton, the function returns the single data structure
implied by the 0th element of `structure`. For example:
_Restructure(["foo", "bar", "baz", "qux"], [None, 2, None])
-> ["foo", ["bar", "baz"], "qux"]
_Restructure(["foo"], [None]) -> "foo"
_Restructure(["foo"], [1]) -> ["foo"]
_Restructure([], [0]) -> []
Args:
l: A list.
structure: A list whose elements are either `None` or a non-negative
integer.
Returns:
The elements of `l`, restructured according to `structure`. If
`structure` is a list of length 1, this function returns the
single data structure implied by `structure[0]`.
"""
result = []
current_index = 0
for element in structure:
if element is None:
result.append(l[current_index])
current_index += 1
else:
result.append(l[current_index:current_index+element])
current_index += element
if len(result) == 1:
return result[0]
else:
return tuple(result)
def _MakeFloat(v, arg_name):
if not isinstance(v, compat.real_types):
raise TypeError("Expected float for argument '%s' not %s." %
(arg_name, repr(v)))
return float(v)
def _MakeInt(v, arg_name):
if isinstance(v, six.string_types):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
try:
return int(v)
except (ValueError, TypeError):
raise TypeError("Expected int for argument '%s' not %s." %
(arg_name, repr(v)))
def _MakeStr(v, arg_name):
if not isinstance(v, compat.bytes_or_text_types):
raise TypeError("Expected string for argument '%s' not %s." %
(arg_name, repr(v)))
return compat.as_bytes(v) # Convert unicode strings to bytes.
def _MakeBool(v, arg_name):
if not isinstance(v, bool):
raise TypeError("Expected bool for argument '%s' not %s." %
(arg_name, repr(v)))
return v
def _MakeType(v, attr_def):
try:
v = dtypes.as_dtype(v)
except TypeError:
raise TypeError("Expected DataType for argument '%s' not %s." %
(attr_def.name, repr(v)))
i = v.as_datatype_enum
_SatisfiesTypeConstraint(i, attr_def)
return i
def _MakeShape(v, arg_name):
"""Convert v into a TensorShapeProto."""
# Args:
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
# arg_name: String, for error messages.
# Returns:
# A TensorShapeProto.
if isinstance(v, tensor_shape_pb2.TensorShapeProto):
for d in v.dim:
if d.name:
logging.warning("Warning: TensorShapeProto with a named dimension: %s",
str(v))
break
return v
s = tensor_shape.as_shape(v)
ret = tensor_shape_pb2.TensorShapeProto()
for i in s.as_dimension_list():
ret.dim.add(size = i)
return ret
def _MakeTensor(v, arg_name):
"""Ensure v is a TensorProto."""
if isinstance(v, tensor_pb2.TensorProto):
return v
raise TypeError(
"Don't know how to convert %s to a TensorProto for argument '%s'" %
(repr(v), arg_name))
class _OpInfo(object):
"""All per-Op state we would like to precompute/validate."""
def __init__(self, op_def):
self.op_def = op_def
# TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
# here, instead of these checks.
for arg in list(op_def.input_arg) + list(op_def.output_arg):
num_type_fields = _NumTypeFields(arg)
if num_type_fields != 1:
raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
(arg.name, op_def.name, num_type_fields))
if arg.type_attr:
attr_type = _Attr(op_def, arg.type_attr).type
if attr_type != "type":
raise TypeError("Attr '%s' of '%s' used as a type_attr "
"but has type %s" %
(arg.type_attr, op_def.name, attr_type))
if arg.type_list_attr:
attr_type = _Attr(op_def, arg.type_list_attr).type
if attr_type != "list(type)":
raise TypeError(
"Attr '%s' of '%s' used as a type_list_attr but has type %s" %
(arg.type_attr, op_def.name, attr_type))
if arg.number_attr:
attr_type = _Attr(op_def, arg.number_attr).type
if attr_type != "int":
raise TypeError(
"Attr '%s' of '%s' used as a number_attr but has type %s" %
(arg.number_attr, op_def.name, attr_type))
class OpDefLibrary(object):
"""Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
def __init__(self):
self._ops = {}
def add_op(self, op_def):
"""Register an OpDef. May call apply_op with the name afterwards."""
if not isinstance(op_def, op_def_pb2.OpDef):
raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
(op_def, type(op_def)))
if not op_def.name:
raise ValueError("%s missing name." % op_def)
if op_def.name in self._ops:
raise RuntimeError("Op name %s registered twice." % op_def.name)
self._ops[op_def.name] = _OpInfo(op_def)
def add_op_list(self, op_list):
"""Register the OpDefs from an OpList."""
if not isinstance(op_list, op_def_pb2.OpList):
raise TypeError("%s is %s, not an op_def_pb2.OpList" %
(op_list, type(op_list)))
for op_def in op_list.op:
self.add_op(op_def)
def apply_op(self, op_type_name, g=None, name=None, **keywords):
# pylint: disable=g-doc-args
"""Add a node invoking a registered Op to a graph.
Config proto extensions must be provided via the 'ext' keyword argument.
Example usage:
# input1 and input2 can be Tensors or anything ops.convert_to_tensor()
# will convert to a Tensor.
op_def_library.apply_op("op", input1=input1, input2=input2)
# If none of the inputs are Tensors and your session doesn't have a
# default graph, you will have to specify the graph.
op_def_library.apply_op("op", input1=input1, g=g)
# Can specify a node name.
op_def_library.apply_op("op", input1=input1, name="node_name")
# Must use keyword arguments, with the names specified in the OpDef.
op_def_library.apply_op("op", input_name=input, attr_name=attr)
All attrs must either be inferred from an input or specified.
(If inferred, the attr must not be specified.) If an attr has a default
value specified in the Op's OpDef, then you may pass None as the value
of that attr to get the default.
Args:
op_type_name: string. Must match the name field of a registered Op.
g: The graph context (optional)
name: string. Optional name of the created op.
**keywords: input Tensor and attr arguments specified by name,
and optional parameters to pass when constructing the Operation.
Returns:
The Tensor(s) representing the output of the operation, or the Operation
itself if there are no outputs.
Raises:
RuntimeError: On some errors.
TypeError: On some errors.
ValueError: On some errors.
"""
op_info = self._ops.get(op_type_name, None)
if op_info is None:
raise RuntimeError("Unrecognized Op name " + op_type_name)
op_def = op_info.op_def
# Determine the graph context.
try:
# Need to flatten all the arguments into a list.
# pylint: disable=protected-access
g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g)
# pyline: enable=protected-access
except AssertionError as e:
raise RuntimeError(
"Need to specify g=graph to Op '%s' (could not determine graph due "
"to: %s)" % (op_type_name, e.message))
# Default name if not specified.
if name is None:
name = op_type_name
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}
inputs = []
input_types = []
with g.as_default(), ops.name_scope(name) as scope:
# Perform input type inference
inferred_from = {}
for input_arg in op_def.input_arg:
input_name = input_arg.name
if input_name in keywords:
values = keywords.pop(input_name)
elif input_name + "_" in keywords:
# Handle the case where the name is a keyword or built-in
# for Python so we use the name + _ instead.
input_name += "_"
values = keywords.pop(input_name)
else:
raise TypeError("No argument for input " + input_name)
# Goals:
# * Convert values to Tensors if it contains constants.
# * Verify that values is a list if that matches the input_arg's
# type.
# * If the input_arg's type is determined by attrs, either set
# those attrs and validate those attr values are legal (if
# they have not yet been set) or validate the input matches
# the type indicated by the attrs (if they have already been
# inferred via an earlier input).
# * If the input_arg has an explicit type, make sure the input
# conforms.
if _IsListParameter(input_arg):
if not _IsListValue(values):
raise TypeError(
"Expected list for '%s' argument to '%s' Op, not %s." %
(input_name, op_type_name, values))
# In cases where we expect all elements of the list to have the
# same dtype, try to cast non-Tensor elements to that type.
dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.number_attr:
if input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
else:
for t in values:
if isinstance(t, ops.Tensor):
dtype = t.dtype
break
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
values = ops.convert_n_to_tensor_or_indexed_slices(
values, name=input_arg.name,
dtype=dtype if dtype else None,
as_ref=input_arg.is_ref)
except (TypeError, ValueError):
assert dtype is not None, "Should not fail if dtype is None"
assert input_arg.number_attr, "Should be number_attr case"
# What types does the conversion function think values have?
values = ops.convert_n_to_tensor_or_indexed_slices(
values, as_ref=input_arg.is_ref)
observed = ", ".join(v.dtype.base_dtype.name for v in values)
prefix = (
"Tensors in list passed to '%s' of '%s' Op have types [%s]" %
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s that do not match expected type %s." %
(prefix, dtype.name))
elif input_arg.type_attr in attrs:
raise TypeError("%s that do not match type %s inferred from "
"earlier arguments." %
(prefix, dtype.name))
else:
raise TypeError("%s that don't all match." % prefix)
types = [x.dtype for x in values]
inputs.extend(values)
else:
# In cases where we have an expected type, try to convert non-Tensor
# arguments to that type.
dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
try:
values = ops.convert_to_tensor(
values, name=input_arg.name, dtype=dtype,
as_ref=input_arg.is_ref)
except ValueError:
# What type does convert_to_tensor think it has?
observed = ops.convert_to_tensor(values,
as_ref=input_arg.is_ref).dtype.name
prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
(input_name, op_type_name, observed))
if input_arg.type != types_pb2.DT_INVALID:
raise TypeError("%s expected type of %s." %
(prefix, dtypes.as_dtype(input_arg.type).name))
else:
raise TypeError(
"%s type %s of argument '%s'." %
(prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
inferred_from[input_arg.type_attr]))
types = [values.dtype]
inputs.append(values)
base_types = [x.base_dtype for x in types]
if input_arg.number_attr:
# <number-attr> * <type> or <number-attr> * <type-attr>
if input_arg.number_attr in attrs:
if len(values) != attrs[input_arg.number_attr]:
raise ValueError(
"List argument '%s' to '%s' Op with length %d must match "
"length %d of argument '%s'." %
(input_name, op_type_name, len(values),
attrs[input_arg.number_attr],
inferred_from[input_arg.number_attr]))
else:
attrs[input_arg.number_attr] = len(values)
inferred_from[input_arg.number_attr] = input_name
num_attr = _Attr(op_def, input_arg.number_attr)
if num_attr.has_minimum and len(values) < num_attr.minimum:
raise ValueError(
"List argument '%s' to '%s' Op with length %d shorter "
"than minimum length %d." %
(input_name, op_type_name, len(values), num_attr.minimum))
# All tensors must have the same base type.
if any([bt != base_types[0] for bt in base_types]):
raise TypeError(
"All tensors passed to '%s' of '%s' Op "
"must have the same type." %
(input_name, op_type_name))
if input_arg.type != types_pb2.DT_INVALID:
# <number-attr> * <type> case
if base_types and base_types[0] != input_arg.type:
assert False, "Unreachable"
elif input_arg.type_attr in attrs:
# <number-attr> * <type-attr> case, where <type-attr> already
# has an inferred value.
if base_types and base_types[0] != attrs[input_arg.type_attr]:
assert False, "Unreachable"
else:
# <number-attr> * <type-attr> case, where we are now setting
# the <type-attr> based on this input
if not base_types:
raise TypeError(
"Don't know how to infer type variable from empty input "
"list passed to input '%s' of '%s' Op." %
(input_name, op_type_name))
attrs[input_arg.type_attr] = base_types[0]
inferred_from[input_arg.type_attr] = input_name
type_attr = _Attr(op_def, input_arg.type_attr)
_SatisfiesTypeConstraint(base_types[0], type_attr)
elif input_arg.type_attr:
# <type-attr>
attr_value = base_types[0]
if input_arg.type_attr in attrs:
if attrs[input_arg.type_attr] != attr_value:
assert False, "Unreachable"
else:
for base_type in base_types:
_SatisfiesTypeConstraint(base_type,
_Attr(op_def, input_arg.type_attr))
attrs[input_arg.type_attr] = attr_value
inferred_from[input_arg.type_attr] = input_name
elif input_arg.type_list_attr:
# <type-list-attr>
attr_value = base_types
if input_arg.type_list_attr in attrs:
if attrs[input_arg.type_list_attr] != attr_value:
raise TypeError(
"Input '%s' of '%s' Op has type list of %s that does not "
"match type list %s of argument '%s'." %
(input_name, op_type_name,
", ".join(dtypes.as_dtype(x).name for x in attr_value),
", ".join(dtypes.as_dtype(x).name
for x in attrs[input_arg.type_list_attr]),
inferred_from[input_arg.type_list_attr]))
else:
for base_type in base_types:
_SatisfiesTypeConstraint(base_type,
_Attr(op_def, input_arg.type_list_attr))
attrs[input_arg.type_list_attr] = attr_value
inferred_from[input_arg.type_list_attr] = input_name
else:
# single Tensor with specified type
if base_types[0] != input_arg.type:
assert False, "Unreachable"
if input_arg.is_ref:
if not all(x.is_ref_dtype for x in types):
raise TypeError(
"Input '%s' of '%s' Op requires l-value input" %
(input_name, op_type_name))
input_types.extend(types)
else:
input_types.extend(base_types)
# Process remaining attrs
for attr in op_def.attr:
# Skip attrs that have already had their values inferred
if attr.name in attrs:
if attr.name in keywords:
raise TypeError(
"Should not specify value for inferred attr '%s'." % attr.name)
continue
if attr.name in keywords:
attrs[attr.name] = keywords.pop(attr.name)
elif attr.name + "_" in keywords:
# Attrs whose names match Python keywords have an extra '_'
# appended, so we must check for that as well.
attrs[attr.name] = keywords.pop(attr.name + "_")
else:
raise TypeError("No argument for attr " + attr.name)
# Convert attr values to AttrValue protos.
attr_protos = {}
for attr_def in op_def.attr:
key = attr_def.name
value = attrs[key]
attr_value = attr_value_pb2.AttrValue()
if attr_def.HasField("default_value") and value is None:
attr_value.CopyFrom(attr_def.default_value)
attr_protos[key] = attr_value
continue
if attr_def.type.startswith("list("):
if not _IsListValue(value):
raise TypeError("Expected list for attr " + key)
if attr_def.has_minimum:
if len(value) < attr_def.minimum:
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
"less than minimum %d." %
(key, op_type_name, len(value),
attr_def.minimum))
if attr_def.type == "string":
attr_value.s = _MakeStr(value, key)
if attr_def.HasField("allowed_values"):
if attr_value.s not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(attr_value.s),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "list(string)":
attr_value.list.s.extend([_MakeStr(x, key) for x in value])
if attr_def.HasField("allowed_values"):
for x in attr_value.list.s:
if x not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(x),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "int":
attr_value.i = _MakeInt(value, key)
if attr_def.has_minimum:
if attr_value.i < attr_def.minimum:
raise ValueError(
"Attr '%s' of '%s' Op passed %d less than minimum %d." %
(key, op_type_name, attr_value.i, attr_def.minimum))
elif attr_def.type == "list(int)":
attr_value.list.i.extend([_MakeInt(x, key) for x in value])
elif attr_def.type == "float":
attr_value.f = _MakeFloat(value, key)
elif attr_def.type == "list(float)":
attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
elif attr_def.type == "bool":
attr_value.b = _MakeBool(value, key)
elif attr_def.type == "list(bool)":
attr_value.list.b.extend([_MakeBool(x, key) for x in value])
elif attr_def.type == "type":
attr_value.type = _MakeType(value, attr_def)
elif attr_def.type == "list(type)":
attr_value.list.type.extend(
[_MakeType(x, attr_def) for x in value])
elif attr_def.type == "shape":
attr_value.shape.CopyFrom(_MakeShape(value, key))
elif attr_def.type == "list(shape)":
attr_value.list.shape.extend(
[_MakeShape(x, key) for x in value])
elif attr_def.type == "tensor":
attr_value.tensor.CopyFrom(_MakeTensor(value, key))
elif attr_def.type == "list(tensor)":
attr_value.list.tensor.extend(
[_MakeTensor(x, key) for x in value])
elif attr_def.type == "func":
if not isinstance(value, compat.bytes_or_text_types):
raise TypeError("Expects a string for the func name")
attr_value.func.name = value
else:
raise TypeError("Unrecognized Attr type " + attr_def.type)
attr_protos[key] = attr_value
del attrs # attrs is no longer authoritative, use attr_protos instead
# Determine output types (possibly using attrs)
output_types = []
output_structure = []
for arg in op_def.output_arg:
types = []
if arg.number_attr:
n = _AttrValue(attr_protos, arg.number_attr).i
if arg.type_attr:
types = [_AttrValue(attr_protos, arg.type_attr).type] * n
else:
types = [arg.type] * n
output_structure.append(n)
elif arg.type_attr:
t = _AttrValue(attr_protos, arg.type_attr)
types = [t.type]
output_structure.append(None)
elif arg.type_list_attr:
t = _AttrValue(attr_protos, arg.type_list_attr)
types = t.list.type
output_structure.append(len(t.list.type))
else:
types = [arg.type]
output_structure.append(None)
if arg.is_ref:
types = [dtypes.as_dtype(x).as_ref for x in types]
output_types.extend(types)
if keywords:
raise TypeError("apply_op() got unexpected keyword arguments: " +
", ".join(sorted(keywords.keys())))
# Add Op to graph
if output_structure:
op = g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)
outputs = op.outputs
return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
output_structure)
else:
return g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos,
op_def=op_def)