Adding Python ApiDef overrides.
PiperOrigin-RevId: 172960496
This commit is contained in:
parent
0d6a2e3531
commit
93e8f3c67d
@ -3326,6 +3326,11 @@ filegroup(
|
|||||||
data = glob(["api_def/base_api/*"]),
|
data = glob(["api_def/base_api/*"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "python_api_def",
|
||||||
|
data = glob(["api_def/python_api/*"]),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "api_test",
|
name = "api_test",
|
||||||
srcs = ["api_def/api_test.cc"],
|
srcs = ["api_def/api_test.cc"],
|
||||||
|
18
tensorflow/core/api_def/python_api/api_def_B.pbtxt
Normal file
18
tensorflow/core/api_def/python_api/api_def_B.pbtxt
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "BitwiseAnd"
|
||||||
|
endpoint {
|
||||||
|
name: "bitwise.bitwise_and"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "BitwiseOr"
|
||||||
|
endpoint {
|
||||||
|
name: "bitwise.bitwise_or"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "BitwiseXor"
|
||||||
|
endpoint {
|
||||||
|
name: "bitwise.bitwise_xor"
|
||||||
|
}
|
||||||
|
}
|
15
tensorflow/core/api_def/python_api/api_def_C.pbtxt
Normal file
15
tensorflow/core/api_def/python_api/api_def_C.pbtxt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Cholesky"
|
||||||
|
endpoint {
|
||||||
|
name: "cholesky"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.cholesky"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "CropAndResize"
|
||||||
|
endpoint {
|
||||||
|
name: "image.crop_and_resize"
|
||||||
|
}
|
||||||
|
}
|
54
tensorflow/core/api_def/python_api/api_def_D.pbtxt
Normal file
54
tensorflow/core/api_def/python_api/api_def_D.pbtxt
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DecodeAndCropJpeg"
|
||||||
|
endpoint {
|
||||||
|
name: "image.decode_and_crop_jpeg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DecodeBmp"
|
||||||
|
endpoint {
|
||||||
|
name: "image.decode_bmp"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DecodeGif"
|
||||||
|
endpoint {
|
||||||
|
name: "image.decode_gif"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DecodeJpeg"
|
||||||
|
endpoint {
|
||||||
|
name: "image.decode_jpeg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DecodePng"
|
||||||
|
endpoint {
|
||||||
|
name: "image.decode_png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DepthwiseConv2dNative"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.depthwise_conv2d_native"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DepthwiseConv2dNativeBackpropFilter"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.depthwise_conv2d_native_backprop_filter"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DepthwiseConv2dNativeBackpropInput"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.depthwise_conv2d_native_backprop_input"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "DrawBoundingBoxes"
|
||||||
|
endpoint {
|
||||||
|
name: "image.draw_bounding_boxes"
|
||||||
|
}
|
||||||
|
}
|
30
tensorflow/core/api_def/python_api/api_def_E.pbtxt
Normal file
30
tensorflow/core/api_def/python_api/api_def_E.pbtxt
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Elu"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.elu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "EncodeJpeg"
|
||||||
|
endpoint {
|
||||||
|
name: "image.encode_jpeg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "EncodePng"
|
||||||
|
endpoint {
|
||||||
|
name: "image.encode_png"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ExtractGlimpse"
|
||||||
|
endpoint {
|
||||||
|
name: "image.extract_glimpse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ExtractJpegShape"
|
||||||
|
endpoint {
|
||||||
|
name: "image.extract_jpeg_shape"
|
||||||
|
}
|
||||||
|
}
|
21
tensorflow/core/api_def/python_api/api_def_F.pbtxt
Normal file
21
tensorflow/core/api_def/python_api/api_def_F.pbtxt
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "FFT"
|
||||||
|
endpoint {
|
||||||
|
name: "fft"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "spectral.fft"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "FractionalAvgPool"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.fractional_avg_pool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "FractionalMaxPool"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.fractional_max_pool"
|
||||||
|
}
|
||||||
|
}
|
6
tensorflow/core/api_def/python_api/api_def_H.pbtxt
Normal file
6
tensorflow/core/api_def/python_api/api_def_H.pbtxt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "HSVToRGB"
|
||||||
|
endpoint {
|
||||||
|
name: "image.hsv_to_rgb"
|
||||||
|
}
|
||||||
|
}
|
15
tensorflow/core/api_def/python_api/api_def_I.pbtxt
Normal file
15
tensorflow/core/api_def/python_api/api_def_I.pbtxt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "IFFT"
|
||||||
|
endpoint {
|
||||||
|
name: "ifft"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "spectral.ifft"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "Invert"
|
||||||
|
endpoint {
|
||||||
|
name: "bitwise.invert"
|
||||||
|
}
|
||||||
|
}
|
24
tensorflow/core/api_def/python_api/api_def_L.pbtxt
Normal file
24
tensorflow/core/api_def/python_api/api_def_L.pbtxt
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "L2Loss"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.l2_loss"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "LRN"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.local_response_normalization"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "nn.lrn"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "LinSpace"
|
||||||
|
endpoint {
|
||||||
|
name: "lin_space"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "linspace"
|
||||||
|
}
|
||||||
|
}
|
78
tensorflow/core/api_def/python_api/api_def_M.pbtxt
Normal file
78
tensorflow/core/api_def/python_api/api_def_M.pbtxt
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "MatrixBandPart"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.band_part"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_band_part"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixDeterminant"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.det"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_determinant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixDiag"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.diag"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_diag"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixDiagPart"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.diag_part"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_diag_part"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixInverse"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.inv"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_inverse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixSetDiag"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.set_diag"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_set_diag"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixSolve"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.solve"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_solve"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MatrixTriangularSolve"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.triangular_solve"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "matrix_triangular_solve"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "MaxPoolWithArgmax"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.max_pool_with_argmax"
|
||||||
|
}
|
||||||
|
}
|
27
tensorflow/core/api_def/python_api/api_def_Q.pbtxt
Normal file
27
tensorflow/core/api_def/python_api/api_def_Q.pbtxt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Qr"
|
||||||
|
endpoint {
|
||||||
|
name: "linalg.qr"
|
||||||
|
}
|
||||||
|
endpoint {
|
||||||
|
name: "qr"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "QuantizedAvgPool"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.quantized_avg_pool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "QuantizedMaxPool"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.quantized_max_pool"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "QuantizedReluX"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.quantized_relu_x"
|
||||||
|
}
|
||||||
|
}
|
36
tensorflow/core/api_def/python_api/api_def_R.pbtxt
Normal file
36
tensorflow/core/api_def/python_api/api_def_R.pbtxt
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RGBToHSV"
|
||||||
|
endpoint {
|
||||||
|
name: "image.rgb_to_hsv"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "Relu"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.relu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ResizeArea"
|
||||||
|
endpoint {
|
||||||
|
name: "image.resize_area"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ResizeBicubic"
|
||||||
|
endpoint {
|
||||||
|
name: "image.resize_bicubic"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ResizeBilinear"
|
||||||
|
endpoint {
|
||||||
|
name: "image.resize_bilinear"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "ResizeNearestNeighbor"
|
||||||
|
endpoint {
|
||||||
|
name: "image.resize_nearest_neighbor"
|
||||||
|
}
|
||||||
|
}
|
36
tensorflow/core/api_def/python_api/api_def_S.pbtxt
Normal file
36
tensorflow/core/api_def/python_api/api_def_S.pbtxt
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SdcaFprint"
|
||||||
|
endpoint {
|
||||||
|
name: "train.sdca_fprint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "SdcaOptimizer"
|
||||||
|
endpoint {
|
||||||
|
name: "train.sdca_optimizer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "SdcaShrinkL1"
|
||||||
|
endpoint {
|
||||||
|
name: "train.sdca_shrink_l1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "Selu"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.selu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "Softplus"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.softplus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
graph_op_name: "Softsign"
|
||||||
|
endpoint {
|
||||||
|
name: "nn.softsign"
|
||||||
|
}
|
||||||
|
}
|
@ -11,10 +11,15 @@ exports_files([
|
|||||||
"API_UPDATE_WARNING.txt",
|
"API_UPDATE_WARNING.txt",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "api_compatibility_test",
|
name = "api_compatibility_test",
|
||||||
srcs = ["api_compatibility_test.py"],
|
srcs = ["api_compatibility_test.py"],
|
||||||
data = [
|
data = [
|
||||||
|
":convert_from_multiline",
|
||||||
|
"//tensorflow/core:base_api_def",
|
||||||
|
"//tensorflow/core:python_api_def",
|
||||||
"//tensorflow/tools/api/golden:api_golden",
|
"//tensorflow/tools/api/golden:api_golden",
|
||||||
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
|
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
|
||||||
"//tensorflow/tools/api/tests:README.txt",
|
"//tensorflow/tools/api/tests:README.txt",
|
||||||
@ -23,6 +28,7 @@ py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
|
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
|
||||||
@ -31,6 +37,15 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_binary(
|
||||||
|
name = "convert_from_multiline",
|
||||||
|
srcs = ["convert_from_multiline.cc"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:op_gen_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -28,8 +28,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from operator import attrgetter
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -37,6 +40,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
|
from tensorflow.core.framework import api_def_pb2
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -64,6 +68,11 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden'
|
|||||||
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
|
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
|
||||||
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
|
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
|
||||||
|
|
||||||
|
_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||||
|
_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline'
|
||||||
|
_BASE_API_DIR = 'tensorflow/core/api_def/base_api'
|
||||||
|
_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api'
|
||||||
|
|
||||||
|
|
||||||
def _KeyToFilePath(key):
|
def _KeyToFilePath(key):
|
||||||
"""From a given key, construct a filepath."""
|
"""From a given key, construct a filepath."""
|
||||||
@ -88,6 +97,30 @@ def _FileNameToKey(filename):
|
|||||||
return api_object_key
|
return api_object_key
|
||||||
|
|
||||||
|
|
||||||
|
def _GetSymbol(symbol_id):
|
||||||
|
"""Get TensorFlow symbol based on the given identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol_id: Symbol identifier in the form module1.module2. ... .sym.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Symbol corresponding to the given id.
|
||||||
|
"""
|
||||||
|
# Ignore first module which should be tensorflow
|
||||||
|
symbol_id_split = symbol_id.split('.')[1:]
|
||||||
|
symbol = tf
|
||||||
|
for sym in symbol_id_split:
|
||||||
|
symbol = getattr(symbol, sym)
|
||||||
|
return symbol
|
||||||
|
|
||||||
|
|
||||||
|
def _IsGenModule(module_name):
|
||||||
|
if not module_name:
|
||||||
|
return False
|
||||||
|
module_name_split = module_name.split('.')
|
||||||
|
return module_name_split[-1].startswith('gen_')
|
||||||
|
|
||||||
|
|
||||||
class ApiCompatibilityTest(test.TestCase):
|
class ApiCompatibilityTest(test.TestCase):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -229,6 +262,150 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
update_goldens=FLAGS.update_goldens)
|
update_goldens=FLAGS.update_goldens)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiDefTest(test.TestCase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ApiDefTest, self).__init__(*args, **kwargs)
|
||||||
|
self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)')
|
||||||
|
self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])')
|
||||||
|
|
||||||
|
def _GenerateLowerCaseOpName(self, op_name):
|
||||||
|
lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name)
|
||||||
|
return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower()
|
||||||
|
|
||||||
|
def _CreatePythonApiDef(self, base_api_def, endpoint_names):
|
||||||
|
"""Creates Python ApiDef that overrides base_api_def if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_api_def: (api_def_pb2.ApiDef) base ApiDef instance.
|
||||||
|
endpoint_names: List of Python endpoint names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
api_def_pb2.ApiDef instance with overrides for base_api_def
|
||||||
|
if module.name endpoint is different from any existing
|
||||||
|
endpoints in base_api_def. Otherwise, returns None.
|
||||||
|
"""
|
||||||
|
endpoint_names_set = set(endpoint_names)
|
||||||
|
base_endpoint_names_set = {
|
||||||
|
self._GenerateLowerCaseOpName(endpoint.name)
|
||||||
|
for endpoint in base_api_def.endpoint}
|
||||||
|
|
||||||
|
if endpoint_names_set == base_endpoint_names_set:
|
||||||
|
return None # All endpoints are the same
|
||||||
|
|
||||||
|
api_def = api_def_pb2.ApiDef()
|
||||||
|
api_def.graph_op_name = base_api_def.graph_op_name
|
||||||
|
|
||||||
|
for endpoint_name in sorted(endpoint_names):
|
||||||
|
new_endpoint = api_def.endpoint.add()
|
||||||
|
new_endpoint.name = endpoint_name
|
||||||
|
|
||||||
|
return api_def
|
||||||
|
|
||||||
|
def _GetBaseApiMap(self):
|
||||||
|
"""Get a map from graph op name to its base ApiDef.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping graph op name to corresponding ApiDef.
|
||||||
|
"""
|
||||||
|
# Convert base ApiDef in Multiline format to Proto format.
|
||||||
|
converted_base_api_dir = os.path.join(
|
||||||
|
test.get_temp_dir(), 'temp_base_api_defs')
|
||||||
|
subprocess.check_call(
|
||||||
|
[os.path.join(resource_loader.get_root_dir_with_all_resources(),
|
||||||
|
_CONVERT_FROM_MULTILINE_SCRIPT),
|
||||||
|
_BASE_API_DIR, converted_base_api_dir])
|
||||||
|
|
||||||
|
name_to_base_api_def = {}
|
||||||
|
base_api_files = file_io.get_matching_files(
|
||||||
|
os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
|
||||||
|
for base_api_file in base_api_files:
|
||||||
|
if file_io.file_exists(base_api_file):
|
||||||
|
api_defs = api_def_pb2.ApiDefs()
|
||||||
|
text_format.Merge(
|
||||||
|
file_io.read_file_to_string(base_api_file), api_defs)
|
||||||
|
for api_def in api_defs.op:
|
||||||
|
lower_case_name = self._GenerateLowerCaseOpName(api_def.graph_op_name)
|
||||||
|
name_to_base_api_def[lower_case_name] = api_def
|
||||||
|
return name_to_base_api_def
|
||||||
|
|
||||||
|
@unittest.skipUnless(
|
||||||
|
sys.version_info.major == 2 and os.uname()[0] == 'Linux',
|
||||||
|
'API compabitility test goldens are generated using python2 on Linux.')
|
||||||
|
def testAPIDefCompatibility(self):
|
||||||
|
# Get base ApiDef
|
||||||
|
name_to_base_api_def = self._GetBaseApiMap()
|
||||||
|
# Extract Python API
|
||||||
|
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
|
||||||
|
public_api_visitor = public_api.PublicAPIVisitor(visitor)
|
||||||
|
public_api_visitor.do_not_descend_map['tf'].append('contrib')
|
||||||
|
traverse.traverse(tf, public_api_visitor)
|
||||||
|
proto_dict = visitor.GetProtos()
|
||||||
|
|
||||||
|
# Map from first character of op name to Python ApiDefs.
|
||||||
|
api_def_map = defaultdict(api_def_pb2.ApiDefs)
|
||||||
|
# We need to override all endpoints even if 1 endpoint differs from base
|
||||||
|
# ApiDef. So, we first create a map from an op to all its endpoints.
|
||||||
|
op_to_endpoint_name = defaultdict(list)
|
||||||
|
|
||||||
|
# Generate map from generated python op to endpoint names.
|
||||||
|
for public_module, value in proto_dict.items():
|
||||||
|
module_obj = _GetSymbol(public_module)
|
||||||
|
for sym in value.tf_module.member_method:
|
||||||
|
obj = getattr(module_obj, sym.name)
|
||||||
|
|
||||||
|
# Check if object is defined in gen_* module. That is,
|
||||||
|
# the object has been generated from OpDef.
|
||||||
|
if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
|
||||||
|
if obj.__name__ not in name_to_base_api_def:
|
||||||
|
# Symbol might be defined only in Python and not generated from
|
||||||
|
# C++ api.
|
||||||
|
continue
|
||||||
|
relative_public_module = public_module[len('tensorflow.'):]
|
||||||
|
full_name = (relative_public_module + '.' + sym.name
|
||||||
|
if relative_public_module else sym.name)
|
||||||
|
op_to_endpoint_name[obj].append(full_name)
|
||||||
|
|
||||||
|
# Generate Python ApiDef overrides.
|
||||||
|
for op, endpoint_names in op_to_endpoint_name.items():
|
||||||
|
api_def = self._CreatePythonApiDef(
|
||||||
|
name_to_base_api_def[op.__name__], endpoint_names)
|
||||||
|
if api_def:
|
||||||
|
api_defs = api_def_map[op.__name__[0].upper()]
|
||||||
|
api_defs.op.extend([api_def])
|
||||||
|
|
||||||
|
for key in _ALPHABET:
|
||||||
|
# Get new ApiDef for the given key.
|
||||||
|
new_api_defs_str = ''
|
||||||
|
if key in api_def_map:
|
||||||
|
new_api_defs = api_def_map[key]
|
||||||
|
new_api_defs.op.sort(key=attrgetter('graph_op_name'))
|
||||||
|
new_api_defs_str = str(new_api_defs)
|
||||||
|
|
||||||
|
# Get current ApiDef for the given key.
|
||||||
|
api_defs_file_path = os.path.join(
|
||||||
|
_PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
|
||||||
|
old_api_defs_str = ''
|
||||||
|
if file_io.file_exists(api_defs_file_path):
|
||||||
|
old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)
|
||||||
|
|
||||||
|
if old_api_defs_str == new_api_defs_str:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if FLAGS.update_goldens:
|
||||||
|
if not new_api_defs_str:
|
||||||
|
logging.info('Deleting %s...' % api_defs_file_path)
|
||||||
|
file_io.delete_file(api_defs_file_path)
|
||||||
|
else:
|
||||||
|
logging.info('Updating %s...' % api_defs_file_path)
|
||||||
|
file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
|
||||||
|
else:
|
||||||
|
self.assertMultiLineEqual(
|
||||||
|
old_api_defs_str, new_api_defs_str,
|
||||||
|
'To update golden API files, run api_compatibility_test locally '
|
||||||
|
'with --update_goldens=True flag.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
63
tensorflow/tools/api/tests/convert_from_multiline.cc
Normal file
63
tensorflow/tools/api/tests/convert_from_multiline.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Converts all *.pbtxt files in a directory from Multiline to proto format.
|
||||||
|
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr char kApiDefFilePattern[] = "*.pbtxt";
|
||||||
|
|
||||||
|
Status ConvertFilesFromMultiline(const string& input_dir,
|
||||||
|
const string& output_dir) {
|
||||||
|
Env* env = Env::Default();
|
||||||
|
|
||||||
|
const string file_pattern = io::JoinPath(input_dir, kApiDefFilePattern);
|
||||||
|
std::vector<string> matching_paths;
|
||||||
|
TF_CHECK_OK(env->GetMatchingPaths(file_pattern, &matching_paths));
|
||||||
|
|
||||||
|
if (!env->IsDirectory(output_dir).ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(env->CreateDir(output_dir));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& path : matching_paths) {
|
||||||
|
string contents;
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, path, &contents));
|
||||||
|
contents = tensorflow::PBTxtFromMultiline(contents);
|
||||||
|
string output_path = io::JoinPath(output_dir, io::Basename(path));
|
||||||
|
// Write contents to output_path
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
tensorflow::WriteStringToFile(env, output_path, contents));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||||
|
|
||||||
|
const std::string usage =
|
||||||
|
"Usage: convert_from_multiline input_dir output_dir";
|
||||||
|
if (argc != 3) {
|
||||||
|
std::cerr << usage << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(tensorflow::ConvertFilesFromMultiline(argv[1], argv[2]));
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user