Split some of the code in _apply_op_helper out into helper functions.

PiperOrigin-RevId: 328931451
Change-Id: I33eed3decaa6f21485d42eaf4ba036e9da81cd06
This commit is contained in:
Edward Loper 2020-08-28 07:02:24 -07:00 committed by TensorFlower Gardener
parent 6ba62e9fea
commit 46b6d3707f

View File

@ -63,6 +63,27 @@ def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
if attr_def.has_minimum and length < attr_def.minimum:
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
"less than minimum %d." %
(param_name, op_type_name, length, attr_def.minimum))
def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
if value not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(arg_name, op_type_name, compat.as_text(value), '", "'.join(
map(compat.as_text, attr_def.allowed_values.list.s))))
def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
if value < attr_def.minimum:
raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." %
(arg_name, op_type_name, value, attr_def.minimum))
def _IsListParameter(arg):
if arg.number_attr:
return True
@ -172,15 +193,13 @@ def _MakeBool(v, arg_name):
return v
def _MakeType(v, attr_def):
def _MakeType(v, arg_name):
try:
v = dtypes.as_dtype(v).base_dtype
except TypeError:
raise TypeError("Expected DataType for argument '%s' not %s." %
(attr_def.name, repr(v)))
i = v.as_datatype_enum
_SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
return i
(arg_name, repr(v)))
return v.as_datatype_enum
def _MakeShape(v, arg_name):
@ -670,78 +689,32 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
for attr_def in op_def.attr:
key = attr_def.name
value = attrs[key]
attr_value = attr_value_pb2.AttrValue()
if attr_def.HasField("default_value") and value is None:
attr_value = attr_value_pb2.AttrValue()
attr_value.CopyFrom(attr_def.default_value)
attr_protos[key] = attr_value
continue
attr_value = value_to_attr_value(value, attr_def.type, key)
if attr_def.type.startswith("list("):
if not _IsListValue(value):
raise TypeError("Expected list for attr " + key)
if attr_def.has_minimum:
if len(value) < attr_def.minimum:
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
"less than minimum %d." %
(key, op_type_name, len(value),
attr_def.minimum))
attr_value.list.SetInParent()
if attr_def.type == "string":
attr_value.s = _MakeStr(value, key)
if attr_def.HasField("allowed_values"):
if attr_value.s not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(attr_value.s),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "list(string)":
attr_value.list.s.extend([_MakeStr(x, key) for x in value])
if attr_def.HasField("allowed_values"):
for x in attr_value.list.s:
if x not in attr_def.allowed_values.list.s:
raise ValueError(
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
(key, op_type_name, compat.as_text(x),
'", "'.join(map(compat.as_text,
attr_def.allowed_values.list.s))))
elif attr_def.type == "int":
attr_value.i = _MakeInt(value, key)
if attr_def.has_minimum:
if attr_value.i < attr_def.minimum:
raise ValueError(
"Attr '%s' of '%s' Op passed %d less than minimum %d." %
(key, op_type_name, attr_value.i, attr_def.minimum))
elif attr_def.type == "list(int)":
attr_value.list.i.extend([_MakeInt(x, key) for x in value])
elif attr_def.type == "float":
attr_value.f = _MakeFloat(value, key)
elif attr_def.type == "list(float)":
attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
elif attr_def.type == "bool":
attr_value.b = _MakeBool(value, key)
elif attr_def.type == "list(bool)":
attr_value.list.b.extend([_MakeBool(x, key) for x in value])
elif attr_def.type == "type":
attr_value.type = _MakeType(value, attr_def)
elif attr_def.type == "list(type)":
attr_value.list.type.extend(
[_MakeType(x, attr_def) for x in value])
elif attr_def.type == "shape":
attr_value.shape.CopyFrom(_MakeShape(value, key))
elif attr_def.type == "list(shape)":
attr_value.list.shape.extend(
[_MakeShape(x, key) for x in value])
elif attr_def.type == "tensor":
attr_value.tensor.CopyFrom(_MakeTensor(value, key))
elif attr_def.type == "list(tensor)":
attr_value.list.tensor.extend(
[_MakeTensor(x, key) for x in value])
elif attr_def.type == "func":
attr_value.func.CopyFrom(_MakeFunc(value, key))
elif attr_def.type == "list(func)":
attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
else:
raise TypeError("Unrecognized Attr type " + attr_def.type)
_SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
if attr_def.HasField("allowed_values"):
if attr_def.type == "string":
_SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
op_type_name)
elif attr_def.type == "list(string)":
for value in attr_value.list.s:
_SatisfiesAllowedStringsConstraint(value, attr_def, key,
op_type_name)
if attr_def.has_minimum and attr_def.type == "int":
_SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
op_type_name)
if attr_def.type == "type":
_SatisfiesTypeConstraint(attr_value.type, attr_def, key)
if attr_def.type == "list(type)":
for value in attr_value.list.type:
_SatisfiesTypeConstraint(value, attr_def, key)
attr_protos[key] = attr_value
del attrs # attrs is no longer authoritative, use attr_protos instead
@ -792,6 +765,61 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
return output_structure, op_def.is_stateful, op, outputs
def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name
"""Encodes a Python value as an `AttrValue` proto message.
Args:
value: The value to convert.
attr_type: The value type (string) -- see the AttrValue proto definition for
valid strings.
arg_name: Argument name (for error messages).
Returns:
An AttrValue proto message that encodes `value`.
"""
attr_value = attr_value_pb2.AttrValue()
if attr_type.startswith("list("):
if not _IsListValue(value):
raise TypeError("Expected list for attr " + arg_name)
if attr_type == "string":
attr_value.s = _MakeStr(value, arg_name)
elif attr_type == "list(string)":
attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
elif attr_type == "int":
attr_value.i = _MakeInt(value, arg_name)
elif attr_type == "list(int)":
attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
elif attr_type == "float":
attr_value.f = _MakeFloat(value, arg_name)
elif attr_type == "list(float)":
attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
elif attr_type == "bool":
attr_value.b = _MakeBool(value, arg_name)
elif attr_type == "list(bool)":
attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
elif attr_type == "type":
attr_value.type = _MakeType(value, arg_name)
elif attr_type == "list(type)":
attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
elif attr_type == "shape":
attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
elif attr_type == "list(shape)":
attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
elif attr_type == "tensor":
attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
elif attr_type == "list(tensor)":
attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
elif attr_type == "func":
attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
elif attr_type == "list(func)":
attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
else:
raise TypeError("Unrecognized Attr type " + attr_type)
return attr_value
# The following symbols are used by op_def_util.cc.
_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)