Add test for API generated with tf_export calls to api_compatibility_test.py.
Also, fix remaining differences between this new API and the current TensorFlow API. PiperOrigin-RevId: 188943768
This commit is contained in:
parent
86dd46a3c6
commit
d57f0213bf
tensorflow
BUILD
core/api_def/python_api
api_def_Assign.pbtxtapi_def_AssignAdd.pbtxtapi_def_AssignSub.pbtxtapi_def_SparseReduceMax.pbtxtapi_def_SparseReduceMaxSparse.pbtxtapi_def_SparseReduceSum.pbtxtapi_def_SparseReduceSumSparse.pbtxtapi_def_SparseSlice.pbtxtapi_def_SparseSoftmax.pbtxt
experimental_api.pypython
framework
keras
ops
platform
tools/api
@ -830,3 +830,14 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/python"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "experimental_tensorflow_py",
|
||||
srcs = ["experimental_api.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow/tools/api/tests:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/python",
|
||||
"//tensorflow/tools/api/generator:python_api",
|
||||
],
|
||||
)
|
||||
|
4
tensorflow/core/api_def/python_api/api_def_Assign.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Assign.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "Assign"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AssignAdd"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "AssignSub"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseReduceMax"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseReduceMaxSparse"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseReduceSum"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseReduceSumSparse"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseSlice"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "SparseSoftmax"
|
||||
visibility: HIDDEN
|
||||
}
|
38
tensorflow/experimental_api.py
Normal file
38
tensorflow/experimental_api.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# Bring in all of the public TensorFlow interface into this
|
||||
# module.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
|
||||
# pylint: enable=wildcard-import
|
||||
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
||||
del LazyLoader
|
||||
|
||||
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||
app.flags = flags # pylint: disable=undefined-variable
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
@ -343,7 +343,9 @@ tf_export("uint8").export_constant(__name__, "uint8")
|
||||
uint16 = DType(types_pb2.DT_UINT16)
|
||||
tf_export("uint16").export_constant(__name__, "uint16")
|
||||
uint32 = DType(types_pb2.DT_UINT32)
|
||||
tf_export("uint32").export_constant(__name__, "uint32")
|
||||
uint64 = DType(types_pb2.DT_UINT64)
|
||||
tf_export("uint64").export_constant(__name__, "uint32")
|
||||
int16 = DType(types_pb2.DT_INT16)
|
||||
tf_export("int16").export_constant(__name__, "int16")
|
||||
int8 = DType(types_pb2.DT_INT8)
|
||||
|
@ -24,8 +24,10 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('keras.datasets.fashion_mnist.load_data')
|
||||
def load_data():
|
||||
"""Loads the Fashion-MNIST dataset.
|
||||
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export('keras.layers.InputLayer')
|
||||
class InputLayer(base_layer.Layer):
|
||||
"""Layer to be used as an entry point into a Network (a graph of layers).
|
||||
|
||||
|
@ -0,0 +1,25 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Fashion-MNIST dataset."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
@ -1795,6 +1795,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059,
|
||||
[0.114, -0.32134392, 0.31119955]]
|
||||
|
||||
|
||||
@tf_export('image.rgb_to_yiq')
|
||||
def rgb_to_yiq(images):
|
||||
"""Converts one or more images from RGB to YIQ.
|
||||
|
||||
@ -1820,6 +1821,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
|
||||
[0.6208248, -0.64720424, 1.70423049]]
|
||||
|
||||
|
||||
@tf_export('image.yiq_to_rgb')
|
||||
def yiq_to_rgb(images):
|
||||
"""Converts one or more images from YIQ to RGB.
|
||||
|
||||
@ -1847,6 +1849,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119,
|
||||
[0.114, 0.43601035, -0.10001026]]
|
||||
|
||||
|
||||
@tf_export('image.rgb_to_yuv')
|
||||
def rgb_to_yuv(images):
|
||||
"""Converts one or more images from RGB to YUV.
|
||||
|
||||
@ -1872,6 +1875,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
|
||||
[1.13988303, -0.58062185, 0]]
|
||||
|
||||
|
||||
@tf_export('image.yuv_to_rgb')
|
||||
def yuv_to_rgb(images):
|
||||
"""Converts one or more images from YUV to RGB.
|
||||
|
||||
|
@ -39,5 +39,8 @@ global_variables = _variables.global_variables_initializer
|
||||
local_variables = _variables.local_variables_initializer
|
||||
|
||||
# Seal API.
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
||||
del init_ops
|
||||
del _variables
|
||||
|
@ -32,15 +32,18 @@ cholesky = linalg_ops.cholesky
|
||||
cholesky_solve = linalg_ops.cholesky_solve
|
||||
det = linalg_ops.matrix_determinant
|
||||
slogdet = gen_linalg_ops.log_matrix_determinant
|
||||
tf_export('linalg.slogdet')(slogdet)
|
||||
diag = array_ops.matrix_diag
|
||||
diag_part = array_ops.matrix_diag_part
|
||||
eigh = linalg_ops.self_adjoint_eig
|
||||
eigvalsh = linalg_ops.self_adjoint_eigvals
|
||||
einsum = special_math_ops.einsum
|
||||
expm = gen_linalg_ops.matrix_exponential
|
||||
tf_export('linalg.expm')(expm)
|
||||
eye = linalg_ops.eye
|
||||
inv = linalg_ops.matrix_inverse
|
||||
logm = gen_linalg_ops.matrix_logarithm
|
||||
tf_export('linalg.logm')(logm)
|
||||
lstsq = linalg_ops.matrix_solve_ls
|
||||
norm = linalg_ops.norm
|
||||
qr = linalg_ops.qr
|
||||
|
@ -23,9 +23,11 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@tf_export('manip.roll')
|
||||
def roll(input, shift, axis): # pylint: disable=redefined-builtin
|
||||
return _gen_manip_ops.roll(input, shift, axis)
|
||||
|
||||
|
@ -1184,11 +1184,16 @@ def floordiv(x, y, name=None):
|
||||
|
||||
|
||||
realdiv = gen_math_ops.real_div
|
||||
tf_export("realdiv")(realdiv)
|
||||
truncatediv = gen_math_ops.truncate_div
|
||||
tf_export("truncatediv")(truncatediv)
|
||||
# TODO(aselle): Rename this to floordiv when we can.
|
||||
floor_div = gen_math_ops.floor_div
|
||||
tf_export("floor_div")(floor_div)
|
||||
truncatemod = gen_math_ops.truncate_mod
|
||||
tf_export("truncatemod")(truncatemod)
|
||||
floormod = gen_math_ops.floor_mod
|
||||
tf_export("floormod")(floormod)
|
||||
|
||||
|
||||
def _mul_dispatch(x, y, name=None):
|
||||
@ -2111,6 +2116,7 @@ def matmul(a,
|
||||
_OverrideBinaryOperatorHelper(matmul, "matmul")
|
||||
|
||||
sparse_matmul = gen_math_ops.sparse_mat_mul
|
||||
tf_export("sparse_matmul")(sparse_matmul)
|
||||
|
||||
|
||||
@ops.RegisterStatistics("MatMul", "flops")
|
||||
|
@ -303,12 +303,12 @@ def _swish_grad(features, grad):
|
||||
# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
|
||||
# during backprop, and we can free the sigmoid(features) expression immediately
|
||||
# after use during the forward pass.
|
||||
@tf_export("nn.swish")
|
||||
@function.Defun(
|
||||
grad_func=_swish_grad,
|
||||
shape_func=_swish_shape,
|
||||
func_name="swish",
|
||||
noinline=True)
|
||||
@tf_export("nn.swish")
|
||||
def swish(features):
|
||||
# pylint: disable=g-doc-args
|
||||
"""Computes the Swish activation function: `x * sigmoid(x)`.
|
||||
@ -1343,4 +1343,4 @@ def sampled_softmax_loss(weights,
|
||||
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits)
|
||||
# sampled_losses is a [batch_size] tensor.
|
||||
return sampled_losses
|
||||
return sampled_losses
|
||||
|
@ -1385,7 +1385,6 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
|
||||
"GraphKeys.GLOBAL_VARIABLES")
|
||||
|
||||
|
||||
@functools.wraps(get_variable)
|
||||
@tf_export("get_local_variable")
|
||||
def get_local_variable(*args, **kwargs):
|
||||
kwargs["trainable"] = False
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name
|
||||
@ -138,6 +139,7 @@ def StatefulSessionAvailable():
|
||||
return False
|
||||
|
||||
|
||||
@tf_export('test.StubOutForTesting')
|
||||
class StubOutForTesting(object):
|
||||
"""Support class for stubbing methods out for unit testing.
|
||||
|
||||
|
@ -62,6 +62,8 @@ if sys.version_info.major == 2:
|
||||
else:
|
||||
from unittest import mock # pylint: disable=g-import-not-at-top
|
||||
|
||||
tf_export('test.mock')(mock)
|
||||
|
||||
# Import Benchmark class
|
||||
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Description:
|
||||
# Scripts used to generate TensorFlow Python API.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
@ -21,7 +22,7 @@ py_binary(
|
||||
srcs = ["create_python_api.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python",
|
||||
],
|
||||
)
|
||||
|
||||
@ -80,6 +81,7 @@ genrule(
|
||||
"api/keras/datasets/boston_housing/__init__.py",
|
||||
"api/keras/datasets/cifar10/__init__.py",
|
||||
"api/keras/datasets/cifar100/__init__.py",
|
||||
"api/keras/datasets/fashion_mnist/__init__.py",
|
||||
"api/keras/datasets/imdb/__init__.py",
|
||||
"api/keras/datasets/mnist/__init__.py",
|
||||
"api/keras/datasets/reuters/__init__.py",
|
||||
@ -102,6 +104,7 @@ genrule(
|
||||
"api/linalg/__init__.py",
|
||||
"api/logging/__init__.py",
|
||||
"api/losses/__init__.py",
|
||||
"api/manip/__init__.py",
|
||||
"api/metrics/__init__.py",
|
||||
"api/nn/__init__.py",
|
||||
"api/nn/rnn_cell/__init__.py",
|
||||
@ -133,7 +136,9 @@ py_library(
|
||||
name = "python_api",
|
||||
srcs = [":python_api_gen"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow/contrib:contrib_py", # keep
|
||||
"//tensorflow/python", # keep
|
||||
],
|
||||
)
|
||||
|
@ -23,15 +23,14 @@ import collections
|
||||
import os
|
||||
import sys
|
||||
|
||||
# This import is needed so that we can traverse over TensorFlow modules.
|
||||
import tensorflow as tf # pylint: disable=unused-import
|
||||
from tensorflow import python as tf
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
|
||||
_API_CONSTANTS_ATTR = '_tf_api_constants'
|
||||
_API_NAMES_ATTR = '_tf_api_names'
|
||||
_API_DIR = '/api/'
|
||||
_CONTRIB_IMPORT = 'from tensorflow import contrib'
|
||||
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
|
||||
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
|
||||
|
||||
This file is MACHINE GENERATED! Do not edit.
|
||||
@ -92,7 +91,7 @@ def get_api_imports():
|
||||
if module_contents_name == _API_CONSTANTS_ATTR:
|
||||
for exports, value in attr:
|
||||
for export in exports:
|
||||
names = ['tf'] + export.split('.')
|
||||
names = export.split('.')
|
||||
dest_module = '.'.join(names[:-1])
|
||||
import_str = format_import(module.__name__, value, names[-1])
|
||||
module_imports[dest_module].append(import_str)
|
||||
@ -104,29 +103,43 @@ def get_api_imports():
|
||||
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
|
||||
# The same op might be accessible from multiple modules.
|
||||
# We only want to consider location where function was defined.
|
||||
if attr.__module__ != module.__name__:
|
||||
# Here we check if the op is defined in another TensorFlow module in
|
||||
# sys.modules.
|
||||
if (hasattr(attr, '__module__') and
|
||||
attr.__module__.startswith(tf.__name__) and
|
||||
attr.__module__ != module.__name__ and
|
||||
attr.__module__ in sys.modules and
|
||||
module_contents_name in dir(sys.modules[attr.__module__])):
|
||||
continue
|
||||
|
||||
for export in attr._tf_api_names: # pylint: disable=protected-access
|
||||
names = ['tf'] + export.split('.')
|
||||
names = export.split('.')
|
||||
dest_module = '.'.join(names[:-1])
|
||||
import_str = format_import(
|
||||
module.__name__, module_contents_name, names[-1])
|
||||
module_imports[dest_module].append(import_str)
|
||||
|
||||
# Import all required modules in their parent modules.
|
||||
# For e.g. if we import 'tf.foo.bar.Value'. Then, we also
|
||||
# import 'bar' in 'tf.foo'.
|
||||
dest_modules = set(module_imports.keys())
|
||||
for dest_module in dest_modules:
|
||||
dest_module_split = dest_module.split('.')
|
||||
for dest_submodule_index in range(1, len(dest_module_split)):
|
||||
dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
|
||||
# For e.g. if we import 'foo.bar.Value'. Then, we also
|
||||
# import 'bar' in 'foo'.
|
||||
imported_modules = set(module_imports.keys())
|
||||
for module in imported_modules:
|
||||
if not module:
|
||||
continue
|
||||
module_split = module.split('.')
|
||||
parent_module = '' # we import submodules in their parent_module
|
||||
|
||||
for submodule_index in range(len(module_split)):
|
||||
import_from = _OUTPUT_MODULE
|
||||
if submodule_index > 0:
|
||||
parent_module += ('.' + module_split[submodule_index-1] if parent_module
|
||||
else module_split[submodule_index-1])
|
||||
import_from += '.' + parent_module
|
||||
submodule_import = format_import(
|
||||
'', dest_module_split[dest_submodule_index],
|
||||
dest_module_split[dest_submodule_index])
|
||||
if submodule_import not in module_imports[dest_submodule]:
|
||||
module_imports[dest_submodule].append(submodule_import)
|
||||
import_from, module_split[submodule_index],
|
||||
module_split[submodule_index])
|
||||
if submodule_import not in module_imports[parent_module]:
|
||||
module_imports[parent_module].append(submodule_import)
|
||||
|
||||
return module_imports
|
||||
|
||||
@ -151,8 +164,8 @@ def create_api_files(output_files):
|
||||
# First get module directory under _API_DIR.
|
||||
module_dir = os.path.dirname(
|
||||
output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
|
||||
# Convert / to . and prefix with tf.
|
||||
module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.')
|
||||
# Convert / to .
|
||||
module_name = module_dir.replace('/', '.').strip('.')
|
||||
module_name_to_file_path[module_name] = output_file
|
||||
|
||||
# Create file for each expected output in genrule.
|
||||
@ -162,16 +175,14 @@ def create_api_files(output_files):
|
||||
open(file_path, 'a').close()
|
||||
|
||||
module_imports = get_api_imports()
|
||||
module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib.
|
||||
|
||||
# Add imports to output files.
|
||||
missing_output_files = []
|
||||
for module, exports in module_imports.items():
|
||||
# Make sure genrule output file list is in sync with API exports.
|
||||
if module not in module_name_to_file_path:
|
||||
module_without_tf = module[len('tf.'):]
|
||||
module_file_path = '"api/%s/__init__.py"' % (
|
||||
module_without_tf.replace('.', '/'))
|
||||
module.replace('.', '/'))
|
||||
missing_output_files.append(module_file_path)
|
||||
continue
|
||||
with open(module_name_to_file_path[module], 'w') as fp:
|
||||
|
@ -1,17 +1,9 @@
|
||||
path: "tensorflow.initializers"
|
||||
tf_module {
|
||||
member {
|
||||
name: "absolute_import"
|
||||
mtype: "<type \'instance\'>"
|
||||
}
|
||||
member {
|
||||
name: "constant"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "division"
|
||||
mtype: "<type \'instance\'>"
|
||||
}
|
||||
member {
|
||||
name: "identity"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -24,10 +16,6 @@ tf_module {
|
||||
name: "orthogonal"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "print_function"
|
||||
mtype: "<type \'instance\'>"
|
||||
}
|
||||
member {
|
||||
name: "random_normal"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -1,3 +1,7 @@
|
||||
path: "tensorflow.keras.datasets.fashion_mnist"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "load_data"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ py_test(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:experimental_tensorflow_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:lib",
|
||||
|
@ -34,6 +34,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow import experimental_api as api
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
@ -46,6 +47,9 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor
|
||||
from tensorflow.tools.common import public_api
|
||||
from tensorflow.tools.common import traverse
|
||||
|
||||
if hasattr(tf, 'experimental_api'):
|
||||
del tf.experimental_api
|
||||
|
||||
# FLAGS defined at the bottom:
|
||||
FLAGS = None
|
||||
# DEFINE_boolean, update_goldens, default False:
|
||||
@ -109,7 +113,8 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
expected_dict,
|
||||
actual_dict,
|
||||
verbose=False,
|
||||
update_goldens=False):
|
||||
update_goldens=False,
|
||||
additional_missing_object_message=''):
|
||||
"""Diff given dicts of protobufs and report differences a readable way.
|
||||
|
||||
Args:
|
||||
@ -120,6 +125,8 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
verbose: Whether to log the full diffs, or simply report which files were
|
||||
different.
|
||||
update_goldens: Whether to update goldens when there are diffs found.
|
||||
additional_missing_object_message: Message to print when a symbol is
|
||||
missing.
|
||||
"""
|
||||
diffs = []
|
||||
verbose_diffs = []
|
||||
@ -138,7 +145,8 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
verbose_diff_message = ''
|
||||
# First check if the key is not found in one or the other.
|
||||
if key in only_in_expected:
|
||||
diff_message = 'Object %s expected but not found (removed).' % key
|
||||
diff_message = 'Object %s expected but not found (removed). %s' % (
|
||||
key, additional_missing_object_message)
|
||||
verbose_diff_message = diff_message
|
||||
elif key in only_in_actual:
|
||||
diff_message = 'New object %s found (added).' % key
|
||||
@ -229,6 +237,64 @@ class ApiCompatibilityTest(test.TestCase):
|
||||
verbose=FLAGS.verbose_diffs,
|
||||
update_goldens=FLAGS.update_goldens)
|
||||
|
||||
@unittest.skipUnless(
|
||||
sys.version_info.major == 2,
|
||||
'API compabitility test goldens are generated using python2.')
|
||||
def testNewAPIBackwardsCompatibility(self):
|
||||
# Extract all API stuff.
|
||||
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
|
||||
|
||||
public_api_visitor = public_api.PublicAPIVisitor(visitor)
|
||||
public_api_visitor.do_not_descend_map['tf'].append('contrib')
|
||||
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
|
||||
# TODO(annarev): these symbols have been added recently with tf_export
|
||||
# decorators, but they are not exported with old API. Export them using
|
||||
# old API approach and remove them from here.
|
||||
public_api_visitor.private_map['tf'] = [
|
||||
'to_complex128', 'to_complex64', 'add_to_collections',
|
||||
'unsorted_segment_mean']
|
||||
traverse.traverse(api, public_api_visitor)
|
||||
|
||||
proto_dict = visitor.GetProtos()
|
||||
|
||||
# Read all golden files.
|
||||
expression = os.path.join(
|
||||
resource_loader.get_root_dir_with_all_resources(),
|
||||
_KeyToFilePath('*'))
|
||||
golden_file_list = file_io.get_matching_files(expression)
|
||||
|
||||
def _ReadFileToProto(filename):
|
||||
"""Read a filename, create a protobuf from its contents."""
|
||||
ret_val = api_objects_pb2.TFAPIObject()
|
||||
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
|
||||
return ret_val
|
||||
|
||||
golden_proto_dict = {
|
||||
_FileNameToKey(filename): _ReadFileToProto(filename)
|
||||
for filename in golden_file_list
|
||||
}
|
||||
|
||||
# user_ops is an empty module. It is currently available in TensorFlow API
|
||||
# but we don't keep empty modules in the new API.
|
||||
# We delete user_ops from golden_proto_dict to make sure assert passes
|
||||
# when diffing new API against goldens.
|
||||
# TODO(annarev): remove user_ops from goldens once we switch to new API.
|
||||
tf_module = golden_proto_dict['tensorflow'].tf_module
|
||||
for i in range(len(tf_module.member)):
|
||||
if tf_module.member[i].name == 'user_ops':
|
||||
del tf_module.member[i]
|
||||
break
|
||||
|
||||
# Diff them. Do not fail if called with update.
|
||||
# If the test is run to update goldens, only report diffs but do not fail.
|
||||
self._AssertProtoDictEquals(
|
||||
golden_proto_dict,
|
||||
proto_dict,
|
||||
verbose=FLAGS.verbose_diffs,
|
||||
update_goldens=False,
|
||||
additional_missing_object_message=
|
||||
'Check if tf_export decorator/call is missing for this symbol.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
Loading…
Reference in New Issue
Block a user