Adding Python ApiDef overrides.

PiperOrigin-RevId: 172960496
This commit is contained in:
Anna R 2017-10-20 18:20:05 -07:00 committed by TensorFlower Gardener
parent 0d6a2e3531
commit 93e8f3c67d
16 changed files with 620 additions and 0 deletions

View File

@ -3326,6 +3326,11 @@ filegroup(
data = glob(["api_def/base_api/*"]),
)
filegroup(
name = "python_api_def",
data = glob(["api_def/python_api/*"]),
)
tf_cc_test(
name = "api_test",
srcs = ["api_def/api_test.cc"],

View File

@ -0,0 +1,18 @@
op {
graph_op_name: "BitwiseAnd"
endpoint {
name: "bitwise.bitwise_and"
}
}
op {
graph_op_name: "BitwiseOr"
endpoint {
name: "bitwise.bitwise_or"
}
}
op {
graph_op_name: "BitwiseXor"
endpoint {
name: "bitwise.bitwise_xor"
}
}

View File

@ -0,0 +1,15 @@
op {
graph_op_name: "Cholesky"
endpoint {
name: "cholesky"
}
endpoint {
name: "linalg.cholesky"
}
}
op {
graph_op_name: "CropAndResize"
endpoint {
name: "image.crop_and_resize"
}
}

View File

@ -0,0 +1,54 @@
op {
graph_op_name: "DecodeAndCropJpeg"
endpoint {
name: "image.decode_and_crop_jpeg"
}
}
op {
graph_op_name: "DecodeBmp"
endpoint {
name: "image.decode_bmp"
}
}
op {
graph_op_name: "DecodeGif"
endpoint {
name: "image.decode_gif"
}
}
op {
graph_op_name: "DecodeJpeg"
endpoint {
name: "image.decode_jpeg"
}
}
op {
graph_op_name: "DecodePng"
endpoint {
name: "image.decode_png"
}
}
op {
graph_op_name: "DepthwiseConv2dNative"
endpoint {
name: "nn.depthwise_conv2d_native"
}
}
op {
graph_op_name: "DepthwiseConv2dNativeBackpropFilter"
endpoint {
name: "nn.depthwise_conv2d_native_backprop_filter"
}
}
op {
graph_op_name: "DepthwiseConv2dNativeBackpropInput"
endpoint {
name: "nn.depthwise_conv2d_native_backprop_input"
}
}
op {
graph_op_name: "DrawBoundingBoxes"
endpoint {
name: "image.draw_bounding_boxes"
}
}

View File

@ -0,0 +1,30 @@
op {
graph_op_name: "Elu"
endpoint {
name: "nn.elu"
}
}
op {
graph_op_name: "EncodeJpeg"
endpoint {
name: "image.encode_jpeg"
}
}
op {
graph_op_name: "EncodePng"
endpoint {
name: "image.encode_png"
}
}
op {
graph_op_name: "ExtractGlimpse"
endpoint {
name: "image.extract_glimpse"
}
}
op {
graph_op_name: "ExtractJpegShape"
endpoint {
name: "image.extract_jpeg_shape"
}
}

View File

@ -0,0 +1,21 @@
op {
graph_op_name: "FFT"
endpoint {
name: "fft"
}
endpoint {
name: "spectral.fft"
}
}
op {
graph_op_name: "FractionalAvgPool"
endpoint {
name: "nn.fractional_avg_pool"
}
}
op {
graph_op_name: "FractionalMaxPool"
endpoint {
name: "nn.fractional_max_pool"
}
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "HSVToRGB"
endpoint {
name: "image.hsv_to_rgb"
}
}

View File

@ -0,0 +1,15 @@
op {
graph_op_name: "IFFT"
endpoint {
name: "ifft"
}
endpoint {
name: "spectral.ifft"
}
}
op {
graph_op_name: "Invert"
endpoint {
name: "bitwise.invert"
}
}

View File

@ -0,0 +1,24 @@
op {
graph_op_name: "L2Loss"
endpoint {
name: "nn.l2_loss"
}
}
op {
graph_op_name: "LRN"
endpoint {
name: "nn.local_response_normalization"
}
endpoint {
name: "nn.lrn"
}
}
op {
graph_op_name: "LinSpace"
endpoint {
name: "lin_space"
}
endpoint {
name: "linspace"
}
}

View File

@ -0,0 +1,78 @@
op {
graph_op_name: "MatrixBandPart"
endpoint {
name: "linalg.band_part"
}
endpoint {
name: "matrix_band_part"
}
}
op {
graph_op_name: "MatrixDeterminant"
endpoint {
name: "linalg.det"
}
endpoint {
name: "matrix_determinant"
}
}
op {
graph_op_name: "MatrixDiag"
endpoint {
name: "linalg.diag"
}
endpoint {
name: "matrix_diag"
}
}
op {
graph_op_name: "MatrixDiagPart"
endpoint {
name: "linalg.diag_part"
}
endpoint {
name: "matrix_diag_part"
}
}
op {
graph_op_name: "MatrixInverse"
endpoint {
name: "linalg.inv"
}
endpoint {
name: "matrix_inverse"
}
}
op {
graph_op_name: "MatrixSetDiag"
endpoint {
name: "linalg.set_diag"
}
endpoint {
name: "matrix_set_diag"
}
}
op {
graph_op_name: "MatrixSolve"
endpoint {
name: "linalg.solve"
}
endpoint {
name: "matrix_solve"
}
}
op {
graph_op_name: "MatrixTriangularSolve"
endpoint {
name: "linalg.triangular_solve"
}
endpoint {
name: "matrix_triangular_solve"
}
}
op {
graph_op_name: "MaxPoolWithArgmax"
endpoint {
name: "nn.max_pool_with_argmax"
}
}

View File

@ -0,0 +1,27 @@
op {
graph_op_name: "Qr"
endpoint {
name: "linalg.qr"
}
endpoint {
name: "qr"
}
}
op {
graph_op_name: "QuantizedAvgPool"
endpoint {
name: "nn.quantized_avg_pool"
}
}
op {
graph_op_name: "QuantizedMaxPool"
endpoint {
name: "nn.quantized_max_pool"
}
}
op {
graph_op_name: "QuantizedReluX"
endpoint {
name: "nn.quantized_relu_x"
}
}

View File

@ -0,0 +1,36 @@
op {
graph_op_name: "RGBToHSV"
endpoint {
name: "image.rgb_to_hsv"
}
}
op {
graph_op_name: "Relu"
endpoint {
name: "nn.relu"
}
}
op {
graph_op_name: "ResizeArea"
endpoint {
name: "image.resize_area"
}
}
op {
graph_op_name: "ResizeBicubic"
endpoint {
name: "image.resize_bicubic"
}
}
op {
graph_op_name: "ResizeBilinear"
endpoint {
name: "image.resize_bilinear"
}
}
op {
graph_op_name: "ResizeNearestNeighbor"
endpoint {
name: "image.resize_nearest_neighbor"
}
}

View File

@ -0,0 +1,36 @@
op {
graph_op_name: "SdcaFprint"
endpoint {
name: "train.sdca_fprint"
}
}
op {
graph_op_name: "SdcaOptimizer"
endpoint {
name: "train.sdca_optimizer"
}
}
op {
graph_op_name: "SdcaShrinkL1"
endpoint {
name: "train.sdca_shrink_l1"
}
}
op {
graph_op_name: "Selu"
endpoint {
name: "nn.selu"
}
}
op {
graph_op_name: "Softplus"
endpoint {
name: "nn.softplus"
}
}
op {
graph_op_name: "Softsign"
endpoint {
name: "nn.softsign"
}
}

View File

@ -11,10 +11,15 @@ exports_files([
"API_UPDATE_WARNING.txt",
])
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
py_test(
name = "api_compatibility_test",
srcs = ["api_compatibility_test.py"],
data = [
":convert_from_multiline",
"//tensorflow/core:base_api_def",
"//tensorflow/core:python_api_def",
"//tensorflow/tools/api/golden:api_golden",
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
"//tensorflow/tools/api/tests:README.txt",
@ -23,6 +28,7 @@ py_test(
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
@ -31,6 +37,15 @@ py_test(
],
)
tf_cc_binary(
name = "convert_from_multiline",
srcs = ["convert_from_multiline.cc"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -28,8 +28,11 @@ from __future__ import division
from __future__ import print_function
import argparse
from collections import defaultdict
from operator import attrgetter
import os
import re
import subprocess
import sys
import unittest
@ -37,6 +40,7 @@ import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import api_def_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
@ -64,6 +68,11 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden'
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline'
_BASE_API_DIR = 'tensorflow/core/api_def/base_api'
_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api'
def _KeyToFilePath(key):
"""From a given key, construct a filepath."""
@ -88,6 +97,30 @@ def _FileNameToKey(filename):
return api_object_key
def _GetSymbol(symbol_id):
"""Get TensorFlow symbol based on the given identifier.
Args:
symbol_id: Symbol identifier in the form module1.module2. ... .sym.
Returns:
Symbol corresponding to the given id.
"""
# Ignore first module which should be tensorflow
symbol_id_split = symbol_id.split('.')[1:]
symbol = tf
for sym in symbol_id_split:
symbol = getattr(symbol, sym)
return symbol
def _IsGenModule(module_name):
if not module_name:
return False
module_name_split = module_name.split('.')
return module_name_split[-1].startswith('gen_')
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
@ -229,6 +262,150 @@ class ApiCompatibilityTest(test.TestCase):
update_goldens=FLAGS.update_goldens)
class ApiDefTest(test.TestCase):
def __init__(self, *args, **kwargs):
super(ApiDefTest, self).__init__(*args, **kwargs)
self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)')
self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])')
def _GenerateLowerCaseOpName(self, op_name):
lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name)
return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower()
def _CreatePythonApiDef(self, base_api_def, endpoint_names):
"""Creates Python ApiDef that overrides base_api_def if needed.
Args:
base_api_def: (api_def_pb2.ApiDef) base ApiDef instance.
endpoint_names: List of Python endpoint names.
Returns:
api_def_pb2.ApiDef instance with overrides for base_api_def
if module.name endpoint is different from any existing
endpoints in base_api_def. Otherwise, returns None.
"""
endpoint_names_set = set(endpoint_names)
base_endpoint_names_set = {
self._GenerateLowerCaseOpName(endpoint.name)
for endpoint in base_api_def.endpoint}
if endpoint_names_set == base_endpoint_names_set:
return None # All endpoints are the same
api_def = api_def_pb2.ApiDef()
api_def.graph_op_name = base_api_def.graph_op_name
for endpoint_name in sorted(endpoint_names):
new_endpoint = api_def.endpoint.add()
new_endpoint.name = endpoint_name
return api_def
def _GetBaseApiMap(self):
"""Get a map from graph op name to its base ApiDef.
Returns:
Dictionary mapping graph op name to corresponding ApiDef.
"""
# Convert base ApiDef in Multiline format to Proto format.
converted_base_api_dir = os.path.join(
test.get_temp_dir(), 'temp_base_api_defs')
subprocess.check_call(
[os.path.join(resource_loader.get_root_dir_with_all_resources(),
_CONVERT_FROM_MULTILINE_SCRIPT),
_BASE_API_DIR, converted_base_api_dir])
name_to_base_api_def = {}
base_api_files = file_io.get_matching_files(
os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
for base_api_file in base_api_files:
if file_io.file_exists(base_api_file):
api_defs = api_def_pb2.ApiDefs()
text_format.Merge(
file_io.read_file_to_string(base_api_file), api_defs)
for api_def in api_defs.op:
lower_case_name = self._GenerateLowerCaseOpName(api_def.graph_op_name)
name_to_base_api_def[lower_case_name] = api_def
return name_to_base_api_def
@unittest.skipUnless(
sys.version_info.major == 2 and os.uname()[0] == 'Linux',
'API compabitility test goldens are generated using python2 on Linux.')
def testAPIDefCompatibility(self):
# Get base ApiDef
name_to_base_api_def = self._GetBaseApiMap()
# Extract Python API
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
traverse.traverse(tf, public_api_visitor)
proto_dict = visitor.GetProtos()
# Map from first character of op name to Python ApiDefs.
api_def_map = defaultdict(api_def_pb2.ApiDefs)
# We need to override all endpoints even if 1 endpoint differs from base
# ApiDef. So, we first create a map from an op to all its endpoints.
op_to_endpoint_name = defaultdict(list)
# Generate map from generated python op to endpoint names.
for public_module, value in proto_dict.items():
module_obj = _GetSymbol(public_module)
for sym in value.tf_module.member_method:
obj = getattr(module_obj, sym.name)
# Check if object is defined in gen_* module. That is,
# the object has been generated from OpDef.
if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
if obj.__name__ not in name_to_base_api_def:
# Symbol might be defined only in Python and not generated from
# C++ api.
continue
relative_public_module = public_module[len('tensorflow.'):]
full_name = (relative_public_module + '.' + sym.name
if relative_public_module else sym.name)
op_to_endpoint_name[obj].append(full_name)
# Generate Python ApiDef overrides.
for op, endpoint_names in op_to_endpoint_name.items():
api_def = self._CreatePythonApiDef(
name_to_base_api_def[op.__name__], endpoint_names)
if api_def:
api_defs = api_def_map[op.__name__[0].upper()]
api_defs.op.extend([api_def])
for key in _ALPHABET:
# Get new ApiDef for the given key.
new_api_defs_str = ''
if key in api_def_map:
new_api_defs = api_def_map[key]
new_api_defs.op.sort(key=attrgetter('graph_op_name'))
new_api_defs_str = str(new_api_defs)
# Get current ApiDef for the given key.
api_defs_file_path = os.path.join(
_PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
old_api_defs_str = ''
if file_io.file_exists(api_defs_file_path):
old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)
if old_api_defs_str == new_api_defs_str:
continue
if FLAGS.update_goldens:
if not new_api_defs_str:
logging.info('Deleting %s...' % api_defs_file_path)
file_io.delete_file(api_defs_file_path)
else:
logging.info('Updating %s...' % api_defs_file_path)
file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
else:
self.assertMultiLineEqual(
old_api_defs_str, new_api_defs_str,
'To update golden API files, run api_compatibility_test locally '
'with --update_goldens=True flag.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(

View File

@ -0,0 +1,63 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Converts all *.pbtxt files in a directory from Multiline to proto format.
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
namespace tensorflow {
namespace {
constexpr char kApiDefFilePattern[] = "*.pbtxt";
Status ConvertFilesFromMultiline(const string& input_dir,
const string& output_dir) {
Env* env = Env::Default();
const string file_pattern = io::JoinPath(input_dir, kApiDefFilePattern);
std::vector<string> matching_paths;
TF_CHECK_OK(env->GetMatchingPaths(file_pattern, &matching_paths));
if (!env->IsDirectory(output_dir).ok()) {
TF_RETURN_IF_ERROR(env->CreateDir(output_dir));
}
for (const auto& path : matching_paths) {
string contents;
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, path, &contents));
contents = tensorflow::PBTxtFromMultiline(contents);
string output_path = io::JoinPath(output_dir, io::Basename(path));
// Write contents to output_path
TF_RETURN_IF_ERROR(
tensorflow::WriteStringToFile(env, output_path, contents));
}
return Status::OK();
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
const std::string usage =
"Usage: convert_from_multiline input_dir output_dir";
if (argc != 3) {
std::cerr << usage << std::endl;
return -1;
}
TF_CHECK_OK(tensorflow::ConvertFilesFromMultiline(argv[1], argv[2]));
return 0;
}