Adding Python ApiDef overrides.
PiperOrigin-RevId: 172960496
This commit is contained in:
parent
0d6a2e3531
commit
93e8f3c67d
@ -3326,6 +3326,11 @@ filegroup(
|
||||
data = glob(["api_def/base_api/*"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "python_api_def",
|
||||
data = glob(["api_def/python_api/*"]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "api_test",
|
||||
srcs = ["api_def/api_test.cc"],
|
||||
|
18
tensorflow/core/api_def/python_api/api_def_B.pbtxt
Normal file
18
tensorflow/core/api_def/python_api/api_def_B.pbtxt
Normal file
@ -0,0 +1,18 @@
|
||||
op {
|
||||
graph_op_name: "BitwiseAnd"
|
||||
endpoint {
|
||||
name: "bitwise.bitwise_and"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "BitwiseOr"
|
||||
endpoint {
|
||||
name: "bitwise.bitwise_or"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "BitwiseXor"
|
||||
endpoint {
|
||||
name: "bitwise.bitwise_xor"
|
||||
}
|
||||
}
|
15
tensorflow/core/api_def/python_api/api_def_C.pbtxt
Normal file
15
tensorflow/core/api_def/python_api/api_def_C.pbtxt
Normal file
@ -0,0 +1,15 @@
|
||||
op {
|
||||
graph_op_name: "Cholesky"
|
||||
endpoint {
|
||||
name: "cholesky"
|
||||
}
|
||||
endpoint {
|
||||
name: "linalg.cholesky"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "CropAndResize"
|
||||
endpoint {
|
||||
name: "image.crop_and_resize"
|
||||
}
|
||||
}
|
54
tensorflow/core/api_def/python_api/api_def_D.pbtxt
Normal file
54
tensorflow/core/api_def/python_api/api_def_D.pbtxt
Normal file
@ -0,0 +1,54 @@
|
||||
op {
|
||||
graph_op_name: "DecodeAndCropJpeg"
|
||||
endpoint {
|
||||
name: "image.decode_and_crop_jpeg"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DecodeBmp"
|
||||
endpoint {
|
||||
name: "image.decode_bmp"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DecodeGif"
|
||||
endpoint {
|
||||
name: "image.decode_gif"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DecodeJpeg"
|
||||
endpoint {
|
||||
name: "image.decode_jpeg"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DecodePng"
|
||||
endpoint {
|
||||
name: "image.decode_png"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DepthwiseConv2dNative"
|
||||
endpoint {
|
||||
name: "nn.depthwise_conv2d_native"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DepthwiseConv2dNativeBackpropFilter"
|
||||
endpoint {
|
||||
name: "nn.depthwise_conv2d_native_backprop_filter"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DepthwiseConv2dNativeBackpropInput"
|
||||
endpoint {
|
||||
name: "nn.depthwise_conv2d_native_backprop_input"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "DrawBoundingBoxes"
|
||||
endpoint {
|
||||
name: "image.draw_bounding_boxes"
|
||||
}
|
||||
}
|
30
tensorflow/core/api_def/python_api/api_def_E.pbtxt
Normal file
30
tensorflow/core/api_def/python_api/api_def_E.pbtxt
Normal file
@ -0,0 +1,30 @@
|
||||
op {
|
||||
graph_op_name: "Elu"
|
||||
endpoint {
|
||||
name: "nn.elu"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "EncodeJpeg"
|
||||
endpoint {
|
||||
name: "image.encode_jpeg"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "EncodePng"
|
||||
endpoint {
|
||||
name: "image.encode_png"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ExtractGlimpse"
|
||||
endpoint {
|
||||
name: "image.extract_glimpse"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ExtractJpegShape"
|
||||
endpoint {
|
||||
name: "image.extract_jpeg_shape"
|
||||
}
|
||||
}
|
21
tensorflow/core/api_def/python_api/api_def_F.pbtxt
Normal file
21
tensorflow/core/api_def/python_api/api_def_F.pbtxt
Normal file
@ -0,0 +1,21 @@
|
||||
op {
|
||||
graph_op_name: "FFT"
|
||||
endpoint {
|
||||
name: "fft"
|
||||
}
|
||||
endpoint {
|
||||
name: "spectral.fft"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "FractionalAvgPool"
|
||||
endpoint {
|
||||
name: "nn.fractional_avg_pool"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "FractionalMaxPool"
|
||||
endpoint {
|
||||
name: "nn.fractional_max_pool"
|
||||
}
|
||||
}
|
6
tensorflow/core/api_def/python_api/api_def_H.pbtxt
Normal file
6
tensorflow/core/api_def/python_api/api_def_H.pbtxt
Normal file
@ -0,0 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "HSVToRGB"
|
||||
endpoint {
|
||||
name: "image.hsv_to_rgb"
|
||||
}
|
||||
}
|
15
tensorflow/core/api_def/python_api/api_def_I.pbtxt
Normal file
15
tensorflow/core/api_def/python_api/api_def_I.pbtxt
Normal file
@ -0,0 +1,15 @@
|
||||
op {
|
||||
graph_op_name: "IFFT"
|
||||
endpoint {
|
||||
name: "ifft"
|
||||
}
|
||||
endpoint {
|
||||
name: "spectral.ifft"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Invert"
|
||||
endpoint {
|
||||
name: "bitwise.invert"
|
||||
}
|
||||
}
|
24
tensorflow/core/api_def/python_api/api_def_L.pbtxt
Normal file
24
tensorflow/core/api_def/python_api/api_def_L.pbtxt
Normal file
@ -0,0 +1,24 @@
|
||||
op {
|
||||
graph_op_name: "L2Loss"
|
||||
endpoint {
|
||||
name: "nn.l2_loss"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "LRN"
|
||||
endpoint {
|
||||
name: "nn.local_response_normalization"
|
||||
}
|
||||
endpoint {
|
||||
name: "nn.lrn"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "LinSpace"
|
||||
endpoint {
|
||||
name: "lin_space"
|
||||
}
|
||||
endpoint {
|
||||
name: "linspace"
|
||||
}
|
||||
}
|
78
tensorflow/core/api_def/python_api/api_def_M.pbtxt
Normal file
78
tensorflow/core/api_def/python_api/api_def_M.pbtxt
Normal file
@ -0,0 +1,78 @@
|
||||
op {
|
||||
graph_op_name: "MatrixBandPart"
|
||||
endpoint {
|
||||
name: "linalg.band_part"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_band_part"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixDeterminant"
|
||||
endpoint {
|
||||
name: "linalg.det"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_determinant"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixDiag"
|
||||
endpoint {
|
||||
name: "linalg.diag"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_diag"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixDiagPart"
|
||||
endpoint {
|
||||
name: "linalg.diag_part"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_diag_part"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixInverse"
|
||||
endpoint {
|
||||
name: "linalg.inv"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_inverse"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixSetDiag"
|
||||
endpoint {
|
||||
name: "linalg.set_diag"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_set_diag"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixSolve"
|
||||
endpoint {
|
||||
name: "linalg.solve"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_solve"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MatrixTriangularSolve"
|
||||
endpoint {
|
||||
name: "linalg.triangular_solve"
|
||||
}
|
||||
endpoint {
|
||||
name: "matrix_triangular_solve"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "MaxPoolWithArgmax"
|
||||
endpoint {
|
||||
name: "nn.max_pool_with_argmax"
|
||||
}
|
||||
}
|
27
tensorflow/core/api_def/python_api/api_def_Q.pbtxt
Normal file
27
tensorflow/core/api_def/python_api/api_def_Q.pbtxt
Normal file
@ -0,0 +1,27 @@
|
||||
op {
|
||||
graph_op_name: "Qr"
|
||||
endpoint {
|
||||
name: "linalg.qr"
|
||||
}
|
||||
endpoint {
|
||||
name: "qr"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "QuantizedAvgPool"
|
||||
endpoint {
|
||||
name: "nn.quantized_avg_pool"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "QuantizedMaxPool"
|
||||
endpoint {
|
||||
name: "nn.quantized_max_pool"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "QuantizedReluX"
|
||||
endpoint {
|
||||
name: "nn.quantized_relu_x"
|
||||
}
|
||||
}
|
36
tensorflow/core/api_def/python_api/api_def_R.pbtxt
Normal file
36
tensorflow/core/api_def/python_api/api_def_R.pbtxt
Normal file
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "RGBToHSV"
|
||||
endpoint {
|
||||
name: "image.rgb_to_hsv"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Relu"
|
||||
endpoint {
|
||||
name: "nn.relu"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ResizeArea"
|
||||
endpoint {
|
||||
name: "image.resize_area"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ResizeBicubic"
|
||||
endpoint {
|
||||
name: "image.resize_bicubic"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ResizeBilinear"
|
||||
endpoint {
|
||||
name: "image.resize_bilinear"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "ResizeNearestNeighbor"
|
||||
endpoint {
|
||||
name: "image.resize_nearest_neighbor"
|
||||
}
|
||||
}
|
36
tensorflow/core/api_def/python_api/api_def_S.pbtxt
Normal file
36
tensorflow/core/api_def/python_api/api_def_S.pbtxt
Normal file
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "SdcaFprint"
|
||||
endpoint {
|
||||
name: "train.sdca_fprint"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "SdcaOptimizer"
|
||||
endpoint {
|
||||
name: "train.sdca_optimizer"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "SdcaShrinkL1"
|
||||
endpoint {
|
||||
name: "train.sdca_shrink_l1"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Selu"
|
||||
endpoint {
|
||||
name: "nn.selu"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Softplus"
|
||||
endpoint {
|
||||
name: "nn.softplus"
|
||||
}
|
||||
}
|
||||
op {
|
||||
graph_op_name: "Softsign"
|
||||
endpoint {
|
||||
name: "nn.softsign"
|
||||
}
|
||||
}
|
@ -11,10 +11,15 @@ exports_files([
|
||||
"API_UPDATE_WARNING.txt",
|
||||
])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
|
||||
py_test(
|
||||
name = "api_compatibility_test",
|
||||
srcs = ["api_compatibility_test.py"],
|
||||
data = [
|
||||
":convert_from_multiline",
|
||||
"//tensorflow/core:base_api_def",
|
||||
"//tensorflow/core:python_api_def",
|
||||
"//tensorflow/tools/api/golden:api_golden",
|
||||
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
|
||||
"//tensorflow/tools/api/tests:README.txt",
|
||||
@ -23,6 +28,7 @@ py_test(
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
|
||||
@ -31,6 +37,15 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "convert_from_multiline",
|
||||
srcs = ["convert_from_multiline.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:op_gen_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -28,8 +28,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from operator import attrgetter
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
@ -37,6 +40,7 @@ import tensorflow as tf
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import api_def_pb2
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.platform import test
|
||||
@ -64,6 +68,11 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden'
|
||||
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
|
||||
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
|
||||
|
||||
_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline'
|
||||
_BASE_API_DIR = 'tensorflow/core/api_def/base_api'
|
||||
_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api'
|
||||
|
||||
|
||||
def _KeyToFilePath(key):
|
||||
"""From a given key, construct a filepath."""
|
||||
@ -88,6 +97,30 @@ def _FileNameToKey(filename):
|
||||
return api_object_key
|
||||
|
||||
|
||||
def _GetSymbol(symbol_id):
|
||||
"""Get TensorFlow symbol based on the given identifier.
|
||||
|
||||
Args:
|
||||
symbol_id: Symbol identifier in the form module1.module2. ... .sym.
|
||||
|
||||
Returns:
|
||||
Symbol corresponding to the given id.
|
||||
"""
|
||||
# Ignore first module which should be tensorflow
|
||||
symbol_id_split = symbol_id.split('.')[1:]
|
||||
symbol = tf
|
||||
for sym in symbol_id_split:
|
||||
symbol = getattr(symbol, sym)
|
||||
return symbol
|
||||
|
||||
|
||||
def _IsGenModule(module_name):
|
||||
if not module_name:
|
||||
return False
|
||||
module_name_split = module_name.split('.')
|
||||
return module_name_split[-1].startswith('gen_')
|
||||
|
||||
|
||||
class ApiCompatibilityTest(test.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@ -229,6 +262,150 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
update_goldens=FLAGS.update_goldens)
|
||||
|
||||
|
||||
class ApiDefTest(test.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ApiDefTest, self).__init__(*args, **kwargs)
|
||||
self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)')
|
||||
self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])')
|
||||
|
||||
def _GenerateLowerCaseOpName(self, op_name):
|
||||
lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name)
|
||||
return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower()
|
||||
|
||||
def _CreatePythonApiDef(self, base_api_def, endpoint_names):
|
||||
"""Creates Python ApiDef that overrides base_api_def if needed.
|
||||
|
||||
Args:
|
||||
base_api_def: (api_def_pb2.ApiDef) base ApiDef instance.
|
||||
endpoint_names: List of Python endpoint names.
|
||||
|
||||
Returns:
|
||||
api_def_pb2.ApiDef instance with overrides for base_api_def
|
||||
if module.name endpoint is different from any existing
|
||||
endpoints in base_api_def. Otherwise, returns None.
|
||||
"""
|
||||
endpoint_names_set = set(endpoint_names)
|
||||
base_endpoint_names_set = {
|
||||
self._GenerateLowerCaseOpName(endpoint.name)
|
||||
for endpoint in base_api_def.endpoint}
|
||||
|
||||
if endpoint_names_set == base_endpoint_names_set:
|
||||
return None # All endpoints are the same
|
||||
|
||||
api_def = api_def_pb2.ApiDef()
|
||||
api_def.graph_op_name = base_api_def.graph_op_name
|
||||
|
||||
for endpoint_name in sorted(endpoint_names):
|
||||
new_endpoint = api_def.endpoint.add()
|
||||
new_endpoint.name = endpoint_name
|
||||
|
||||
return api_def
|
||||
|
||||
def _GetBaseApiMap(self):
|
||||
"""Get a map from graph op name to its base ApiDef.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping graph op name to corresponding ApiDef.
|
||||
"""
|
||||
# Convert base ApiDef in Multiline format to Proto format.
|
||||
converted_base_api_dir = os.path.join(
|
||||
test.get_temp_dir(), 'temp_base_api_defs')
|
||||
subprocess.check_call(
|
||||
[os.path.join(resource_loader.get_root_dir_with_all_resources(),
|
||||
_CONVERT_FROM_MULTILINE_SCRIPT),
|
||||
_BASE_API_DIR, converted_base_api_dir])
|
||||
|
||||
name_to_base_api_def = {}
|
||||
base_api_files = file_io.get_matching_files(
|
||||
os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
|
||||
for base_api_file in base_api_files:
|
||||
if file_io.file_exists(base_api_file):
|
||||
api_defs = api_def_pb2.ApiDefs()
|
||||
text_format.Merge(
|
||||
file_io.read_file_to_string(base_api_file), api_defs)
|
||||
for api_def in api_defs.op:
|
||||
lower_case_name = self._GenerateLowerCaseOpName(api_def.graph_op_name)
|
||||
name_to_base_api_def[lower_case_name] = api_def
|
||||
return name_to_base_api_def
|
||||
|
||||
@unittest.skipUnless(
|
||||
sys.version_info.major == 2 and os.uname()[0] == 'Linux',
|
||||
'API compabitility test goldens are generated using python2 on Linux.')
|
||||
def testAPIDefCompatibility(self):
|
||||
# Get base ApiDef
|
||||
name_to_base_api_def = self._GetBaseApiMap()
|
||||
# Extract Python API
|
||||
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
|
||||
public_api_visitor = public_api.PublicAPIVisitor(visitor)
|
||||
public_api_visitor.do_not_descend_map['tf'].append('contrib')
|
||||
traverse.traverse(tf, public_api_visitor)
|
||||
proto_dict = visitor.GetProtos()
|
||||
|
||||
# Map from first character of op name to Python ApiDefs.
|
||||
api_def_map = defaultdict(api_def_pb2.ApiDefs)
|
||||
# We need to override all endpoints even if 1 endpoint differs from base
|
||||
# ApiDef. So, we first create a map from an op to all its endpoints.
|
||||
op_to_endpoint_name = defaultdict(list)
|
||||
|
||||
# Generate map from generated python op to endpoint names.
|
||||
for public_module, value in proto_dict.items():
|
||||
module_obj = _GetSymbol(public_module)
|
||||
for sym in value.tf_module.member_method:
|
||||
obj = getattr(module_obj, sym.name)
|
||||
|
||||
# Check if object is defined in gen_* module. That is,
|
||||
# the object has been generated from OpDef.
|
||||
if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
|
||||
if obj.__name__ not in name_to_base_api_def:
|
||||
# Symbol might be defined only in Python and not generated from
|
||||
# C++ api.
|
||||
continue
|
||||
relative_public_module = public_module[len('tensorflow.'):]
|
||||
full_name = (relative_public_module + '.' + sym.name
|
||||
if relative_public_module else sym.name)
|
||||
op_to_endpoint_name[obj].append(full_name)
|
||||
|
||||
# Generate Python ApiDef overrides.
|
||||
for op, endpoint_names in op_to_endpoint_name.items():
|
||||
api_def = self._CreatePythonApiDef(
|
||||
name_to_base_api_def[op.__name__], endpoint_names)
|
||||
if api_def:
|
||||
api_defs = api_def_map[op.__name__[0].upper()]
|
||||
api_defs.op.extend([api_def])
|
||||
|
||||
for key in _ALPHABET:
|
||||
# Get new ApiDef for the given key.
|
||||
new_api_defs_str = ''
|
||||
if key in api_def_map:
|
||||
new_api_defs = api_def_map[key]
|
||||
new_api_defs.op.sort(key=attrgetter('graph_op_name'))
|
||||
new_api_defs_str = str(new_api_defs)
|
||||
|
||||
# Get current ApiDef for the given key.
|
||||
api_defs_file_path = os.path.join(
|
||||
_PYTHON_API_DIR, 'api_def_%s.pbtxt' % key)
|
||||
old_api_defs_str = ''
|
||||
if file_io.file_exists(api_defs_file_path):
|
||||
old_api_defs_str = file_io.read_file_to_string(api_defs_file_path)
|
||||
|
||||
if old_api_defs_str == new_api_defs_str:
|
||||
continue
|
||||
|
||||
if FLAGS.update_goldens:
|
||||
if not new_api_defs_str:
|
||||
logging.info('Deleting %s...' % api_defs_file_path)
|
||||
file_io.delete_file(api_defs_file_path)
|
||||
else:
|
||||
logging.info('Updating %s...' % api_defs_file_path)
|
||||
file_io.write_string_to_file(api_defs_file_path, new_api_defs_str)
|
||||
else:
|
||||
self.assertMultiLineEqual(
|
||||
old_api_defs_str, new_api_defs_str,
|
||||
'To update golden API files, run api_compatibility_test locally '
|
||||
'with --update_goldens=True flag.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
63
tensorflow/tools/api/tests/convert_from_multiline.cc
Normal file
63
tensorflow/tools/api/tests/convert_from_multiline.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Converts all *.pbtxt files in a directory from Multiline to proto format.
|
||||
#include "tensorflow/core/framework/op_gen_lib.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
constexpr char kApiDefFilePattern[] = "*.pbtxt";
|
||||
|
||||
Status ConvertFilesFromMultiline(const string& input_dir,
|
||||
const string& output_dir) {
|
||||
Env* env = Env::Default();
|
||||
|
||||
const string file_pattern = io::JoinPath(input_dir, kApiDefFilePattern);
|
||||
std::vector<string> matching_paths;
|
||||
TF_CHECK_OK(env->GetMatchingPaths(file_pattern, &matching_paths));
|
||||
|
||||
if (!env->IsDirectory(output_dir).ok()) {
|
||||
TF_RETURN_IF_ERROR(env->CreateDir(output_dir));
|
||||
}
|
||||
|
||||
for (const auto& path : matching_paths) {
|
||||
string contents;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, path, &contents));
|
||||
contents = tensorflow::PBTxtFromMultiline(contents);
|
||||
string output_path = io::JoinPath(output_dir, io::Basename(path));
|
||||
// Write contents to output_path
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::WriteStringToFile(env, output_path, contents));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
tensorflow::port::InitMain(argv[0], &argc, &argv);
|
||||
|
||||
const std::string usage =
|
||||
"Usage: convert_from_multiline input_dir output_dir";
|
||||
if (argc != 3) {
|
||||
std::cerr << usage << std::endl;
|
||||
return -1;
|
||||
}
|
||||
TF_CHECK_OK(tensorflow::ConvertFilesFromMultiline(argv[1], argv[2]));
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user