Split some of the code in _apply_op_helper out into helper functions.
PiperOrigin-RevId: 328931451 Change-Id: I33eed3decaa6f21485d42eaf4ba036e9da81cd06
This commit is contained in:
parent
6ba62e9fea
commit
46b6d3707f
@ -63,6 +63,27 @@ def _SatisfiesTypeConstraint(dtype, attr_def, param_name):
|
||||
", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
|
||||
|
||||
|
||||
def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name):
|
||||
if attr_def.has_minimum and length < attr_def.minimum:
|
||||
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
|
||||
"less than minimum %d." %
|
||||
(param_name, op_type_name, length, attr_def.minimum))
|
||||
|
||||
|
||||
def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name):
|
||||
if value not in attr_def.allowed_values.list.s:
|
||||
raise ValueError(
|
||||
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
|
||||
(arg_name, op_type_name, compat.as_text(value), '", "'.join(
|
||||
map(compat.as_text, attr_def.allowed_values.list.s))))
|
||||
|
||||
|
||||
def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name):
|
||||
if value < attr_def.minimum:
|
||||
raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." %
|
||||
(arg_name, op_type_name, value, attr_def.minimum))
|
||||
|
||||
|
||||
def _IsListParameter(arg):
|
||||
if arg.number_attr:
|
||||
return True
|
||||
@ -172,15 +193,13 @@ def _MakeBool(v, arg_name):
|
||||
return v
|
||||
|
||||
|
||||
def _MakeType(v, attr_def):
|
||||
def _MakeType(v, arg_name):
|
||||
try:
|
||||
v = dtypes.as_dtype(v).base_dtype
|
||||
except TypeError:
|
||||
raise TypeError("Expected DataType for argument '%s' not %s." %
|
||||
(attr_def.name, repr(v)))
|
||||
i = v.as_datatype_enum
|
||||
_SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name)
|
||||
return i
|
||||
(arg_name, repr(v)))
|
||||
return v.as_datatype_enum
|
||||
|
||||
|
||||
def _MakeShape(v, arg_name):
|
||||
@ -670,78 +689,32 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
for attr_def in op_def.attr:
|
||||
key = attr_def.name
|
||||
value = attrs[key]
|
||||
attr_value = attr_value_pb2.AttrValue()
|
||||
|
||||
if attr_def.HasField("default_value") and value is None:
|
||||
attr_value = attr_value_pb2.AttrValue()
|
||||
attr_value.CopyFrom(attr_def.default_value)
|
||||
attr_protos[key] = attr_value
|
||||
continue
|
||||
|
||||
attr_value = value_to_attr_value(value, attr_def.type, key)
|
||||
if attr_def.type.startswith("list("):
|
||||
if not _IsListValue(value):
|
||||
raise TypeError("Expected list for attr " + key)
|
||||
if attr_def.has_minimum:
|
||||
if len(value) < attr_def.minimum:
|
||||
raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
|
||||
"less than minimum %d." %
|
||||
(key, op_type_name, len(value),
|
||||
attr_def.minimum))
|
||||
attr_value.list.SetInParent()
|
||||
if attr_def.type == "string":
|
||||
attr_value.s = _MakeStr(value, key)
|
||||
if attr_def.HasField("allowed_values"):
|
||||
if attr_value.s not in attr_def.allowed_values.list.s:
|
||||
raise ValueError(
|
||||
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
|
||||
(key, op_type_name, compat.as_text(attr_value.s),
|
||||
'", "'.join(map(compat.as_text,
|
||||
attr_def.allowed_values.list.s))))
|
||||
elif attr_def.type == "list(string)":
|
||||
attr_value.list.s.extend([_MakeStr(x, key) for x in value])
|
||||
if attr_def.HasField("allowed_values"):
|
||||
for x in attr_value.list.s:
|
||||
if x not in attr_def.allowed_values.list.s:
|
||||
raise ValueError(
|
||||
"Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
|
||||
(key, op_type_name, compat.as_text(x),
|
||||
'", "'.join(map(compat.as_text,
|
||||
attr_def.allowed_values.list.s))))
|
||||
elif attr_def.type == "int":
|
||||
attr_value.i = _MakeInt(value, key)
|
||||
if attr_def.has_minimum:
|
||||
if attr_value.i < attr_def.minimum:
|
||||
raise ValueError(
|
||||
"Attr '%s' of '%s' Op passed %d less than minimum %d." %
|
||||
(key, op_type_name, attr_value.i, attr_def.minimum))
|
||||
elif attr_def.type == "list(int)":
|
||||
attr_value.list.i.extend([_MakeInt(x, key) for x in value])
|
||||
elif attr_def.type == "float":
|
||||
attr_value.f = _MakeFloat(value, key)
|
||||
elif attr_def.type == "list(float)":
|
||||
attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
|
||||
elif attr_def.type == "bool":
|
||||
attr_value.b = _MakeBool(value, key)
|
||||
elif attr_def.type == "list(bool)":
|
||||
attr_value.list.b.extend([_MakeBool(x, key) for x in value])
|
||||
elif attr_def.type == "type":
|
||||
attr_value.type = _MakeType(value, attr_def)
|
||||
elif attr_def.type == "list(type)":
|
||||
attr_value.list.type.extend(
|
||||
[_MakeType(x, attr_def) for x in value])
|
||||
elif attr_def.type == "shape":
|
||||
attr_value.shape.CopyFrom(_MakeShape(value, key))
|
||||
elif attr_def.type == "list(shape)":
|
||||
attr_value.list.shape.extend(
|
||||
[_MakeShape(x, key) for x in value])
|
||||
elif attr_def.type == "tensor":
|
||||
attr_value.tensor.CopyFrom(_MakeTensor(value, key))
|
||||
elif attr_def.type == "list(tensor)":
|
||||
attr_value.list.tensor.extend(
|
||||
[_MakeTensor(x, key) for x in value])
|
||||
elif attr_def.type == "func":
|
||||
attr_value.func.CopyFrom(_MakeFunc(value, key))
|
||||
elif attr_def.type == "list(func)":
|
||||
attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
|
||||
else:
|
||||
raise TypeError("Unrecognized Attr type " + attr_def.type)
|
||||
_SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
|
||||
if attr_def.HasField("allowed_values"):
|
||||
if attr_def.type == "string":
|
||||
_SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
|
||||
op_type_name)
|
||||
elif attr_def.type == "list(string)":
|
||||
for value in attr_value.list.s:
|
||||
_SatisfiesAllowedStringsConstraint(value, attr_def, key,
|
||||
op_type_name)
|
||||
if attr_def.has_minimum and attr_def.type == "int":
|
||||
_SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
|
||||
op_type_name)
|
||||
if attr_def.type == "type":
|
||||
_SatisfiesTypeConstraint(attr_value.type, attr_def, key)
|
||||
if attr_def.type == "list(type)":
|
||||
for value in attr_value.list.type:
|
||||
_SatisfiesTypeConstraint(value, attr_def, key)
|
||||
|
||||
attr_protos[key] = attr_value
|
||||
del attrs # attrs is no longer authoritative, use attr_protos instead
|
||||
@ -792,6 +765,61 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
|
||||
return output_structure, op_def.is_stateful, op, outputs
|
||||
|
||||
|
||||
def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name
|
||||
"""Encodes a Python value as an `AttrValue` proto message.
|
||||
|
||||
Args:
|
||||
value: The value to convert.
|
||||
attr_type: The value type (string) -- see the AttrValue proto definition for
|
||||
valid strings.
|
||||
arg_name: Argument name (for error messages).
|
||||
|
||||
Returns:
|
||||
An AttrValue proto message that encodes `value`.
|
||||
"""
|
||||
attr_value = attr_value_pb2.AttrValue()
|
||||
|
||||
if attr_type.startswith("list("):
|
||||
if not _IsListValue(value):
|
||||
raise TypeError("Expected list for attr " + arg_name)
|
||||
|
||||
if attr_type == "string":
|
||||
attr_value.s = _MakeStr(value, arg_name)
|
||||
elif attr_type == "list(string)":
|
||||
attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value])
|
||||
elif attr_type == "int":
|
||||
attr_value.i = _MakeInt(value, arg_name)
|
||||
elif attr_type == "list(int)":
|
||||
attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value])
|
||||
elif attr_type == "float":
|
||||
attr_value.f = _MakeFloat(value, arg_name)
|
||||
elif attr_type == "list(float)":
|
||||
attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value])
|
||||
elif attr_type == "bool":
|
||||
attr_value.b = _MakeBool(value, arg_name)
|
||||
elif attr_type == "list(bool)":
|
||||
attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value])
|
||||
elif attr_type == "type":
|
||||
attr_value.type = _MakeType(value, arg_name)
|
||||
elif attr_type == "list(type)":
|
||||
attr_value.list.type.extend([_MakeType(x, arg_name) for x in value])
|
||||
elif attr_type == "shape":
|
||||
attr_value.shape.CopyFrom(_MakeShape(value, arg_name))
|
||||
elif attr_type == "list(shape)":
|
||||
attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value])
|
||||
elif attr_type == "tensor":
|
||||
attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name))
|
||||
elif attr_type == "list(tensor)":
|
||||
attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value])
|
||||
elif attr_type == "func":
|
||||
attr_value.func.CopyFrom(_MakeFunc(value, arg_name))
|
||||
elif attr_type == "list(func)":
|
||||
attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value])
|
||||
else:
|
||||
raise TypeError("Unrecognized Attr type " + attr_type)
|
||||
return attr_value
|
||||
|
||||
|
||||
# The following symbols are used by op_def_util.cc.
|
||||
_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
|
||||
_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user