Feng Liu 8ccc8fe538 Add the Python API and AST translator for authoring TF op composition
The new API allows the user to define a new op with the composition of existing
TF ops. For example, the following code defined a new op called 'NewConv2D'
with three tensor inputs and one output:

```
@Composite(
    name = 'NewConv2D',
    inputs=['input_: T', 'filter_: T', 'bias: T'],
    attrs=[
        'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int',
        'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""'
    ],
    derived_attrs=['T: {float, int8}'],
    outputs=['o: T'])
def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h,
                             dilation_w, dilation_h, padding, act):
  res = tf.raw_ops.Conv2D(
      input=input_,
      filter=filter_,
      strides=[1, stride_w, stride_h, 1],
      dilations=[1, dilation_w, dilation_h, 1],
      padding=padding)
  res = tf.raw_ops.Add(x=res, y=bias)
  if act == 'RELU':
    return tf.raw_ops.Relu(features=res)
  elif act == 'RELU6':
    return tf.raw_ops.Relu6(features=res)
  elif act == 'TANH':
    return tf.raw_ops.Tanh(x=res)
  else:
    return res
```

The translator can be invoked to translate this function directly to an MLIR
module, which can be used to (optionally) expand the 'NewConv2D' in the graph
to a sequence of leaf nodes with kernels and bridge support. Optimization
(constant folding, loop unrolling and branch condition evaluation, etc., which
are chcked in seperately) can be applied to simplify the result graph.

PiperOrigin-RevId: 335976996
Change-Id: Ie4d641c0e86b234cd94d0a6d36e26269659676f9
2020-10-07 16:35:35 -07:00

57 lines
2.0 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Op composition registration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# TODO(fengliuai): add the tf_export decrator
class Composite(object):
"""A decorator to register a function as a composition for an TF operator.
The argument to the decorator must be the name of a TF raw operator the
function composites for. Decorated function must take positional arguments
which corresponds to the input and attributes in OpDef of the TF operation.
# TODO(fengliuai): more documents here.
Example:
@composite.Composite('AddN')
def _compose_add_n(inputs, N):
if N == 1:
....
"""
# TODO(fengliuai): support input_binding and output_binding so the arguments
# are not positional.
def __init__(self,
op_name,
inputs=None,
attrs=None,
derived_attrs=None,
outputs=None):
self._op_name = op_name
self._inputs = inputs
self._attrs = attrs
self._derived_attrs = derived_attrs
self._outputs = outputs
def __call__(self, compose_fn):
# TODO(fengliuai): more sanity check of the input function and make sure
# the bounded arguments of the function matches the 'inputs' and 'attrs'.
setattr(compose_fn, '_tfr_op_name', self._op_name)
return compose_fn