Changed OpDefLibrary to use op_def_registry
Prior to this change OpDefLibrary used a local OpDef registry to construct and validate ops. Recent changes in op_def_registry allowed to switch to op_def_registry.get() for OpDef lookups making OpDefLibrary redundant. Note also that this changes removes binary ProtoBuf blobs from auto generated op wrappers because they were only used for defining an OpDefLibrary. Before: $ wc -c [...]/tensorflow/python/ops/gen_*_ops.py ... 4796803 total After: $ wc -c [...]/tensorflow/python/ops/gen_*_ops.py ... 4497581 total PiperOrigin-RevId: 271557944
This commit is contained in:
parent
ac6c12c16c
commit
32be3c5175
@ -974,6 +974,7 @@ tf_py_test(
|
||||
":function_def_to_graph",
|
||||
":graph_to_function_def",
|
||||
":math_ops",
|
||||
":op_def_library",
|
||||
":test_ops",
|
||||
],
|
||||
tags = ["no_pip"],
|
||||
@ -1031,6 +1032,7 @@ py_library(
|
||||
deps = [
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":op_def_registry",
|
||||
":platform",
|
||||
":tensor_shape",
|
||||
":util",
|
||||
@ -2039,7 +2041,6 @@ tf_py_test(
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
":platform_test",
|
||||
":test_ops",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,10 +23,11 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function_def_to_graph
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import op_def_library
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -119,8 +120,7 @@ class FunctionDefToGraphDefTest(test.TestCase):
|
||||
y = array_ops.placeholder(dtypes.int32, name="y")
|
||||
z = array_ops.placeholder(dtypes.int32, name="z")
|
||||
|
||||
d_1, e_1 = test_ops._op_def_lib.apply_op(
|
||||
"Foo1", name="foo_1", a=x, b=y, c=z)
|
||||
d_1, e_1 = op_def_library.apply_op("Foo1", name="foo_1", a=x, b=y, c=z)
|
||||
|
||||
list_output0, list_output1 = test_ops.list_output(
|
||||
T=[dtypes.int32, dtypes.int32], name="list_output")
|
||||
|
@ -22,10 +22,10 @@ from __future__ import print_function
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.framework import tensor_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import op_def_registry
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import ops
|
||||
@ -229,38 +229,6 @@ def _MakeFunc(v, arg_name):
|
||||
return fn_attr
|
||||
|
||||
|
||||
class _OpInfo(object):
|
||||
"""All per-Op state we would like to precompute/validate."""
|
||||
|
||||
def __init__(self, op_def):
|
||||
self.op_def = op_def
|
||||
# TODO(josh11b): SWIG the ValidateOpDef() function from C++ and call it
|
||||
# here, instead of these checks.
|
||||
for arg in list(op_def.input_arg) + list(op_def.output_arg):
|
||||
num_type_fields = _NumTypeFields(arg)
|
||||
if num_type_fields != 1:
|
||||
raise TypeError("Arg '%s' of '%s' must have one type field not %d" %
|
||||
(arg.name, op_def.name, num_type_fields))
|
||||
if arg.type_attr:
|
||||
attr_type = _Attr(op_def, arg.type_attr).type
|
||||
if attr_type != "type":
|
||||
raise TypeError("Attr '%s' of '%s' used as a type_attr "
|
||||
"but has type %s" %
|
||||
(arg.type_attr, op_def.name, attr_type))
|
||||
if arg.type_list_attr:
|
||||
attr_type = _Attr(op_def, arg.type_list_attr).type
|
||||
if attr_type != "list(type)":
|
||||
raise TypeError(
|
||||
"Attr '%s' of '%s' used as a type_list_attr but has type %s" %
|
||||
(arg.type_attr, op_def.name, attr_type))
|
||||
if arg.number_attr:
|
||||
attr_type = _Attr(op_def, arg.number_attr).type
|
||||
if attr_type != "int":
|
||||
raise TypeError(
|
||||
"Attr '%s' of '%s' used as a number_attr but has type %s" %
|
||||
(arg.number_attr, op_def.name, attr_type))
|
||||
|
||||
|
||||
# pylint: disable=g-doc-return-or-yield
|
||||
@tf_contextlib.contextmanager
|
||||
def _MaybeColocateWith(inputs):
|
||||
@ -282,34 +250,7 @@ def _MaybeColocateWith(inputs):
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
|
||||
class OpDefLibrary(object):
|
||||
"""Holds a collection of OpDefs, can add the corresponding Ops to a graph."""
|
||||
|
||||
def __init__(self):
|
||||
self._ops = {}
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def add_op(self, op_def):
|
||||
"""Register an OpDef. May call apply_op with the name afterwards."""
|
||||
if not isinstance(op_def, op_def_pb2.OpDef):
|
||||
raise TypeError("%s is %s, not an op_def_pb2.OpDef" %
|
||||
(op_def, type(op_def)))
|
||||
if not op_def.name:
|
||||
raise ValueError("%s missing name." % op_def)
|
||||
if op_def.name in self._ops:
|
||||
raise RuntimeError("Op name %s registered twice." % op_def.name)
|
||||
self._ops[op_def.name] = _OpInfo(op_def)
|
||||
|
||||
def add_op_list(self, op_list):
|
||||
"""Register the OpDefs from an OpList."""
|
||||
if not isinstance(op_list, op_def_pb2.OpList):
|
||||
raise TypeError("%s is %s, not an op_def_pb2.OpList" %
|
||||
(op_list, type(op_list)))
|
||||
for op_def in op_list.op:
|
||||
self.add_op(op_def)
|
||||
|
||||
def apply_op(self, op_type_name, name=None, **keywords):
|
||||
# pylint: disable=g-doc-args
|
||||
def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
|
||||
"""Add a node invoking a registered Op to a graph.
|
||||
|
||||
Example usage:
|
||||
@ -341,7 +282,7 @@ class OpDefLibrary(object):
|
||||
TypeError: On some errors.
|
||||
ValueError: On some errors.
|
||||
"""
|
||||
output_structure, is_stateful, op, outputs = self._apply_op_helper(
|
||||
output_structure, is_stateful, op, outputs = _apply_op_helper(
|
||||
op_type_name, name, **keywords)
|
||||
if output_structure:
|
||||
res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
|
||||
@ -352,12 +293,12 @@ class OpDefLibrary(object):
|
||||
else:
|
||||
return op
|
||||
|
||||
def _apply_op_helper(self, op_type_name, name=None, **keywords):
|
||||
|
||||
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name
|
||||
"""Implementation of apply_op that returns output_structure, op."""
|
||||
op_info = self._ops.get(op_type_name, None)
|
||||
if op_info is None:
|
||||
op_def = op_def_registry.get(op_type_name)
|
||||
if op_def is None:
|
||||
raise RuntimeError("Unrecognized Op name " + op_type_name)
|
||||
op_def = op_info.op_def
|
||||
|
||||
# Determine the graph context.
|
||||
try:
|
||||
@ -807,5 +748,3 @@ class OpDefLibrary(object):
|
||||
outputs = callback_outputs
|
||||
|
||||
return output_structure, op_def.is_stateful, op, outputs
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -375,8 +375,8 @@ void GenEagerPythonOp::HandleGraphMode(
|
||||
if (api_def_.visibility() == ApiDef::VISIBLE) {
|
||||
strings::StrAppend(&result_, " try:\n ");
|
||||
}
|
||||
strings::StrAppend(&result_,
|
||||
" _, _, _op, _outputs = _op_def_lib._apply_op_helper(\n");
|
||||
strings::StrAppend(
|
||||
&result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
|
||||
AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
|
||||
AddDispatch(" ");
|
||||
|
||||
@ -1007,7 +1007,6 @@ from tensorflow.python.eager import core as _core
|
||||
from tensorflow.python.eager import execute as _execute
|
||||
from tensorflow.python.framework import dtypes as _dtypes
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
|
||||
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import op_def_library as _op_def_library
|
||||
@ -1017,10 +1016,6 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
)");
|
||||
|
||||
// We'll make a copy of ops that filters out descriptions.
|
||||
OpList cleaned_ops;
|
||||
auto out = cleaned_ops.mutable_op();
|
||||
out->Reserve(ops.op_size());
|
||||
for (const auto& op_def : ops.op()) {
|
||||
const auto* api_def = api_defs.GetApiDef(op_def.name());
|
||||
|
||||
@ -1064,22 +1059,8 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
strings::StrAppend(&result,
|
||||
GetEagerPythonOp(op_def, *api_def, function_name));
|
||||
|
||||
auto added = out->Add();
|
||||
*added = op_def;
|
||||
RemoveNonDeprecationDescriptionsFromOpDef(added);
|
||||
}
|
||||
|
||||
result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
|
||||
op_list = _op_def_pb2.OpList()
|
||||
op_list.ParseFromString(op_list_proto_bytes)
|
||||
op_def_lib = _op_def_library.OpDefLibrary()
|
||||
op_def_lib.add_op_list(op_list)
|
||||
return op_def_lib
|
||||
)");
|
||||
|
||||
strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
|
||||
absl::CEscape(cleaned_ops.SerializeAsString()).c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user