Changed OpDefLibrary to use op_def_registry

Prior to this change OpDefLibrary used a local OpDef registry to
construct and validate ops. Recent changes in op_def_registry allowed
to switch to op_def_registry.get() for OpDef lookups making OpDefLibrary
redundant.

Note also that this changes removes binary ProtoBuf blobs from auto
generated op wrappers because they were only used for defining an
OpDefLibrary.

Before:

$ wc -c [...]/tensorflow/python/ops/gen_*_ops.py
...
4796803 total

After:

$ wc -c [...]/tensorflow/python/ops/gen_*_ops.py
...
4497581 total

PiperOrigin-RevId: 271557944
This commit is contained in:
Sergei Lebedev 2019-09-27 06:28:08 -07:00 committed by TensorFlower Gardener
parent ac6c12c16c
commit 32be3c5175
5 changed files with 700 additions and 846 deletions

View File

@ -974,6 +974,7 @@ tf_py_test(
":function_def_to_graph",
":graph_to_function_def",
":math_ops",
":op_def_library",
":test_ops",
],
tags = ["no_pip"],
@ -1031,6 +1032,7 @@ py_library(
deps = [
":dtypes",
":framework_ops",
":op_def_registry",
":platform",
":tensor_shape",
":util",
@ -2039,7 +2041,6 @@ tf_py_test(
":framework_for_generated_wrappers",
":framework_test_lib",
":platform_test",
":test_ops",
],
)

View File

@ -23,10 +23,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import op_def_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@ -119,8 +120,7 @@ class FunctionDefToGraphDefTest(test.TestCase):
y = array_ops.placeholder(dtypes.int32, name="y")
z = array_ops.placeholder(dtypes.int32, name="z")
d_1, e_1 = test_ops._op_def_lib.apply_op(
"Foo1", name="foo_1", a=x, b=y, c=z)
d_1, e_1 = op_def_library.apply_op("Foo1", name="foo_1", a=x, b=y, c=z)
list_output0, list_output1 = test_ops.list_output(
T=[dtypes.int32, dtypes.int32], name="list_output")

View File

@ -22,10 +22,10 @@ from __future__ import print_function
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
@ -229,38 +229,6 @@ def _MakeFunc(v, arg_name):
return fn_attr
class _OpInfo(object):
"""All per-Op state we would like to precompute/validate."""
def __init__(self, op_def):
self.op_def = op_def
# TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
# here, instead of these checks.
for arg in list(op_def.input_arg) + list(op_def.output_arg):
num_type_fields = _NumTypeFields(arg)
if num_type_fields != 1:
raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
(arg.name, op_def.name, num_type_fields))
if arg.type_attr:
attr_type = _Attr(op_def, arg.type_attr).type
if attr_type != "type":
raise TypeError("Attr '%s' of '%s' used as a type_attr "
"but has type %s" %
(arg.type_attr, op_def.name, attr_type))
if arg.type_list_attr:
attr_type = _Attr(op_def, arg.type_list_attr).type
if attr_type != "list(type)":
raise TypeError(
"Attr '%s' of '%s' used as a type_list_attr but has type %s" %
(arg.type_attr, op_def.name, attr_type))
if arg.number_attr:
attr_type = _Attr(op_def, arg.number_attr).type
if attr_type != "int":
raise TypeError(
"Attr '%s' of '%s' used as a number_attr but has type %s" %
(arg.number_attr, op_def.name, attr_type))
# pylint: disable=g-doc-return-or-yield
@tf_contextlib.contextmanager
def _MaybeColocateWith(inputs):
@ -282,34 +250,7 @@ def _MaybeColocateWith(inputs):
# pylint: enable=g-doc-return-or-yield
class OpDefLibrary(object):
"""Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
def __init__(self):
self._ops = {}
# pylint: disable=invalid-name
def add_op(self, op_def):
"""Register an OpDef. May call apply_op with the name afterwards."""
if not isinstance(op_def, op_def_pb2.OpDef):
raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
(op_def, type(op_def)))
if not op_def.name:
raise ValueError("%s missing name." % op_def)
if op_def.name in self._ops:
raise RuntimeError("Op name %s registered twice." % op_def.name)
self._ops[op_def.name] = _OpInfo(op_def)
def add_op_list(self, op_list):
"""Register the OpDefs from an OpList."""
if not isinstance(op_list, op_def_pb2.OpList):
raise TypeError("%s is %s, not an op_def_pb2.OpList" %
(op_list, type(op_list)))
for op_def in op_list.op:
self.add_op(op_def)
def apply_op(self, op_type_name, name=None, **keywords):
# pylint: disable=g-doc-args
def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
"""Add a node invoking a registered Op to a graph.
Example usage:
@ -341,7 +282,7 @@ class OpDefLibrary(object):
TypeError: On some errors.
ValueError: On some errors.
"""
output_structure, is_stateful, op, outputs = self._apply_op_helper(
output_structure, is_stateful, op, outputs = _apply_op_helper(
op_type_name, name, **keywords)
if output_structure:
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
@ -352,12 +293,12 @@ class OpDefLibrary(object):
else:
return op
def _apply_op_helper(self, op_type_name, name=None, **keywords):
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
"""Implementation of apply_op that returns output_structure, op."""
op_info = self._ops.get(op_type_name, None)
if op_info is None:
op_def = op_def_registry.get(op_type_name)
if op_def is None:
raise RuntimeError("Unrecognized Op name " + op_type_name)
op_def = op_info.op_def
# Determine the graph context.
try:
@ -807,5 +748,3 @@ class OpDefLibrary(object):
outputs = callback_outputs
return output_structure, op_def.is_stateful, op, outputs
# pylint: enable=invalid-name

File diff suppressed because it is too large Load Diff

View File

@ -375,8 +375,8 @@ void GenEagerPythonOp::HandleGraphMode(
if (api_def_.visibility() == ApiDef::VISIBLE) {
strings::StrAppend(&result_, " try:\n ");
}
strings::StrAppend(&result_,
" _, _, _op, _outputs = _op_def_lib._apply_op_helper(\n");
strings::StrAppend(
&result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
AddDispatch(" ");
@ -1007,7 +1007,6 @@ from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
@ -1017,10 +1016,6 @@ from tensorflow.python.util.tf_export import tf_export
)");
// We'll make a copy of ops that filters out descriptions.
OpList cleaned_ops;
auto out = cleaned_ops.mutable_op();
out->Reserve(ops.op_size());
for (const auto& op_def : ops.op()) {
const auto* api_def = api_defs.GetApiDef(op_def.name());
@ -1064,22 +1059,8 @@ from tensorflow.python.util.tf_export import tf_export
strings::StrAppend(&result,
GetEagerPythonOp(op_def, *api_def, function_name));
auto added = out->Add();
*added = op_def;
RemoveNonDeprecationDescriptionsFromOpDef(added);
}
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
op_list = _op_def_pb2.OpList()
op_list.ParseFromString(op_list_proto_bytes)
op_def_lib = _op_def_library.OpDefLibrary()
op_def_lib.add_op_list(op_list)
return op_def_lib
)");
strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
absl::CEscape(cleaned_ops.SerializeAsString()).c_str());
return result;
}