From 46b6d3707f97b9ef95a7d6d6bd75a93fda4ab8ea Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 28 Aug 2020 07:02:24 -0700 Subject: [PATCH] Split some of the code in _apply_op_helper out into helper functions. PiperOrigin-RevId: 328931451 Change-Id: I33eed3decaa6f21485d42eaf4ba036e9da81cd06 --- tensorflow/python/framework/op_def_library.py | 172 ++++++++++-------- 1 file changed, 100 insertions(+), 72 deletions(-) diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 53d092787f6..016af65fc0a 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -63,6 +63,27 @@ def _SatisfiesTypeConstraint(dtype, attr_def, param_name): ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) +def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): + if attr_def.has_minimum and length < attr_def.minimum: + raise ValueError("Attr '%s' of '%s' Op passed list of length %d " + "less than minimum %d." % + (param_name, op_type_name, length, attr_def.minimum)) + + +def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): + if value not in attr_def.allowed_values.list.s: + raise ValueError( + "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % + (arg_name, op_type_name, compat.as_text(value), '", "'.join( + map(compat.as_text, attr_def.allowed_values.list.s)))) + + +def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): + if value < attr_def.minimum: + raise ValueError("Attr '%s' of '%s' Op passed %d less than minimum %d." % + (arg_name, op_type_name, value, attr_def.minimum)) + + def _IsListParameter(arg): if arg.number_attr: return True @@ -172,15 +193,13 @@ def _MakeBool(v, arg_name): return v -def _MakeType(v, attr_def): +def _MakeType(v, arg_name): try: v = dtypes.as_dtype(v).base_dtype except TypeError: raise TypeError("Expected DataType for argument '%s' not %s." % - (attr_def.name, repr(v))) - i = v.as_datatype_enum - _SatisfiesTypeConstraint(i, attr_def, param_name=attr_def.name) - return i + (arg_name, repr(v))) + return v.as_datatype_enum def _MakeShape(v, arg_name): @@ -670,78 +689,32 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in for attr_def in op_def.attr: key = attr_def.name value = attrs[key] - attr_value = attr_value_pb2.AttrValue() + if attr_def.HasField("default_value") and value is None: + attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue + + attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): - if not _IsListValue(value): - raise TypeError("Expected list for attr " + key) - if attr_def.has_minimum: - if len(value) < attr_def.minimum: - raise ValueError("Attr '%s' of '%s' Op passed list of length %d " - "less than minimum %d." % - (key, op_type_name, len(value), - attr_def.minimum)) - attr_value.list.SetInParent() - if attr_def.type == "string": - attr_value.s = _MakeStr(value, key) - if attr_def.HasField("allowed_values"): - if attr_value.s not in attr_def.allowed_values.list.s: - raise ValueError( - "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(attr_value.s), - '", "'.join(map(compat.as_text, - attr_def.allowed_values.list.s)))) - elif attr_def.type == "list(string)": - attr_value.list.s.extend([_MakeStr(x, key) for x in value]) - if attr_def.HasField("allowed_values"): - for x in attr_value.list.s: - if x not in attr_def.allowed_values.list.s: - raise ValueError( - "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(x), - '", "'.join(map(compat.as_text, - attr_def.allowed_values.list.s)))) - elif attr_def.type == "int": - attr_value.i = _MakeInt(value, key) - if attr_def.has_minimum: - if attr_value.i < attr_def.minimum: - raise ValueError( - "Attr '%s' of '%s' Op passed %d less than minimum %d." % - (key, op_type_name, attr_value.i, attr_def.minimum)) - elif attr_def.type == "list(int)": - attr_value.list.i.extend([_MakeInt(x, key) for x in value]) - elif attr_def.type == "float": - attr_value.f = _MakeFloat(value, key) - elif attr_def.type == "list(float)": - attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) - elif attr_def.type == "bool": - attr_value.b = _MakeBool(value, key) - elif attr_def.type == "list(bool)": - attr_value.list.b.extend([_MakeBool(x, key) for x in value]) - elif attr_def.type == "type": - attr_value.type = _MakeType(value, attr_def) - elif attr_def.type == "list(type)": - attr_value.list.type.extend( - [_MakeType(x, attr_def) for x in value]) - elif attr_def.type == "shape": - attr_value.shape.CopyFrom(_MakeShape(value, key)) - elif attr_def.type == "list(shape)": - attr_value.list.shape.extend( - [_MakeShape(x, key) for x in value]) - elif attr_def.type == "tensor": - attr_value.tensor.CopyFrom(_MakeTensor(value, key)) - elif attr_def.type == "list(tensor)": - attr_value.list.tensor.extend( - [_MakeTensor(x, key) for x in value]) - elif attr_def.type == "func": - attr_value.func.CopyFrom(_MakeFunc(value, key)) - elif attr_def.type == "list(func)": - attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) - else: - raise TypeError("Unrecognized Attr type " + attr_def.type) + _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) + if attr_def.HasField("allowed_values"): + if attr_def.type == "string": + _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, + op_type_name) + elif attr_def.type == "list(string)": + for value in attr_value.list.s: + _SatisfiesAllowedStringsConstraint(value, attr_def, key, + op_type_name) + if attr_def.has_minimum and attr_def.type == "int": + _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, + op_type_name) + if attr_def.type == "type": + _SatisfiesTypeConstraint(attr_value.type, attr_def, key) + if attr_def.type == "list(type)": + for value in attr_value.list.type: + _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead @@ -792,6 +765,61 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in return output_structure, op_def.is_stateful, op, outputs +def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name + """Encodes a Python value as an `AttrValue` proto message. + + Args: + value: The value to convert. + attr_type: The value type (string) -- see the AttrValue proto definition for + valid strings. + arg_name: Argument name (for error messages). + + Returns: + An AttrValue proto message that encodes `value`. + """ + attr_value = attr_value_pb2.AttrValue() + + if attr_type.startswith("list("): + if not _IsListValue(value): + raise TypeError("Expected list for attr " + arg_name) + + if attr_type == "string": + attr_value.s = _MakeStr(value, arg_name) + elif attr_type == "list(string)": + attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) + elif attr_type == "int": + attr_value.i = _MakeInt(value, arg_name) + elif attr_type == "list(int)": + attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) + elif attr_type == "float": + attr_value.f = _MakeFloat(value, arg_name) + elif attr_type == "list(float)": + attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) + elif attr_type == "bool": + attr_value.b = _MakeBool(value, arg_name) + elif attr_type == "list(bool)": + attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) + elif attr_type == "type": + attr_value.type = _MakeType(value, arg_name) + elif attr_type == "list(type)": + attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) + elif attr_type == "shape": + attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) + elif attr_type == "list(shape)": + attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) + elif attr_type == "tensor": + attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) + elif attr_type == "list(tensor)": + attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) + elif attr_type == "func": + attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) + elif attr_type == "list(func)": + attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) + else: + raise TypeError("Unrecognized Attr type " + attr_type) + return attr_value + + # The following symbols are used by op_def_util.cc. _pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) _pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)