diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a5d2301d706..4944b2cb109 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -974,6 +974,7 @@ tf_py_test( ":function_def_to_graph", ":graph_to_function_def", ":math_ops", + ":op_def_library", ":test_ops", ], tags = ["no_pip"], @@ -1031,6 +1032,7 @@ py_library( deps = [ ":dtypes", ":framework_ops", + ":op_def_registry", ":platform", ":tensor_shape", ":util", @@ -2039,7 +2041,6 @@ tf_py_test( ":framework_for_generated_wrappers", ":framework_test_lib", ":platform_test", - ":test_ops", ], ) diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py index 5ccdf896127..7fa02e9fbfb 100644 --- a/tensorflow/python/framework/function_def_to_graph_test.py +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -23,10 +23,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function_def_to_graph from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import op_def_library from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.framework import test_ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -119,8 +120,7 @@ class FunctionDefToGraphDefTest(test.TestCase): y = array_ops.placeholder(dtypes.int32, name="y") z = array_ops.placeholder(dtypes.int32, name="z") - d_1, e_1 = test_ops._op_def_lib.apply_op( - "Foo1", name="foo_1", a=x, b=y, c=z) + d_1, e_1 = op_def_library.apply_op("Foo1", name="foo_1", a=x, b=y, c=z) list_output0, list_output1 = test_ops.list_output( T=[dtypes.int32, dtypes.int32], name="list_output") diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index ab3cefd9f58..ec780f26d6b 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -22,10 +22,10 @@ from __future__ import print_function import six from tensorflow.core.framework import attr_value_pb2 -from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_callbacks from tensorflow.python.framework import ops @@ -229,38 +229,6 @@ def _MakeFunc(v, arg_name): return fn_attr -class _OpInfo(object): - """All per-Op state we would like to precompute/validate.""" - - def __init__(self, op_def): - self.op_def = op_def - # TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it - # here, instead of these checks. - for arg in list(op_def.input_arg) + list(op_def.output_arg): - num_type_fields = _NumTypeFields(arg) - if num_type_fields != 1: - raise TypeError("Arg '%s' of '%s' must have one type field not %d" % - (arg.name, op_def.name, num_type_fields)) - if arg.type_attr: - attr_type = _Attr(op_def, arg.type_attr).type - if attr_type != "type": - raise TypeError("Attr '%s' of '%s' used as a type_attr " - "but has type %s" % - (arg.type_attr, op_def.name, attr_type)) - if arg.type_list_attr: - attr_type = _Attr(op_def, arg.type_list_attr).type - if attr_type != "list(type)": - raise TypeError( - "Attr '%s' of '%s' used as a type_list_attr but has type %s" % - (arg.type_attr, op_def.name, attr_type)) - if arg.number_attr: - attr_type = _Attr(op_def, arg.number_attr).type - if attr_type != "int": - raise TypeError( - "Attr '%s' of '%s' used as a number_attr but has type %s" % - (arg.number_attr, op_def.name, attr_type)) - - # pylint: disable=g-doc-return-or-yield @tf_contextlib.contextmanager def _MaybeColocateWith(inputs): @@ -282,530 +250,501 @@ def _MaybeColocateWith(inputs): # pylint: enable=g-doc-return-or-yield -class OpDefLibrary(object): - """Holds a collection of OpDefs, can add the corresponding Ops to a graph.""" +def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name + """Add a node invoking a registered Op to a graph. - def __init__(self): - self._ops = {} + Example usage: + # input1 and input2 can be Tensors or anything ops.convert_to_tensor() + # will convert to a Tensor. + op_def_library.apply_op("op", input1=input1, input2=input2) + # Can specify a node name. + op_def_library.apply_op("op", input1=input1, name="node_name") + # Must use keyword arguments, with the names specified in the OpDef. + op_def_library.apply_op("op", input_name=input, attr_name=attr) - # pylint: disable=invalid-name - def add_op(self, op_def): - """Register an OpDef. May call apply_op with the name afterwards.""" - if not isinstance(op_def, op_def_pb2.OpDef): - raise TypeError("%s is %s, not an op_def_pb2.OpDef" % - (op_def, type(op_def))) - if not op_def.name: - raise ValueError("%s missing name." % op_def) - if op_def.name in self._ops: - raise RuntimeError("Op name %s registered twice." % op_def.name) - self._ops[op_def.name] = _OpInfo(op_def) + All attrs must either be inferred from an input or specified. + (If inferred, the attr must not be specified.) If an attr has a default + value specified in the Op's OpDef, then you may pass None as the value + of that attr to get the default. - def add_op_list(self, op_list): - """Register the OpDefs from an OpList.""" - if not isinstance(op_list, op_def_pb2.OpList): - raise TypeError("%s is %s, not an op_def_pb2.OpList" % - (op_list, type(op_list))) - for op_def in op_list.op: - self.add_op(op_def) + Args: + op_type_name: string. Must match the name field of a registered Op. + name: string. Optional name of the created op. + **keywords: input Tensor and attr arguments specified by name, + and optional parameters to pass when constructing the Operation. - def apply_op(self, op_type_name, name=None, **keywords): - # pylint: disable=g-doc-args - """Add a node invoking a registered Op to a graph. + Returns: + The Tensor(s) representing the output of the operation, or the Operation + itself if there are no outputs. - Example usage: - # input1 and input2 can be Tensors or anything ops.convert_to_tensor() - # will convert to a Tensor. - op_def_library.apply_op("op", input1=input1, input2=input2) - # Can specify a node name. - op_def_library.apply_op("op", input1=input1, name="node_name") - # Must use keyword arguments, with the names specified in the OpDef. - op_def_library.apply_op("op", input_name=input, attr_name=attr) - - All attrs must either be inferred from an input or specified. - (If inferred, the attr must not be specified.) If an attr has a default - value specified in the Op's OpDef, then you may pass None as the value - of that attr to get the default. - - Args: - op_type_name: string. Must match the name field of a registered Op. - name: string. Optional name of the created op. - **keywords: input Tensor and attr arguments specified by name, - and optional parameters to pass when constructing the Operation. - - Returns: - The Tensor(s) representing the output of the operation, or the Operation - itself if there are no outputs. - - Raises: - RuntimeError: On some errors. - TypeError: On some errors. - ValueError: On some errors. - """ - output_structure, is_stateful, op, outputs = self._apply_op_helper( - op_type_name, name, **keywords) - if output_structure: - res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) - if isinstance(res, list) and not res and is_stateful: - return op - else: - return res - else: + Raises: + RuntimeError: On some errors. + TypeError: On some errors. + ValueError: On some errors. + """ + output_structure, is_stateful, op, outputs = _apply_op_helper( + op_type_name, name, **keywords) + if output_structure: + res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) + if isinstance(res, list) and not res and is_stateful: return op + else: + return res + else: + return op - def _apply_op_helper(self, op_type_name, name=None, **keywords): - """Implementation of apply_op that returns output_structure, op.""" - op_info = self._ops.get(op_type_name, None) - if op_info is None: - raise RuntimeError("Unrecognized Op name " + op_type_name) - op_def = op_info.op_def - # Determine the graph context. - try: - # Need to flatten all the arguments into a list. - # pylint: disable=protected-access - g = ops._get_graph_from_inputs(_Flatten(keywords.values())) - # pylint: enable=protected-access - except AssertionError as e: - raise RuntimeError( - "Cannot determine graph for Op '%s' due to: %s" - % (op_type_name, e.message)) +def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name + """Implementation of apply_op that returns output_structure, op.""" + op_def = op_def_registry.get(op_type_name) + if op_def is None: + raise RuntimeError("Unrecognized Op name " + op_type_name) - # Default name if not specified. - if name is None: - name = op_type_name + # Determine the graph context. + try: + # Need to flatten all the arguments into a list. + # pylint: disable=protected-access + g = ops._get_graph_from_inputs(_Flatten(keywords.values())) + # pylint: enable=protected-access + except AssertionError as e: + raise RuntimeError( + "Cannot determine graph for Op '%s' due to: %s" + % (op_type_name, e.message)) - # Check for deprecation - deprecation_version = op_def.deprecation.version - if deprecation_version: - producer = g.graph_def_versions.producer - if producer >= deprecation_version: - raise NotImplementedError( - ("Op %s is not available in GraphDef version %d. " - "It has been removed in version %d. %s.") % - (op_type_name, producer, deprecation_version, - op_def.deprecation.explanation)) + # Default name if not specified. + if name is None: + name = op_type_name - # Fill in the list of default types for all "type" attrs. This - # will be used to choose a preferred dtype to convert to in the - # absence of input type information. - # - # TODO(b/31302892): Currently the defaults don't work in the right - # way if you have two inputs, one of whose type resolution depends - # on the other. Handling this will require restructuring this code - # significantly. - default_type_attr_map = {} - for attr_def in op_def.attr: - if attr_def.type != "type": - continue - key = attr_def.name - if attr_def.HasField("default_value"): - default_type_attr_map[key] = dtypes.as_dtype( - attr_def.default_value.type) + # Check for deprecation + deprecation_version = op_def.deprecation.version + if deprecation_version: + producer = g.graph_def_versions.producer + if producer >= deprecation_version: + raise NotImplementedError( + ("Op %s is not available in GraphDef version %d. " + "It has been removed in version %d. %s.") % + (op_type_name, producer, deprecation_version, + op_def.deprecation.explanation)) - # Requires that op_def has passed validation (using the C++ - # ValidateOpDef() from ../framework/op_def_util.h). - attrs = {} - inputs = [] - input_types = [] - with g.as_default(), ops.name_scope(name) as scope: + # Fill in the list of default types for all "type" attrs. This + # will be used to choose a preferred dtype to convert to in the + # absence of input type information. + # + # TODO(b/31302892): Currently the defaults don't work in the right + # way if you have two inputs, one of whose type resolution depends + # on the other. Handling this will require restructuring this code + # significantly. + default_type_attr_map = {} + for attr_def in op_def.attr: + if attr_def.type != "type": + continue + key = attr_def.name + if attr_def.HasField("default_value"): + default_type_attr_map[key] = dtypes.as_dtype( + attr_def.default_value.type) - # Perform input type inference - inferred_from = {} - for input_arg in op_def.input_arg: - input_name = input_arg.name - if input_name in keywords: - values = keywords.pop(input_name) - elif input_name + "_" in keywords: - # Handle the case where the name is a keyword or built-in - # for Python so we use the name + _ instead. - input_name += "_" - values = keywords.pop(input_name) - else: - raise TypeError("No argument for input " + input_name) + # Requires that op_def has passed validation (using the C++ + # ValidateOpDef() from ../framework/op_def_util.h). + attrs = {} + inputs = [] + input_types = [] + with g.as_default(), ops.name_scope(name) as scope: - # Goals: - # * Convert values to Tensors if it contains constants. - # * Verify that values is a list if that matches the input_arg's - # type. - # * If the input_arg's type is determined by attrs, either set - # those attrs and validate those attr values are legal (if - # they have not yet been set) or validate the input matches - # the type indicated by the attrs (if they have already been - # inferred via an earlier input). - # * If the input_arg has an explicit type, make sure the input - # conforms. + # Perform input type inference + inferred_from = {} + for input_arg in op_def.input_arg: + input_name = input_arg.name + if input_name in keywords: + values = keywords.pop(input_name) + elif input_name + "_" in keywords: + # Handle the case where the name is a keyword or built-in + # for Python so we use the name + _ instead. + input_name += "_" + values = keywords.pop(input_name) + else: + raise TypeError("No argument for input " + input_name) - if _IsListParameter(input_arg): - if not _IsListValue(values): - raise TypeError( - "Expected list for '%s' argument to '%s' Op, not %s." % - (input_name, op_type_name, values)) - # In cases where we expect all elements of the list to have the - # same dtype, try to cast non-Tensor elements to that type. - dtype = None - default_dtype = None - if input_arg.type != types_pb2.DT_INVALID: - dtype = input_arg.type - elif input_arg.number_attr: - if input_arg.type_attr in attrs: - dtype = attrs[input_arg.type_attr] - else: - for t in values: - if isinstance(t, ops.Tensor): - dtype = t.dtype - break + # Goals: + # * Convert values to Tensors if it contains constants. + # * Verify that values is a list if that matches the input_arg's + # type. + # * If the input_arg's type is determined by attrs, either set + # those attrs and validate those attr values are legal (if + # they have not yet been set) or validate the input matches + # the type indicated by the attrs (if they have already been + # inferred via an earlier input). + # * If the input_arg has an explicit type, make sure the input + # conforms. - # dtype still not found, prefer using the default dtype - # from the attr. - if dtype is None and input_arg.type_attr in default_type_attr_map: - default_dtype = default_type_attr_map[input_arg.type_attr] - - try: - if not input_arg.is_ref and dtype: - dtype = dtypes.as_dtype(dtype).base_dtype - values = ops.internal_convert_n_to_tensor( - values, - name=input_arg.name, - dtype=dtype if dtype else None, - preferred_dtype=default_dtype, - as_ref=input_arg.is_ref) - if input_arg.number_attr and len( - set(v.dtype.base_dtype for v in values)) > 1: - raise TypeError() # All types should match. - except (TypeError, ValueError): - # What types does the conversion function think values have? - observed_types = [] - for value in values: - try: - converted_value = ops.internal_convert_to_tensor( - value, as_ref=input_arg.is_ref) - observed_types.append(converted_value.dtype.base_dtype.name) - except (TypeError, ValueError): - observed_types.append("") - observed = ", ".join(observed_types) - - prefix = ( - "Tensors in list passed to '%s' of '%s' Op have types [%s]" % - (input_name, op_type_name, observed)) - if input_arg.number_attr: - if input_arg.type != types_pb2.DT_INVALID: - raise TypeError("%s that do not match expected type %s." % - (prefix, dtype.name)) - elif input_arg.type_attr in attrs: - raise TypeError("%s that do not match type %s inferred from " - "earlier arguments." % - (prefix, dtype.name)) - else: - raise TypeError("%s that don't all match." % prefix) - else: - raise TypeError( - "%s that are invalid. Tensors: %s" % (prefix, values)) - - types = [x.dtype for x in values] - inputs.extend(values) - else: - # In cases where we have an expected type, try to convert non-Tensor - # arguments to that type. - dtype = None - default_dtype = None - if input_arg.type != types_pb2.DT_INVALID: - dtype = input_arg.type - elif input_arg.type_attr in attrs: + if _IsListParameter(input_arg): + if not _IsListValue(values): + raise TypeError( + "Expected list for '%s' argument to '%s' Op, not %s." % + (input_name, op_type_name, values)) + # In cases where we expect all elements of the list to have the + # same dtype, try to cast non-Tensor elements to that type. + dtype = None + default_dtype = None + if input_arg.type != types_pb2.DT_INVALID: + dtype = input_arg.type + elif input_arg.number_attr: + if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] - elif input_arg.type_attr in default_type_attr_map: - # The dtype could not be inferred solely from the inputs, - # so we prefer the attr's default, so code that adds a new attr - # with a default is backwards compatible. + else: + for t in values: + if isinstance(t, ops.Tensor): + dtype = t.dtype + break + + # dtype still not found, prefer using the default dtype + # from the attr. + if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] - try: - values = ops.internal_convert_to_tensor( - values, - name=input_arg.name, - dtype=dtype, - as_ref=input_arg.is_ref, - preferred_dtype=default_dtype) - except TypeError as err: - if dtype is None: - raise err - else: - raise TypeError( - "Expected %s passed to parameter '%s' of op '%s', got %s of " - "type '%s' instead. Error: %s" % - (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, - repr(values), type(values).__name__, err)) - except ValueError: - # What type does convert_to_tensor think it has? + try: + if not input_arg.is_ref and dtype: + dtype = dtypes.as_dtype(dtype).base_dtype + values = ops.internal_convert_n_to_tensor( + values, + name=input_arg.name, + dtype=dtype if dtype else None, + preferred_dtype=default_dtype, + as_ref=input_arg.is_ref) + if input_arg.number_attr and len( + set(v.dtype.base_dtype for v in values)) > 1: + raise TypeError() # All types should match. + except (TypeError, ValueError): + # What types does the conversion function think values have? + observed_types = [] + for value in values: try: - observed = ops.internal_convert_to_tensor( - values, as_ref=input_arg.is_ref).dtype.name - except ValueError as err: - raise ValueError( - "Tried to convert '%s' to a tensor and failed. Error: %s" % - (input_name, err)) - prefix = ("Input '%s' of '%s' Op has type %s that does not match" % - (input_name, op_type_name, observed)) + converted_value = ops.internal_convert_to_tensor( + value, as_ref=input_arg.is_ref) + observed_types.append(converted_value.dtype.base_dtype.name) + except (TypeError, ValueError): + observed_types.append("") + observed = ", ".join(observed_types) + + prefix = ( + "Tensors in list passed to '%s' of '%s' Op have types [%s]" % + (input_name, op_type_name, observed)) + if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: - raise TypeError("%s expected type of %s." % - (prefix, dtypes.as_dtype(input_arg.type).name)) + raise TypeError("%s that do not match expected type %s." % + (prefix, dtype.name)) + elif input_arg.type_attr in attrs: + raise TypeError("%s that do not match type %s inferred from " + "earlier arguments." % + (prefix, dtype.name)) else: - # Update the maps with the default, if needed. - k = input_arg.type_attr - if k in default_type_attr_map: - if k not in attrs: - attrs[k] = default_type_attr_map[k] - if k not in inferred_from: - inferred_from[k] = "Default in OpDef" - - raise TypeError( - "%s type %s of argument '%s'." % - (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, - inferred_from[input_arg.type_attr])) - - types = [values.dtype] - inputs.append(values) - base_types = [x.base_dtype for x in types] - - if input_arg.number_attr: - # * or * - if input_arg.number_attr in attrs: - if len(values) != attrs[input_arg.number_attr]: - raise ValueError( - "List argument '%s' to '%s' Op with length %d must match " - "length %d of argument '%s'." % - (input_name, op_type_name, len(values), - attrs[input_arg.number_attr], - inferred_from[input_arg.number_attr])) + raise TypeError("%s that don't all match." % prefix) else: - attrs[input_arg.number_attr] = len(values) - inferred_from[input_arg.number_attr] = input_name - num_attr = _Attr(op_def, input_arg.number_attr) - if num_attr.has_minimum and len(values) < num_attr.minimum: - raise ValueError( - "List argument '%s' to '%s' Op with length %d shorter " - "than minimum length %d." % - (input_name, op_type_name, len(values), num_attr.minimum)) - # All tensors must have the same base type. - if any(bt != base_types[0] for bt in base_types): raise TypeError( - "All tensors passed to '%s' of '%s' Op " - "must have the same type." % - (input_name, op_type_name)) + "%s that are invalid. Tensors: %s" % (prefix, values)) + + types = [x.dtype for x in values] + inputs.extend(values) + else: + # In cases where we have an expected type, try to convert non-Tensor + # arguments to that type. + dtype = None + default_dtype = None + if input_arg.type != types_pb2.DT_INVALID: + dtype = input_arg.type + elif input_arg.type_attr in attrs: + dtype = attrs[input_arg.type_attr] + elif input_arg.type_attr in default_type_attr_map: + # The dtype could not be inferred solely from the inputs, + # so we prefer the attr's default, so code that adds a new attr + # with a default is backwards compatible. + default_dtype = default_type_attr_map[input_arg.type_attr] + + try: + values = ops.internal_convert_to_tensor( + values, + name=input_arg.name, + dtype=dtype, + as_ref=input_arg.is_ref, + preferred_dtype=default_dtype) + except TypeError as err: + if dtype is None: + raise err + else: + raise TypeError( + "Expected %s passed to parameter '%s' of op '%s', got %s of " + "type '%s' instead. Error: %s" % + (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, + repr(values), type(values).__name__, err)) + except ValueError: + # What type does convert_to_tensor think it has? + try: + observed = ops.internal_convert_to_tensor( + values, as_ref=input_arg.is_ref).dtype.name + except ValueError as err: + raise ValueError( + "Tried to convert '%s' to a tensor and failed. Error: %s" % + (input_name, err)) + prefix = ("Input '%s' of '%s' Op has type %s that does not match" % + (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: - # * case - if base_types and base_types[0] != input_arg.type: - assert False, "Unreachable" - elif input_arg.type_attr in attrs: - # * case, where already - # has an inferred value. - if base_types and base_types[0] != attrs[input_arg.type_attr]: - assert False, "Unreachable" + raise TypeError("%s expected type of %s." % + (prefix, dtypes.as_dtype(input_arg.type).name)) else: - # * case, where we are now setting - # the based on this input - if not base_types: - raise TypeError( - "Don't know how to infer type variable from empty input " - "list passed to input '%s' of '%s' Op." % - (input_name, op_type_name)) - attrs[input_arg.type_attr] = base_types[0] - inferred_from[input_arg.type_attr] = input_name - type_attr = _Attr(op_def, input_arg.type_attr) - _SatisfiesTypeConstraint(base_types[0], type_attr, - param_name=input_name) - elif input_arg.type_attr: - # - attr_value = base_types[0] - if input_arg.type_attr in attrs: - if attrs[input_arg.type_attr] != attr_value: - raise TypeError( - "Input '%s' of '%s' Op has type %s that does not " - "match type %s of argument '%s'." % - (input_name, op_type_name, dtypes.as_dtype(attr_value).name, - dtypes.as_dtype(attrs[input_arg.type_attr]).name, - inferred_from[input_arg.type_attr])) - else: - for base_type in base_types: - _SatisfiesTypeConstraint(base_type, - _Attr(op_def, input_arg.type_attr), - param_name=input_name) - attrs[input_arg.type_attr] = attr_value - inferred_from[input_arg.type_attr] = input_name - elif input_arg.type_list_attr: - # - attr_value = base_types - if input_arg.type_list_attr in attrs: - if attrs[input_arg.type_list_attr] != attr_value: - raise TypeError( - "Input '%s' of '%s' Op has type list of %s that does not " - "match type list %s of argument '%s'." % - (input_name, op_type_name, - ", ".join(dtypes.as_dtype(x).name for x in attr_value), - ", ".join(dtypes.as_dtype(x).name - for x in attrs[input_arg.type_list_attr]), - inferred_from[input_arg.type_list_attr])) - else: - for base_type in base_types: - _SatisfiesTypeConstraint(base_type, - _Attr(op_def, input_arg.type_list_attr), - param_name=input_name) - attrs[input_arg.type_list_attr] = attr_value - inferred_from[input_arg.type_list_attr] = input_name + # Update the maps with the default, if needed. + k = input_arg.type_attr + if k in default_type_attr_map: + if k not in attrs: + attrs[k] = default_type_attr_map[k] + if k not in inferred_from: + inferred_from[k] = "Default in OpDef" + + raise TypeError( + "%s type %s of argument '%s'." % + (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, + inferred_from[input_arg.type_attr])) + + types = [values.dtype] + inputs.append(values) + base_types = [x.base_dtype for x in types] + + if input_arg.number_attr: + # * or * + if input_arg.number_attr in attrs: + if len(values) != attrs[input_arg.number_attr]: + raise ValueError( + "List argument '%s' to '%s' Op with length %d must match " + "length %d of argument '%s'." % + (input_name, op_type_name, len(values), + attrs[input_arg.number_attr], + inferred_from[input_arg.number_attr])) else: - # single Tensor with specified type - if base_types[0] != input_arg.type: + attrs[input_arg.number_attr] = len(values) + inferred_from[input_arg.number_attr] = input_name + num_attr = _Attr(op_def, input_arg.number_attr) + if num_attr.has_minimum and len(values) < num_attr.minimum: + raise ValueError( + "List argument '%s' to '%s' Op with length %d shorter " + "than minimum length %d." % + (input_name, op_type_name, len(values), num_attr.minimum)) + # All tensors must have the same base type. + if any(bt != base_types[0] for bt in base_types): + raise TypeError( + "All tensors passed to '%s' of '%s' Op " + "must have the same type." % + (input_name, op_type_name)) + if input_arg.type != types_pb2.DT_INVALID: + # * case + if base_types and base_types[0] != input_arg.type: + assert False, "Unreachable" + elif input_arg.type_attr in attrs: + # * case, where already + # has an inferred value. + if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" - - if input_arg.is_ref: - if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access - raise TypeError( - ("'%s' Op requires that input '%s' be a mutable tensor " - "(e.g.: a tf.Variable)") % (op_type_name, input_name)) - input_types.extend(types) else: - input_types.extend(base_types) - - # Process remaining attrs - for attr in op_def.attr: - # Skip attrs that have already had their values inferred - if attr.name in attrs: - if attr.name in keywords: + # * case, where we are now setting + # the based on this input + if not base_types: raise TypeError( - "Should not specify value for inferred attr '%s'." % attr.name) - continue + "Don't know how to infer type variable from empty input " + "list passed to input '%s' of '%s' Op." % + (input_name, op_type_name)) + attrs[input_arg.type_attr] = base_types[0] + inferred_from[input_arg.type_attr] = input_name + type_attr = _Attr(op_def, input_arg.type_attr) + _SatisfiesTypeConstraint(base_types[0], type_attr, + param_name=input_name) + elif input_arg.type_attr: + # + attr_value = base_types[0] + if input_arg.type_attr in attrs: + if attrs[input_arg.type_attr] != attr_value: + raise TypeError( + "Input '%s' of '%s' Op has type %s that does not " + "match type %s of argument '%s'." % + (input_name, op_type_name, dtypes.as_dtype(attr_value).name, + dtypes.as_dtype(attrs[input_arg.type_attr]).name, + inferred_from[input_arg.type_attr])) + else: + for base_type in base_types: + _SatisfiesTypeConstraint(base_type, + _Attr(op_def, input_arg.type_attr), + param_name=input_name) + attrs[input_arg.type_attr] = attr_value + inferred_from[input_arg.type_attr] = input_name + elif input_arg.type_list_attr: + # + attr_value = base_types + if input_arg.type_list_attr in attrs: + if attrs[input_arg.type_list_attr] != attr_value: + raise TypeError( + "Input '%s' of '%s' Op has type list of %s that does not " + "match type list %s of argument '%s'." % + (input_name, op_type_name, + ", ".join(dtypes.as_dtype(x).name for x in attr_value), + ", ".join(dtypes.as_dtype(x).name + for x in attrs[input_arg.type_list_attr]), + inferred_from[input_arg.type_list_attr])) + else: + for base_type in base_types: + _SatisfiesTypeConstraint(base_type, + _Attr(op_def, input_arg.type_list_attr), + param_name=input_name) + attrs[input_arg.type_list_attr] = attr_value + inferred_from[input_arg.type_list_attr] = input_name + else: + # single Tensor with specified type + if base_types[0] != input_arg.type: + assert False, "Unreachable" + + if input_arg.is_ref: + if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access + raise TypeError( + ("'%s' Op requires that input '%s' be a mutable tensor " + "(e.g.: a tf.Variable)") % (op_type_name, input_name)) + input_types.extend(types) + else: + input_types.extend(base_types) + + # Process remaining attrs + for attr in op_def.attr: + # Skip attrs that have already had their values inferred + if attr.name in attrs: if attr.name in keywords: - attrs[attr.name] = keywords.pop(attr.name) - elif attr.name + "_" in keywords: - # Attrs whose names match Python keywords have an extra '_' - # appended, so we must check for that as well. - attrs[attr.name] = keywords.pop(attr.name + "_") - else: - raise TypeError("No argument for attr " + attr.name) + raise TypeError( + "Should not specify value for inferred attr '%s'." % attr.name) + continue + if attr.name in keywords: + attrs[attr.name] = keywords.pop(attr.name) + elif attr.name + "_" in keywords: + # Attrs whose names match Python keywords have an extra '_' + # appended, so we must check for that as well. + attrs[attr.name] = keywords.pop(attr.name + "_") + else: + raise TypeError("No argument for attr " + attr.name) - # Convert attr values to AttrValue protos. - attr_protos = {} - for attr_def in op_def.attr: - key = attr_def.name - value = attrs[key] - attr_value = attr_value_pb2.AttrValue() - if attr_def.HasField("default_value") and value is None: - attr_value.CopyFrom(attr_def.default_value) - attr_protos[key] = attr_value - continue - if attr_def.type.startswith("list("): - if not _IsListValue(value): - raise TypeError("Expected list for attr " + key) - if attr_def.has_minimum: - if len(value) < attr_def.minimum: - raise ValueError("Attr '%s' of '%s' Op passed list of length %d " - "less than minimum %d." % - (key, op_type_name, len(value), - attr_def.minimum)) - attr_value.list.SetInParent() - if attr_def.type == "string": - attr_value.s = _MakeStr(value, key) - if attr_def.HasField("allowed_values"): - if attr_value.s not in attr_def.allowed_values.list.s: + # Convert attr values to AttrValue protos. + attr_protos = {} + for attr_def in op_def.attr: + key = attr_def.name + value = attrs[key] + attr_value = attr_value_pb2.AttrValue() + if attr_def.HasField("default_value") and value is None: + attr_value.CopyFrom(attr_def.default_value) + attr_protos[key] = attr_value + continue + if attr_def.type.startswith("list("): + if not _IsListValue(value): + raise TypeError("Expected list for attr " + key) + if attr_def.has_minimum: + if len(value) < attr_def.minimum: + raise ValueError("Attr '%s' of '%s' Op passed list of length %d " + "less than minimum %d." % + (key, op_type_name, len(value), + attr_def.minimum)) + attr_value.list.SetInParent() + if attr_def.type == "string": + attr_value.s = _MakeStr(value, key) + if attr_def.HasField("allowed_values"): + if attr_value.s not in attr_def.allowed_values.list.s: + raise ValueError( + "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % + (key, op_type_name, compat.as_text(attr_value.s), + '", "'.join(map(compat.as_text, + attr_def.allowed_values.list.s)))) + elif attr_def.type == "list(string)": + attr_value.list.s.extend([_MakeStr(x, key) for x in value]) + if attr_def.HasField("allowed_values"): + for x in attr_value.list.s: + if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(attr_value.s), + (key, op_type_name, compat.as_text(x), '", "'.join(map(compat.as_text, attr_def.allowed_values.list.s)))) - elif attr_def.type == "list(string)": - attr_value.list.s.extend([_MakeStr(x, key) for x in value]) - if attr_def.HasField("allowed_values"): - for x in attr_value.list.s: - if x not in attr_def.allowed_values.list.s: - raise ValueError( - "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, compat.as_text(x), - '", "'.join(map(compat.as_text, - attr_def.allowed_values.list.s)))) - elif attr_def.type == "int": - attr_value.i = _MakeInt(value, key) - if attr_def.has_minimum: - if attr_value.i < attr_def.minimum: - raise ValueError( - "Attr '%s' of '%s' Op passed %d less than minimum %d." % - (key, op_type_name, attr_value.i, attr_def.minimum)) - elif attr_def.type == "list(int)": - attr_value.list.i.extend([_MakeInt(x, key) for x in value]) - elif attr_def.type == "float": - attr_value.f = _MakeFloat(value, key) - elif attr_def.type == "list(float)": - attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) - elif attr_def.type == "bool": - attr_value.b = _MakeBool(value, key) - elif attr_def.type == "list(bool)": - attr_value.list.b.extend([_MakeBool(x, key) for x in value]) - elif attr_def.type == "type": - attr_value.type = _MakeType(value, attr_def) - elif attr_def.type == "list(type)": - attr_value.list.type.extend( - [_MakeType(x, attr_def) for x in value]) - elif attr_def.type == "shape": - attr_value.shape.CopyFrom(_MakeShape(value, key)) - elif attr_def.type == "list(shape)": - attr_value.list.shape.extend( - [_MakeShape(x, key) for x in value]) - elif attr_def.type == "tensor": - attr_value.tensor.CopyFrom(_MakeTensor(value, key)) - elif attr_def.type == "list(tensor)": - attr_value.list.tensor.extend( - [_MakeTensor(x, key) for x in value]) - elif attr_def.type == "func": - attr_value.func.CopyFrom(_MakeFunc(value, key)) - elif attr_def.type == "list(func)": - attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) - else: - raise TypeError("Unrecognized Attr type " + attr_def.type) + elif attr_def.type == "int": + attr_value.i = _MakeInt(value, key) + if attr_def.has_minimum: + if attr_value.i < attr_def.minimum: + raise ValueError( + "Attr '%s' of '%s' Op passed %d less than minimum %d." % + (key, op_type_name, attr_value.i, attr_def.minimum)) + elif attr_def.type == "list(int)": + attr_value.list.i.extend([_MakeInt(x, key) for x in value]) + elif attr_def.type == "float": + attr_value.f = _MakeFloat(value, key) + elif attr_def.type == "list(float)": + attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) + elif attr_def.type == "bool": + attr_value.b = _MakeBool(value, key) + elif attr_def.type == "list(bool)": + attr_value.list.b.extend([_MakeBool(x, key) for x in value]) + elif attr_def.type == "type": + attr_value.type = _MakeType(value, attr_def) + elif attr_def.type == "list(type)": + attr_value.list.type.extend( + [_MakeType(x, attr_def) for x in value]) + elif attr_def.type == "shape": + attr_value.shape.CopyFrom(_MakeShape(value, key)) + elif attr_def.type == "list(shape)": + attr_value.list.shape.extend( + [_MakeShape(x, key) for x in value]) + elif attr_def.type == "tensor": + attr_value.tensor.CopyFrom(_MakeTensor(value, key)) + elif attr_def.type == "list(tensor)": + attr_value.list.tensor.extend( + [_MakeTensor(x, key) for x in value]) + elif attr_def.type == "func": + attr_value.func.CopyFrom(_MakeFunc(value, key)) + elif attr_def.type == "list(func)": + attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) + else: + raise TypeError("Unrecognized Attr type " + attr_def.type) - attr_protos[key] = attr_value - del attrs # attrs is no longer authoritative, use attr_protos instead + attr_protos[key] = attr_value + del attrs # attrs is no longer authoritative, use attr_protos instead - # Determine output types (possibly using attrs) - output_structure = [] - for arg in op_def.output_arg: - if arg.number_attr: - n = _AttrValue(attr_protos, arg.number_attr).i - output_structure.append(n) - elif arg.type_attr: - t = _AttrValue(attr_protos, arg.type_attr) - output_structure.append(None) - elif arg.type_list_attr: - t = _AttrValue(attr_protos, arg.type_list_attr) - output_structure.append(len(t.list.type)) - else: - output_structure.append(None) + # Determine output types (possibly using attrs) + output_structure = [] + for arg in op_def.output_arg: + if arg.number_attr: + n = _AttrValue(attr_protos, arg.number_attr).i + output_structure.append(n) + elif arg.type_attr: + t = _AttrValue(attr_protos, arg.type_attr) + output_structure.append(None) + elif arg.type_list_attr: + t = _AttrValue(attr_protos, arg.type_list_attr) + output_structure.append(len(t.list.type)) + else: + output_structure.append(None) - if keywords: - raise TypeError("apply_op() got unexpected keyword arguments: " + - ", ".join(sorted(keywords.keys()))) + if keywords: + raise TypeError("apply_op() got unexpected keyword arguments: " + + ", ".join(sorted(keywords.keys()))) - # NOTE(mrry): We add an explicit colocation constraint between - # the newly created op and any of its reference-typed inputs. - must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) - if arg.is_ref] - with _MaybeColocateWith(must_colocate_inputs): - # Add Op to graph - # pylint: disable=protected-access - op = g._create_op_internal(op_type_name, inputs, dtypes=None, - name=scope, input_types=input_types, - attrs=attr_protos, op_def=op_def) + # NOTE(mrry): We add an explicit colocation constraint between + # the newly created op and any of its reference-typed inputs. + must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) + if arg.is_ref] + with _MaybeColocateWith(must_colocate_inputs): + # Add Op to graph + # pylint: disable=protected-access + op = g._create_op_internal(op_type_name, inputs, dtypes=None, + name=scope, input_types=input_types, + attrs=attr_protos, op_def=op_def) - # `outputs` is returned as a separate return value so that the output - # tensors can the `op` per se can be decoupled so that the - # `op_callbacks` can function properly. See framework/op_callbacks.py - # for more details. - outputs = op.outputs - # Conditionally invoke tfdbg v2's op callback(s). - if op_callbacks.should_invoke_op_callbacks(): - callback_outputs = op_callbacks.invoke_op_callbacks( - op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), - op_name=op.name, graph=g) - if callback_outputs is not None: - outputs = callback_outputs + # `outputs` is returned as a separate return value so that the output + # tensors can the `op` per se can be decoupled so that the + # `op_callbacks` can function properly. See framework/op_callbacks.py + # for more details. + outputs = op.outputs + # Conditionally invoke tfdbg v2's op callback(s). + if op_callbacks.should_invoke_op_callbacks(): + callback_outputs = op_callbacks.invoke_op_callbacks( + op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), + op_name=op.name, graph=g) + if callback_outputs is not None: + outputs = callback_outputs - return output_structure, op_def.is_stateful, op, outputs - -# pylint: enable=invalid-name + return output_structure, op_def.is_stateful, op, outputs diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 71d708dd89e..83990d7648b 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -19,120 +19,46 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from google.protobuf import text_format - -from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import op_def_library from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest -def _unknown_shape(op): - """Shape function for use with ops whose output shapes are unknown.""" - return [tensor_shape.unknown_shape() for _ in op.outputs] - - class OpDefLibraryTest(test_util.TensorFlowTestCase): - def setUp(self): - self._lib = test_ops._op_def_lib - - def _add_op(self, ascii): # pylint: disable=redefined-builtin - op_def = op_def_pb2.OpDef() - text_format.Merge(ascii, op_def) - self._lib.add_op(op_def) - def Tensor(self, t, name="in"): - return self._lib.apply_op("OutT", T=t, name=name) + return op_def_library.apply_op("OutT", T=t, name=name) def testNoRegisteredOpFails(self): with self.assertRaises(RuntimeError) as cm: - self._lib.apply_op("unknown") + op_def_library.apply_op("unknown") self.assertEqual(str(cm.exception), "Unrecognized Op name unknown") - def testAddOpValidation(self): - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'MissingTypeAttr' " - "input_arg { name: 'a' type_attr: 'T' } ") - self.assertEqual(str(cm.exception), - "Inconsistent OpDef for 'MissingTypeAttr', " - "missing attr 'T'") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'BadTypeAttr' " - "output_arg { name: 'a' type_attr: 'T' } " - "attr { name: 'T' type: 'int' }") - self.assertEqual( - str(cm.exception), - "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'MissingNumberAttr' " - "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ") - self.assertEqual(str(cm.exception), - "Inconsistent OpDef for 'MissingNumberAttr', " - "missing attr 'N'") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'BadNumberAttr' " - "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " - "attr { name: 'N' type: 'type' }") - self.assertEqual( - str(cm.exception), - "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'TwoTypesA' " - "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } " - "attr { name: 'T' type: 'type' }") - self.assertEqual(str(cm.exception), - "Arg 'a' of 'TwoTypesA' must have one type field not 2") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'TwoTypesB' " - "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } " - "attr { name: 'T' type: 'list(type)' }") - self.assertEqual(str(cm.exception), - "Arg 'a' of 'TwoTypesB' must have one type field not 2") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'ThreeTypes' " - "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' " - "type_list_attr: 'U' } " - "attr { name: 'T' type: 'type' } " - "attr { name: 'U' type: 'list(type)' }") - self.assertEqual(str(cm.exception), - "Arg 'a' of 'ThreeTypes' must have one type field not 3") - - with self.assertRaises(TypeError) as cm: - self._add_op("name: 'NoTypes' output_arg { name: 'a' } ") - self.assertEqual(str(cm.exception), - "Arg 'a' of 'NoTypes' must have one type field not 0") - def testSimple(self): with ops.Graph().as_default(): - out = self._lib.apply_op("Simple", a=3) + out = op_def_library.apply_op("Simple", a=3) self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'Simple' op: 'Simple' input: 'Simple/a' """, out.op.node_def) - out = self._lib.apply_op("Simple", a=4) + out = op_def_library.apply_op("Simple", a=4) self.assertProtoEquals(""" name: 'Simple_1' op: 'Simple' input: 'Simple_1/a' """, out.op.node_def) - out = self._lib.apply_op("Simple", a=5, name="named") + out = op_def_library.apply_op("Simple", a=5, name="named") self.assertProtoEquals(""" name: 'named' op: 'Simple' input: 'named/a' """, out.op.node_def) - out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d") + out = op_def_library.apply_op( + "Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d") self.assertProtoEquals(""" name: 'two_d' op: 'Simple' input: 'two_d/a' """, out.op.node_def) @@ -140,69 +66,70 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testSimpleFailures(self): with ops.Graph().as_default(): with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a="Bad string") + op_def_library.apply_op("Simple", a="Bad string") self.assertTrue( "Expected int32 passed to parameter 'a' of op 'Simple', " "got 'Bad string' of type 'str' instead." in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=self.Tensor(dtypes.string)) + op_def_library.apply_op("Simple", a=self.Tensor(dtypes.string)) self.assertTrue( "Input 'a' of 'Simple' Op has type string " "that does not match expected type of int32." in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=6, extra="bogus") + op_def_library.apply_op("Simple", a=6, extra="bogus") self.assertTrue( "apply_op() got unexpected keyword arguments: extra" in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus") + op_def_library.apply_op( + "Simple", a=6, extra1="bogus", extra2="also_bogus") self.assertTrue( "apply_op() got unexpected keyword arguments: extra1, " "extra2" in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple") + op_def_library.apply_op("Simple") self.assertTrue( "No argument for input a" in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", wrong=7) + op_def_library.apply_op("Simple", wrong=7) self.assertTrue( "No argument for input a" in str(cm.exception)) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a={"label": 1}) + op_def_library.apply_op("Simple", a={"label": 1}) self.assertTrue( "Expected int32 passed to parameter 'a' of op 'Simple', " "got {'label': 1} of type 'dict' instead." in str(cm.exception)) def testReservedInput(self): with ops.Graph().as_default(): - op = self._lib.apply_op("ReservedInput", input_=7, name="x") + op = op_def_library.apply_op("ReservedInput", input_=7, name="x") self.assertProtoEquals(""" name: 'x' op: 'ReservedInput' input: 'x/input' """, op.node_def) def testPolymorphic(self): with ops.Graph().as_default(): - out = self._lib.apply_op("Polymorphic", a=7, name="p") + out = op_def_library.apply_op("Polymorphic", a=7, name="p") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'Polymorphic' input: 'p/a' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) - out = self._lib.apply_op("Polymorphic", a="s", name="q") + out = op_def_library.apply_op("Polymorphic", a="s", name="q") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'Polymorphic' input: 'q/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) - out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r") + out = op_def_library.apply_op("Polymorphic", a=["s", "t", "u"], name="r") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'r' op: 'Polymorphic' input: 'r/a' @@ -210,20 +137,20 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Polymorphic", a="s", T=dtypes.string) + op_def_library.apply_op("Polymorphic", a="s", T=dtypes.string) self.assertEqual(str(cm.exception), "Should not specify value for inferred attr 'T'.") def testPolymorphicOut(self): with ops.Graph().as_default(): - out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p") + out = op_def_library.apply_op("PolymorphicOut", T=dtypes.int32, name="p") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q") + out = op_def_library.apply_op("PolymorphicOut", T=dtypes.bool, name="q") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicOut' @@ -231,25 +158,26 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("PolymorphicOut") + op_def_library.apply_op("PolymorphicOut") self.assertEqual(str(cm.exception), "No argument for attr T") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("PolymorphicOut", T=None) + op_def_library.apply_op("PolymorphicOut", T=None) self.assertEqual(str(cm.exception), "Expected DataType for argument 'T' not None.") def testPolymorphicDefaultOut(self): with ops.Graph().as_default(): - out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p") + out = op_def_library.apply_op("PolymorphicDefaultOut", T=None, name="p") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q") + out = op_def_library.apply_op( + "PolymorphicDefaultOut", T=dtypes.bool, name="q") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicDefaultOut' @@ -258,14 +186,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testBinary(self): with ops.Graph().as_default(): - out = self._lib.apply_op("Binary", a=8, b=9, name="b") + out = op_def_library.apply_op("Binary", a=8, b=9, name="b") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'b' op: 'Binary' input: 'b/a' input: 'b/b' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) - out = self._lib.apply_op("Binary", a="left", b="right", name="c") + out = op_def_library.apply_op("Binary", a="left", b="right", name="c") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'c' op: 'Binary' input: 'c/a' input: 'c/b' @@ -273,23 +201,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out.op.node_def) with self.assertRaises(TypeError): - self._lib.apply_op("Binary", a="left", b=12) + op_def_library.apply_op("Binary", a="left", b=12) with self.assertRaises(TypeError): - self._lib.apply_op("Binary", - a=self.Tensor(dtypes.string), - b=self.Tensor(dtypes.int32)) + op_def_library.apply_op( + "Binary", a=self.Tensor(dtypes.string), b=self.Tensor(dtypes.int32)) def testRestrict(self): with ops.Graph().as_default(): - out = self._lib.apply_op("Restrict", a="foo", name="g") + out = op_def_library.apply_op("Restrict", a="foo", name="g") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'g' op: 'Restrict' input: 'g/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) - out = self._lib.apply_op("Restrict", a=True, name="h") + out = op_def_library.apply_op("Restrict", a=True, name="h") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'h' op: 'Restrict' input: 'h/a' @@ -297,61 +224,59 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Restrict", a=17) + op_def_library.apply_op("Restrict", a=17) self.assertEqual(str(cm.exception), "Value passed to parameter 'a' has DataType int32 " "not in list of allowed values: string, bool") def testTypeList(self): with ops.Graph().as_default(): - op = self._lib.apply_op("TypeList", a=["foo"], name="z") + op = op_def_library.apply_op("TypeList", a=["foo"], name="z") self.assertProtoEquals(""" name: 'z' op: 'TypeList' input: 'z/a_0' attr { key: 'T' value { list { type: DT_STRING } } } """, op.node_def) - op = self._lib.apply_op("TypeList", a=[True, 12], name="y") + op = op_def_library.apply_op("TypeList", a=[True, 12], name="y") self.assertProtoEquals(""" name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1' attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } } """, op.node_def) - op = self._lib.apply_op("TypeList", a=[], name="empty") + op = op_def_library.apply_op("TypeList", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("TypeList", a=17) + op_def_library.apply_op("TypeList", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' " "argument to 'TypeList' Op, not ") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None]) + op_def_library.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None]) self.assertStartsWith(str(cm.exception), "Tensors in list passed to 'a' of 'TypeList' Op " "have types [int32, ]") def testTypeListTwice(self): with ops.Graph().as_default(): - op = self._lib.apply_op("TypeListTwice", - a=["foo", True], - b=["bar", False], - name="z") + op = op_def_library.apply_op( + "TypeListTwice", a=["foo", True], b=["bar", False], name="z") self.assertProtoEquals(""" name: 'z' op: 'TypeListTwice' input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1' attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } } """, op.node_def) - op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty") + op = op_def_library.apply_op("TypeListTwice", a=[], b=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6]) + op_def_library.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6]) self.assertEqual(str(cm.exception), "Input 'b' of 'TypeListTwice' Op has type list of " "string, int32 that does not match type list " @@ -359,16 +284,16 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testOutTypeList(self): with ops.Graph().as_default(): - out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x") + out, = op_def_library.apply_op( + "OutTypeList", T=[dtypes.float32], name="x") self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'x' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_FLOAT } } } """, out.op.node_def) - out1, out2 = self._lib.apply_op("OutTypeList", - T=[dtypes.int32, dtypes.bool], - name="w") + out1, out2 = op_def_library.apply_op( + "OutTypeList", T=[dtypes.int32, dtypes.bool], name="w") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" @@ -376,32 +301,32 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } } """, out1.op.node_def) - out = self._lib.apply_op("OutTypeList", T=[], name="empty") + out = op_def_library.apply_op("OutTypeList", T=[], name="empty") self.assertEqual([], out) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("OutTypeList", T=dtypes.int32) + op_def_library.apply_op("OutTypeList", T=dtypes.int32) self.assertEqual(str(cm.exception), "Expected list for attr T") def testTypeListRestrict(self): with ops.Graph().as_default(): - op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v") + op = op_def_library.apply_op( + "TypeListRestrict", a=["foo", False], name="v") self.assertProtoEquals(""" name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1' attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("TypeListRestrict", a=[True, 12]) + op_def_library.apply_op("TypeListRestrict", a=[True, 12]) self.assertEqual(str(cm.exception), "Value passed to parameter 'a' has DataType int32 " "not in list of allowed values: string, bool") def testOutTypeListRestrict(self): with ops.Graph().as_default(): - out1, out2 = self._lib.apply_op("OutTypeListRestrict", - t=[dtypes.bool, dtypes.string], - name="u") + out1, out2 = op_def_library.apply_op( + "OutTypeListRestrict", t=[dtypes.bool, dtypes.string], name="u") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.string, out2.dtype) self.assertProtoEquals(""" @@ -410,57 +335,58 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("OutTypeListRestrict", - t=[dtypes.string, dtypes.int32]) + op_def_library.apply_op( + "OutTypeListRestrict", t=[dtypes.string, dtypes.int32]) self.assertEqual(str(cm.exception), "Value passed to parameter 't' has DataType int32 " "not in list of allowed values: string, bool") def testAttr(self): with ops.Graph().as_default(): - op = self._lib.apply_op("Attr", a=12, name="t") + op = op_def_library.apply_op("Attr", a=12, name="t") self.assertProtoEquals(""" name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } } """, op.node_def) - op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u") + op = op_def_library.apply_op( + "Attr", a=tensor_shape.Dimension(13), name="u") self.assertProtoEquals(""" name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Attr", a="bad") + op_def_library.apply_op("Attr", a="bad") self.assertEqual(str(cm.exception), "Expected int for argument 'a' not 'bad'.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Attr", a=[12]) + op_def_library.apply_op("Attr", a=[12]) self.assertEqual(str(cm.exception), "Expected int for argument 'a' not [12].") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Attr", a=None) + op_def_library.apply_op("Attr", a=None) self.assertEqual(str(cm.exception), "Expected int for argument 'a' not None.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Attr") + op_def_library.apply_op("Attr") self.assertEqual(str(cm.exception), "No argument for attr a") def testAttrFloat(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrFloat", a=1.2, name="t") + op = op_def_library.apply_op("AttrFloat", a=1.2, name="t") self.assertProtoEquals(""" name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } } """, op.node_def) - op = self._lib.apply_op("AttrFloat", a=12, name="u") + op = op_def_library.apply_op("AttrFloat", a=12, name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("AttrFloat", a="bad") + op_def_library.apply_op("AttrFloat", a="bad") self.assertEqual(str(cm.exception), "Expected float for argument 'a' not 'bad'.") @@ -469,14 +395,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): @function.Defun(dtypes.float32, func_name="MyFn") def fn(x): return 2 + x - op = self._lib.apply_op("FuncAttr", f=fn, name="t") + + op = op_def_library.apply_op("FuncAttr", f=fn, name="t") self.assertProtoEquals(""" name: 't' op: 'FuncAttr' attr { key: 'f' value { func { name: 'MyFn' } } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("FuncAttr", f=3) + op_def_library.apply_op("FuncAttr", f=3) self.assertEqual(str(cm.exception), "Don't know how to convert 3 to a func for argument f") @@ -491,7 +418,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): @function.Defun(dtypes.int32, func_name="MyFn3") def fn3(y): return 2 + y - op = self._lib.apply_op("FuncListAttr", f=[fn1, fn2, fn3], name="t") + + op = op_def_library.apply_op("FuncListAttr", f=[fn1, fn2, fn3], name="t") self.assertProtoEquals(""" name: 't' op: 'FuncListAttr' attr { key: 'f' value { list { func { name: 'MyFn' } @@ -500,90 +428,91 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("FuncListAttr", f=[fn1, 3, fn2]) + op_def_library.apply_op("FuncListAttr", f=[fn1, 3, fn2]) self.assertEqual(str(cm.exception), "Don't know how to convert 3 to a func for argument f") def testAttrBool(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrBool", a=True, name="t") + op = op_def_library.apply_op("AttrBool", a=True, name="t") self.assertProtoEquals(""" name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } } """, op.node_def) - op = self._lib.apply_op("AttrBool", a=False, name="u") + op = op_def_library.apply_op("AttrBool", a=False, name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("AttrBool", a=0) + op_def_library.apply_op("AttrBool", a=0) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("AttrBool", a=1) + op_def_library.apply_op("AttrBool", a=1) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 1.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("AttrBool", a=[]) + op_def_library.apply_op("AttrBool", a=[]) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not [].") def testAttrBoolList(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t") + op = op_def_library.apply_op( + "AttrBoolList", a=[True, False, True], name="t") self.assertProtoEquals(""" name: 't' op: 'AttrBoolList' attr { key: 'a' value { list { b: true b: false b:true } } } """, op.node_def) - op = self._lib.apply_op("AttrBoolList", a=[], name="u") + op = op_def_library.apply_op("AttrBoolList", a=[], name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("AttrBoolList", a=[0]) + op_def_library.apply_op("AttrBoolList", a=[0]) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") def testAttrMin(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrMin", a=12, name="s") + op = op_def_library.apply_op("AttrMin", a=12, name="s") self.assertProtoEquals(""" name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } } """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("AttrMin", a=2) + op_def_library.apply_op("AttrMin", a=2) self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.") def testAttrListMin(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r") + op = op_def_library.apply_op("AttrListMin", a=[1, 2], name="r") self.assertProtoEquals(""" name: 'r' op: 'AttrListMin' attr { key: 'a' value { list { i: 1 i: 2 } } } """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("AttrListMin", a=[17]) + op_def_library.apply_op("AttrListMin", a=[17]) self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrListMin' Op " "passed list of length 1 less than minimum 2.") def testAttrEnum(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrEnum", a="oranges", name="e") + op = op_def_library.apply_op("AttrEnum", a="oranges", name="e") self.assertProtoEquals(""" name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } } """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("AttrEnum", a="invalid") + op_def_library.apply_op("AttrEnum", a="invalid") self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnum\' Op ' 'passed string \'invalid\' not in: ' @@ -591,14 +520,16 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testAttrEnumList(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f") + op = op_def_library.apply_op( + "AttrEnumList", a=["oranges", "apples"], name="f") self.assertProtoEquals(""" name: 'f' op: 'AttrEnumList' attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } } """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"]) + op_def_library.apply_op( + "AttrEnumList", a=["apples", "invalid", "oranges"]) self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnumList\' Op ' 'passed string \'invalid\' not ' @@ -606,20 +537,20 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testAttrShape(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrShape", a=[5], name="s1") + op = op_def_library.apply_op("AttrShape", a=[5], name="s1") self.assertProtoEquals(""" name: 's1' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 5 } } } } """, op.node_def) - op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2") + op = op_def_library.apply_op("AttrShape", a=(4, 3, 2), name="s2") self.assertProtoEquals(""" name: 's2' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } """, op.node_def) - op = self._lib.apply_op( + op = op_def_library.apply_op( "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3") self.assertProtoEquals(""" name: 's3' op: 'AttrShape' @@ -627,7 +558,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape { dim { size: 3 } dim { size: 2 } } } } """, op.node_def) - op = self._lib.apply_op("AttrShape", a=[], name="s4") + op = op_def_library.apply_op("AttrShape", a=[], name="s4") self.assertProtoEquals(""" name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } } """, op.node_def) @@ -635,7 +566,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape = tensor_shape_pb2.TensorShapeProto() shape.dim.add().size = 6 shape.dim.add().size = 3 - op = self._lib.apply_op("AttrShape", a=shape, name="s5") + op = op_def_library.apply_op("AttrShape", a=shape, name="s5") self.assertProtoEquals(""" name: 's5' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } } @@ -644,17 +575,18 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # TODO(josh11b): Re-enable this test once we stop promoting scalars to # shapes. # with self.assertRaises(TypeError) as cm: - # self._lib.apply_op("AttrShape", a=5) + # op_def_library.apply_op("AttrShape", a=5) # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for" # " argument 'a'") with self.assertRaises(TypeError): - self._lib.apply_op("AttrShape", a="ABC") + op_def_library.apply_op("AttrShape", a="ABC") def testAttrShapeList(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl") + op = op_def_library.apply_op( + "AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl") self.assertProtoEquals(""" name: 'sl' op: 'AttrShapeList' attr { key: 'a' value { list { @@ -662,27 +594,28 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } } """, op.node_def) - op = self._lib.apply_op("AttrShapeList", a=[], name="esl") + op = op_def_library.apply_op("AttrShapeList", a=[], name="esl") self.assertProtoEquals(""" name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } } """, op.node_def) def testAttrPartialShape(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrPartialShape", a=[5], name="s1") + op = op_def_library.apply_op("AttrPartialShape", a=[5], name="s1") self.assertProtoEquals(""" name: 's1' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: 5 } } } } """, op.node_def) - op = self._lib.apply_op("AttrPartialShape", a=(4, None, 2), name="s2") + op = op_def_library.apply_op( + "AttrPartialShape", a=(4, None, 2), name="s2") self.assertProtoEquals(""" name: 's2' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: 4 } dim { size: -1 } dim { size: 2 } } } } """, op.node_def) - op = self._lib.apply_op( + op = op_def_library.apply_op( "AttrPartialShape", a=tensor_shape.TensorShape([3, None]), name="s3") self.assertProtoEquals(""" name: 's3' op: 'AttrPartialShape' @@ -690,7 +623,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape { dim { size: 3 } dim { size: -1 } } } } """, op.node_def) - op = self._lib.apply_op("AttrPartialShape", a=[], name="s4") + op = op_def_library.apply_op("AttrPartialShape", a=[], name="s4") self.assertProtoEquals(""" name: 's4' op: 'AttrPartialShape' attr { key: 'a' value { shape { } } } @@ -699,7 +632,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape = tensor_shape_pb2.TensorShapeProto() shape.dim.add().size = -1 shape.dim.add().size = 3 - op = self._lib.apply_op("AttrPartialShape", a=shape, name="s5") + op = op_def_library.apply_op("AttrPartialShape", a=shape, name="s5") self.assertProtoEquals(""" name: 's5' op: 'AttrPartialShape' attr { key: 'a' value { @@ -708,17 +641,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # TODO(ebrevdo): Re-enable once we stop promoting scalars to shapes. # with self.assertRaises(TypeError) as cm: - # self._lib.apply_op("AttrPartialShape", a=5) + # op_def_library.apply_op("AttrPartialShape", a=5) # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for" # " argument 'a'") with self.assertRaises(TypeError): - self._lib.apply_op("AttrPartialShape", a="ABC") + op_def_library.apply_op("AttrPartialShape", a="ABC") def testAttrPartialShapeList(self): with ops.Graph().as_default(): - op = self._lib.apply_op( + op = op_def_library.apply_op( "AttrPartialShapeList", a=[[3, 2], [6, None, 4]], name="sl") self.assertProtoEquals(""" name: 'sl' op: 'AttrPartialShapeList' @@ -727,7 +660,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): shape { dim { size: 6 } dim { size: -1 } dim { size: 4 } } } } } """, op.node_def) - op = self._lib.apply_op("AttrPartialShapeList", a=[], name="esl") + op = op_def_library.apply_op("AttrPartialShapeList", a=[], name="esl") self.assertProtoEquals(""" name: 'esl' op: 'AttrPartialShapeList' attr { key: 'a' value { list { } } } @@ -735,31 +668,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testAttrDefault(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrDefault", a=None, name="d") + op = op_def_library.apply_op("AttrDefault", a=None, name="d") self.assertProtoEquals(""" name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } } """, op.node_def) - op = self._lib.apply_op("AttrDefault", a="kiwi", name="c") + op = op_def_library.apply_op("AttrDefault", a="kiwi", name="c") self.assertProtoEquals(""" name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } } """, op.node_def) def testAttrListDefault(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrListDefault", a=None, name="b") + op = op_def_library.apply_op("AttrListDefault", a=None, name="b") self.assertProtoEquals(""" name: 'b' op: 'AttrListDefault' attr { key: 'a' value { list { i: 5 i: 15 } } } """, op.node_def) - op = self._lib.apply_op("AttrListDefault", a=[3], name="a") + op = op_def_library.apply_op("AttrListDefault", a=[3], name="a") self.assertProtoEquals(""" name: 'a' op: 'AttrListDefault' attr { key: 'a' value { list { i: 3 } } } """, op.node_def) - op = self._lib.apply_op("AttrListDefault", a=[], name="empty") + op = op_def_library.apply_op("AttrListDefault", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'AttrListDefault' attr { key: 'a' value { list { } } } @@ -767,19 +700,19 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testAttrEmptyListDefault(self): with ops.Graph().as_default(): - op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b") + op = op_def_library.apply_op("AttrEmptyListDefault", a=None, name="b") self.assertProtoEquals(""" name: 'b' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { } } } """, op.node_def) - op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a") + op = op_def_library.apply_op("AttrEmptyListDefault", a=[3], name="a") self.assertProtoEquals(""" name: 'a' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { f: 3 } } } """, op.node_def) - op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty") + op = op_def_library.apply_op("AttrEmptyListDefault", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { } } } @@ -787,7 +720,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testReservedAttr(self): with ops.Graph().as_default(): - op = self._lib.apply_op("ReservedAttr", range_=7, name="x") + op = op_def_library.apply_op("ReservedAttr", range_=7, name="x") self.assertProtoEquals(""" name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } } """, op.node_def) @@ -795,7 +728,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testDefaultAttrType(self): with ops.Graph().as_default(): # Give an input whose type has no obvious output type. - op = self._lib.apply_op("AttrTypeDefault", a=[], name="n") + op = op_def_library.apply_op("AttrTypeDefault", a=[], name="n") self.assertProtoEquals(""" name: 'n' op: 'AttrTypeDefault' input: 'n/a' attr { key: 'T' value { type: DT_INT32 } } @@ -803,7 +736,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # Give an input whose type can be inferred as different # than the default. - op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f") + op = op_def_library.apply_op("AttrTypeDefault", a=[1.0], name="f") self.assertProtoEquals(""" name: 'f' op: 'AttrTypeDefault' input: 'f/a' attr { key: 'T' value { type: DT_FLOAT } } @@ -813,7 +746,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): # Give an input whose type can be inferred as different # than the default. - op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n") + op = op_def_library.apply_op( + "AttrListTypeDefault", a=[1.0], b=[2.0], name="n") self.assertProtoEquals(""" name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0' attr { key: 'T' value { type: DT_FLOAT } } @@ -822,13 +756,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNIntsIn(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n") + op = op_def_library.apply_op("NIntsIn", a=[1, 2], name="n") self.assertProtoEquals(""" name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1' attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o") + op = op_def_library.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o") self.assertProtoEquals(""" name: 'o' op: 'NIntsIn' input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4' @@ -836,60 +770,63 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=["foo", "bar"]) + op_def_library.apply_op("NIntsIn", a=["foo", "bar"]) self.assertEqual( str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[string, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", - a=[self.Tensor(dtypes.string), - self.Tensor(dtypes.string)]) + op_def_library.apply_op( + "NIntsIn", + a=[self.Tensor(dtypes.string), + self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have " "types [string, string] that do not match expected type " "int32.") with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NIntsIn", a=[99]) + op_def_library.apply_op("NIntsIn", a=[99]) self.assertEqual(str(cm.exception), "List argument 'a' to 'NIntsIn' Op " "with length 1 shorter than " "minimum length 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=[38, "bar"]) + op_def_library.apply_op("NIntsIn", a=[38, "bar"]) self.assertEqual( str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[int32, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", - a=[self.Tensor(dtypes.int32), - self.Tensor(dtypes.string)]) + op_def_library.apply_op( + "NIntsIn", + a=[self.Tensor(dtypes.int32), + self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op " "have types [int32, string] that do not match expected " "type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=17) + op_def_library.apply_op("NIntsIn", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NIntsIn' Op, not ") def testNPolymorphicIn(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n") + op = op_def_library.apply_op("NPolymorphicIn", a=[1, 2], name="n") self.assertProtoEquals(""" name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o") + op = op_def_library.apply_op( + "NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o") self.assertProtoEquals(""" name: 'o' op: 'NPolymorphicIn' input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4' @@ -897,26 +834,30 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 5 } } """, op.node_def) - op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p") + op = op_def_library.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p") self.assertProtoEquals(""" name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1' attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NPolymorphicIn", - a=[1, self.Tensor(dtypes.float32, name="x")], - name="q") + op = op_def_library.apply_op( + "NPolymorphicIn", + a=[1, self.Tensor(dtypes.float32, name="x")], + name="q") self.assertProtoEquals(""" name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x' attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NPolymorphicIn", - a=[self.Tensor(dtypes.float32, name="y"), - self.Tensor(dtypes.float32_ref, name="z")], - name="r") + op = op_def_library.apply_op( + "NPolymorphicIn", + a=[ + self.Tensor(dtypes.float32, name="y"), + self.Tensor(dtypes.float32_ref, name="z") + ], + name="r") self.assertProtoEquals(""" name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z' attr { key: 'T' value { type: DT_FLOAT } } @@ -924,56 +865,56 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NPolymorphicIn", a=[99]) + op_def_library.apply_op("NPolymorphicIn", a=[99]) self.assertEqual(str(cm.exception), "List argument 'a' to 'NPolymorphicIn' Op with length 1 " "shorter than minimum length 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicIn", a=[38, "bar"]) + op_def_library.apply_op("NPolymorphicIn", a=[38, "bar"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicIn", a=[38, self.Tensor(dtypes.string)]) + op_def_library.apply_op( + "NPolymorphicIn", a=[38, self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicIn", a=[38, None]) + op_def_library.apply_op("NPolymorphicIn", a=[38, None]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, ] that " "don't all match.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicIn", - a=["abcd", self.Tensor(dtypes.int32)]) + op_def_library.apply_op( + "NPolymorphicIn", a=["abcd", self.Tensor(dtypes.int32)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [string, int32] that don't all match.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicIn", a=17) + op_def_library.apply_op("NPolymorphicIn", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NPolymorphicIn' Op, not ") def testNPolymorphicRestrictIn(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"], - name="p") + op = op_def_library.apply_op( + "NPolymorphicRestrictIn", a=["foo", "bar"], name="p") self.assertProtoEquals(""" name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1' attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NPolymorphicRestrictIn", - a=[False, True, False], - name="b") + op = op_def_library.apply_op( + "NPolymorphicRestrictIn", a=[False, True, False], name="b") self.assertProtoEquals(""" name: 'b' op: 'NPolymorphicRestrictIn' input: 'b/a_0' input: 'b/a_1' input: 'b/a_2' @@ -982,7 +923,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2]) + op_def_library.apply_op("NPolymorphicRestrictIn", a=[1, 2]) self.assertEqual( str(cm.exception), "Value passed to parameter 'a' has DataType int32 not in " @@ -990,20 +931,21 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNInTwice(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n") + op = op_def_library.apply_op( + "NInTwice", a=[1, 2], b=["one", "two"], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInTwice' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NInTwice", a=[], b=[], name="o") + op = op_def_library.apply_op("NInTwice", a=[], b=[], name="o") self.assertProtoEquals(""" name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } } """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"]) + op_def_library.apply_op("NInTwice", a=[1, 2, 3], b=["too short"]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwice' Op " "with length 1 must match " @@ -1011,8 +953,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNInPolymorphicTwice(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], - name="n") + op = op_def_library.apply_op( + "NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInPolymorphicTwice' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' @@ -1021,23 +963,25 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5]) + op_def_library.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInPolymorphicTwice' Op " "with length 1 " "must match length 3 of argument 'a'.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"]) + op_def_library.apply_op( + "NInPolymorphicTwice", a=[1, 2], b=["one", "two"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'NInPolymorphicTwice' " "Op have types [string, string] that do not match type " "int32 inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NInPolymorphicTwice", - a=[self.Tensor(dtypes.int32)], - b=[self.Tensor(dtypes.string)]) + op_def_library.apply_op( + "NInPolymorphicTwice", + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of " "'NInPolymorphicTwice' Op have types [string] that do " @@ -1045,10 +989,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNInTwoTypeVariables(self): with ops.Graph().as_default(): - op = self._lib.apply_op("NInTwoTypeVariables", - a=[1, 2], - b=[True, False], - name="n") + op = op_def_library.apply_op( + "NInTwoTypeVariables", a=[1, 2], b=[True, False], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInTwoTypeVariables' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' @@ -1057,8 +999,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], - name="o") + op = op_def_library.apply_op( + "NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o") self.assertProtoEquals(""" name: 'o' op: 'NInTwoTypeVariables' input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1' @@ -1067,10 +1009,11 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 2 } } """, op.node_def) - op = self._lib.apply_op("NInTwoTypeVariables", - a=[self.Tensor(dtypes.int32, name="q")], - b=[self.Tensor(dtypes.string, name="r")], - name="p") + op = op_def_library.apply_op( + "NInTwoTypeVariables", + a=[self.Tensor(dtypes.int32, name="q")], + b=[self.Tensor(dtypes.string, name="r")], + name="p") self.assertProtoEquals(""" name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r' attr { key: 'S' value { type: DT_INT32 } } @@ -1079,7 +1022,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"]) + op_def_library.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwoTypeVariables' Op " "with length 1 " @@ -1087,8 +1030,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testInPolymorphicTwice(self): with ops.Graph().as_default(): - op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], - name="n") + op = op_def_library.apply_op( + "InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n") self.assertProtoEquals(""" name: 'n' op: 'InPolymorphicTwice' input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2' @@ -1097,7 +1040,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'M' value { i: 3 } } """, op.node_def) - op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o") + op = op_def_library.apply_op("InPolymorphicTwice", a=[8], b=[], name="o") self.assertProtoEquals(""" name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0' attr { key: 'T' value { type: DT_INT32 } } @@ -1106,13 +1049,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) + op_def_library.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) self.assertEqual(str(cm.exception), "Don't know how to infer type variable from empty input " "list passed to input 'a' of 'InPolymorphicTwice' Op.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"]) + op_def_library.apply_op( + "InPolymorphicTwice", a=[1, 2], b=["one", "two"]) self.assertEqual( str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op " @@ -1120,9 +1064,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("InPolymorphicTwice", - a=[self.Tensor(dtypes.int32)], - b=[self.Tensor(dtypes.string)]) + op_def_library.apply_op( + "InPolymorphicTwice", + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' " "Op have types [string] that do not match type int32 " @@ -1130,14 +1075,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNIntsOut(self): with ops.Graph().as_default(): - out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n") + out1, out2 = op_def_library.apply_op("NIntsOut", N=2, name="n") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } } """, out1.op.node_def) - out1, out2, out3, out4, out5 = self._lib.apply_op( + out1, out2, out3, out4, out5 = op_def_library.apply_op( "NIntsOut", N=5, name="o") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) @@ -1149,19 +1094,19 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out5.op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NIntsOut", N=1) + op_def_library.apply_op("NIntsOut", N=1) self.assertEqual( str(cm.exception), "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsOut", N=[3]) + op_def_library.apply_op("NIntsOut", N=[3]) self.assertEqual(str(cm.exception), "Expected int for argument 'N' not [3].") def testNIntsOutDefault(self): with ops.Graph().as_default(): - out1, out2, out3 = self._lib.apply_op( + out1, out2, out3 = op_def_library.apply_op( "NIntsOutDefault", N=None, name="z") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) @@ -1170,7 +1115,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } } """, out1.op.node_def) - out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y") + out1, out2 = op_def_library.apply_op("NIntsOutDefault", N=2, name="y") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" @@ -1179,10 +1124,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNPolymorphicOut(self): with ops.Graph().as_default(): - out1, out2 = self._lib.apply_op("NPolymorphicOut", - N=2, - T=dtypes.int32, - name="n") + out1, out2 = op_def_library.apply_op( + "NPolymorphicOut", N=2, T=dtypes.int32, name="n") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" @@ -1191,7 +1134,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 2 } } """, out1.op.node_def) - out1, out2, out3 = self._lib.apply_op( + out1, out2, out3 = op_def_library.apply_op( "NPolymorphicOut", T=dtypes.string, N=3, name="o") self.assertEqual(dtypes.string, out1.dtype) self.assertEqual(dtypes.string, out2.dtype) @@ -1203,20 +1146,20 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out3.op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string) + op_def_library.apply_op("NPolymorphicOut", N=1, T=dtypes.string) self.assertEqual(str(cm.exception), "Attr 'N' of 'NPolymorphicOut' Op " "passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string]) + op_def_library.apply_op("NPolymorphicOut", N=3, T=[dtypes.string]) self.assertEqual( str(cm.exception), "Expected DataType for argument 'T' not [tf.string].") def testNPolymorphicOutDefault(self): with ops.Graph().as_default(): - out1, out2 = self._lib.apply_op( + out1, out2 = op_def_library.apply_op( "NPolymorphicOutDefault", N=None, T=None, name="r") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) @@ -1226,7 +1169,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 2 } } """, out1.op.node_def) - out1, out2, out3 = self._lib.apply_op( + out1, out2, out3 = op_def_library.apply_op( "NPolymorphicOutDefault", N=3, T=None, name="s") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) @@ -1237,7 +1180,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 3 } } """, out1.op.node_def) - out1, out2 = self._lib.apply_op( + out1, out2 = op_def_library.apply_op( "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) @@ -1247,7 +1190,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): attr { key: 'N' value { i: 2 } } """, out1.op.node_def) - out1, out2, out3 = self._lib.apply_op( + out1, out2, out3 = op_def_library.apply_op( "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) @@ -1260,7 +1203,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNPolymorphicRestrictOut(self): with ops.Graph().as_default(): - out1, out2, out3 = self._lib.apply_op( + out1, out2, out3 = op_def_library.apply_op( "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) @@ -1272,21 +1215,21 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32) + op_def_library.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32) self.assertEqual(str(cm.exception), "Value passed to parameter 'T' has DataType int32 " "not in list of allowed values: string, bool") def testRef(self): with ops.Graph().as_default(): - out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o") + out = op_def_library.apply_op("RefOut", T=dtypes.bool, name="o") self.assertEqual(dtypes.bool_ref, out.dtype) self.assertProtoEquals(""" name: 'o' op: 'RefOut' attr { key: 'T' value { type: DT_BOOL } } """, out.op.node_def) - op = self._lib.apply_op("RefIn", a=out, name="i") + op = op_def_library.apply_op("RefIn", a=out, name="i") self.assertProtoEquals(""" name: 'i' op: 'RefIn' input: 'o' attr { key: 'T' value { type: DT_BOOL } } @@ -1294,23 +1237,23 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) # Can pass ref to non-ref input. - out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r") - out = self._lib.apply_op("Simple", a=out, name="s") + out = op_def_library.apply_op("RefOut", T=dtypes.int32, name="r") + out = op_def_library.apply_op("Simple", a=out, name="s") self.assertProtoEquals(""" name: 's' op: 'Simple' input: 'r' """, out.op.node_def) # Can't pass non-ref to ref input. with self.assertRaises(TypeError) as cm: - self._lib.apply_op("RefIn", a=2) + op_def_library.apply_op("RefIn", a=2) self.assertEqual( str(cm.exception), "'RefIn' Op requires that input 'a' be a mutable tensor " + "(e.g.: a tf.Variable)") - input_a = self._lib.apply_op("RefOut", T=dtypes.int32, name="t") - input_b = self._lib.apply_op("RefOut", T=dtypes.int32, name="u") - op = self._lib.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v") + input_a = op_def_library.apply_op("RefOut", T=dtypes.int32, name="t") + input_b = op_def_library.apply_op("RefOut", T=dtypes.int32, name="u") + op = op_def_library.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v") # NOTE(mrry): The order of colocation constraints is an implementation # detail. self.assertProtoEquals(""" @@ -1323,7 +1266,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): graph = ops.Graph() with graph.as_default(): with graph.device("/job:ADevice"): - self._lib.apply_op("Simple", a=3) + op_def_library.apply_op("Simple", a=3) # We look at the whole graph here to make sure the Const op is also given # the specified device. graph_def = graph.as_graph_def() @@ -1334,14 +1277,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testStructuredOutputSingleList(self): with ops.Graph().as_default(): for n_a in [0, 1, 3]: - a = self._lib.apply_op("SimpleStruct", n_a=n_a) + a = op_def_library.apply_op("SimpleStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) def testStructuredOutputListAndSingle(self): with ops.Graph().as_default(): for n_a in [0, 1, 3]: - a, b = self._lib.apply_op("MixedStruct", n_a=n_a) + a, b = op_def_library.apply_op("MixedStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) @@ -1355,10 +1298,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): for t_c in [[], [dtypes.int32], [dtypes.int32, dtypes.float32]]: - a, b, c = self._lib.apply_op("ComplexStruct", - n_a=n_a, - n_b=n_b, - t_c=t_c) + a, b, c = op_def_library.apply_op( + "ComplexStruct", n_a=n_a, n_b=n_b, t_c=t_c) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) @@ -1369,31 +1310,23 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): - def setUp(self): - self._lib = test_ops._op_def_lib - - def _add_op(self, ascii): # pylint: disable=redefined-builtin - op_def = op_def_pb2.OpDef() - text_format.Merge(ascii, op_def) - self._lib.add_op(op_def) - def testNoGraph(self): - out = self._lib.apply_op("Simple", a=3) + out = op_def_library.apply_op("Simple", a=3) self.assertEqual(out.graph, ops.get_default_graph()) def testDefaultGraph(self): graph = ops.Graph() with graph.as_default(): - out = self._lib.apply_op("Simple", a=3) + out = op_def_library.apply_op("Simple", a=3) self.assertEqual(out.graph, graph) def testDifferentGraphFails(self): with ops.Graph().as_default(): - a = self._lib.apply_op("Simple", a=3) + a = op_def_library.apply_op("Simple", a=3) with ops.Graph().as_default(): - b = self._lib.apply_op("Simple", a=4) + b = op_def_library.apply_op("Simple", a=4) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("Binary", a=a, b=b) + op_def_library.apply_op("Binary", a=a, b=b) self.assertTrue("must be from the same graph" in str(cm.exception)) diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 7db4a0c5133..6eec2eb20ed 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -375,8 +375,8 @@ void GenEagerPythonOp::HandleGraphMode( if (api_def_.visibility() == ApiDef::VISIBLE) { strings::StrAppend(&result_, " try:\n "); } - strings::StrAppend(&result_, - " _, _, _op, _outputs = _op_def_lib._apply_op_helper(\n"); + strings::StrAppend( + &result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n"); AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", ")); AddDispatch(" "); @@ -1007,7 +1007,6 @@ from tensorflow.python.eager import core as _core from tensorflow.python.eager import execute as _execute from tensorflow.python.framework import dtypes as _dtypes -from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 from tensorflow.python.framework import op_def_registry as _op_def_registry from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import op_def_library as _op_def_library @@ -1017,10 +1016,6 @@ from tensorflow.python.util.tf_export import tf_export )"); - // We'll make a copy of ops that filters out descriptions. - OpList cleaned_ops; - auto out = cleaned_ops.mutable_op(); - out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { const auto* api_def = api_defs.GetApiDef(op_def.name()); @@ -1064,22 +1059,8 @@ from tensorflow.python.util.tf_export import tf_export strings::StrAppend(&result, GetEagerPythonOp(op_def, *api_def, function_name)); - - auto added = out->Add(); - *added = op_def; - RemoveNonDeprecationDescriptionsFromOpDef(added); } - result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): - op_list = _op_def_pb2.OpList() - op_list.ParseFromString(op_list_proto_bytes) - op_def_lib = _op_def_library.OpDefLibrary() - op_def_lib.add_op_list(op_list) - return op_def_lib -)"); - - strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", - absl::CEscape(cleaned_ops.SerializeAsString()).c_str()); return result; }