Add test for API generated with tf_export calls to api_compatibility_test.py.

Also, fix remaining differences between this new API and the current TensorFlow
API.

PiperOrigin-RevId: 188943768
This commit is contained in:
Anna R 2018-03-13 15:51:36 -07:00 committed by TensorFlower Gardener
parent 86dd46a3c6
commit d57f0213bf
30 changed files with 251 additions and 40 deletions

View File

@ -830,3 +830,14 @@ py_library(
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
py_library(
name = "experimental_tensorflow_py",
srcs = ["experimental_api.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow/tools/api/tests:__subpackages__"],
deps = [
"//tensorflow/python",
"//tensorflow/tools/api/generator:python_api",
],
)

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "Assign"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "AssignAdd"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "AssignSub"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseReduceMax"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseReduceMaxSparse"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseReduceSum"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseReduceSumSparse"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseSlice"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "SparseSoftmax"
visibility: HIDDEN
}

View File

@ -0,0 +1,38 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Bring in all of the public TensorFlow interface into this
# module.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# pylint: disable=wildcard-import
from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
del absolute_import
del division
del print_function

View File

@ -343,7 +343,9 @@ tf_export("uint8").export_constant(__name__, "uint8")
uint16 = DType(types_pb2.DT_UINT16)
tf_export("uint16").export_constant(__name__, "uint16")
uint32 = DType(types_pb2.DT_UINT32)
tf_export("uint32").export_constant(__name__, "uint32")
uint64 = DType(types_pb2.DT_UINT64)
tf_export("uint64").export_constant(__name__, "uint32")
int16 = DType(types_pb2.DT_INT16)
tf_export("int16").export_constant(__name__, "int16")
int8 = DType(types_pb2.DT_INT8)

View File

@ -24,8 +24,10 @@ import os
import numpy as np
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.datasets.fashion_mnist.load_data')
def load_data():
"""Loads the Fashion-MNIST dataset.

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.layers.InputLayer')
class InputLayer(base_layer.Layer):
"""Layer to be used as an entry point into a Network (a graph of layers).

View File

@ -0,0 +1,25 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fashion-MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data
del absolute_import
del division
del print_function

View File

@ -1795,6 +1795,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059,
[0.114, -0.32134392, 0.31119955]]
@tf_export('image.rgb_to_yiq')
def rgb_to_yiq(images):
"""Converts one or more images from RGB to YIQ.
@ -1820,6 +1821,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
[0.6208248, -0.64720424, 1.70423049]]
@tf_export('image.yiq_to_rgb')
def yiq_to_rgb(images):
"""Converts one or more images from YIQ to RGB.
@ -1847,6 +1849,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119,
[0.114, 0.43601035, -0.10001026]]
@tf_export('image.rgb_to_yuv')
def rgb_to_yuv(images):
"""Converts one or more images from RGB to YUV.
@ -1872,6 +1875,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
[1.13988303, -0.58062185, 0]]
@tf_export('image.yuv_to_rgb')
def yuv_to_rgb(images):
"""Converts one or more images from YUV to RGB.

View File

@ -39,5 +39,8 @@ global_variables = _variables.global_variables_initializer
local_variables = _variables.local_variables_initializer
# Seal API.
del absolute_import
del division
del print_function
del init_ops
del _variables

View File

@ -32,15 +32,18 @@ cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve
det = linalg_ops.matrix_determinant
slogdet = gen_linalg_ops.log_matrix_determinant
tf_export('linalg.slogdet')(slogdet)
diag = array_ops.matrix_diag
diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
expm = gen_linalg_ops.matrix_exponential
tf_export('linalg.expm')(expm)
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
logm = gen_linalg_ops.matrix_logarithm
tf_export('linalg.logm')(logm)
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr

View File

@ -23,9 +23,11 @@ from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@tf_export('manip.roll')
def roll(input, shift, axis): # pylint: disable=redefined-builtin
return _gen_manip_ops.roll(input, shift, axis)

View File

@ -1184,11 +1184,16 @@ def floordiv(x, y, name=None):
realdiv = gen_math_ops.real_div
tf_export("realdiv")(realdiv)
truncatediv = gen_math_ops.truncate_div
tf_export("truncatediv")(truncatediv)
# TODO(aselle): Rename this to floordiv when we can.
floor_div = gen_math_ops.floor_div
tf_export("floor_div")(floor_div)
truncatemod = gen_math_ops.truncate_mod
tf_export("truncatemod")(truncatemod)
floormod = gen_math_ops.floor_mod
tf_export("floormod")(floormod)
def _mul_dispatch(x, y, name=None):
@ -2111,6 +2116,7 @@ def matmul(a,
_OverrideBinaryOperatorHelper(matmul, "matmul")
sparse_matmul = gen_math_ops.sparse_mat_mul
tf_export("sparse_matmul")(sparse_matmul)
@ops.RegisterStatistics("MatMul", "flops")

View File

@ -303,12 +303,12 @@ def _swish_grad(features, grad):
# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
# during backprop, and we can free the sigmoid(features) expression immediately
# after use during the forward pass.
@tf_export("nn.swish")
@function.Defun(
grad_func=_swish_grad,
shape_func=_swish_shape,
func_name="swish",
noinline=True)
@tf_export("nn.swish")
def swish(features):
# pylint: disable=g-doc-args
"""Computes the Swish activation function: `x * sigmoid(x)`.
@ -1343,4 +1343,4 @@ def sampled_softmax_loss(weights,
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
labels=labels, logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses
return sampled_losses

View File

@ -1385,7 +1385,6 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
"GraphKeys.GLOBAL_VARIABLES")
@functools.wraps(get_variable)
@tf_export("get_local_variable")
def get_local_variable(*args, **kwargs):
kwargs["trainable"] = False

View File

@ -36,6 +36,7 @@ from tensorflow.python.platform import benchmark
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name
@ -138,6 +139,7 @@ def StatefulSessionAvailable():
return False
@tf_export('test.StubOutForTesting')
class StubOutForTesting(object):
"""Support class for stubbing methods out for unit testing.

View File

@ -62,6 +62,8 @@ if sys.version_info.major == 2:
else:
from unittest import mock # pylint: disable=g-import-not-at-top
tf_export('test.mock')(mock)
# Import Benchmark class
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name

View File

@ -1,5 +1,6 @@
# Description:
# Scripts used to generate TensorFlow Python API.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
@ -21,7 +22,7 @@ py_binary(
srcs = ["create_python_api.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python",
],
)
@ -80,6 +81,7 @@ genrule(
"api/keras/datasets/boston_housing/__init__.py",
"api/keras/datasets/cifar10/__init__.py",
"api/keras/datasets/cifar100/__init__.py",
"api/keras/datasets/fashion_mnist/__init__.py",
"api/keras/datasets/imdb/__init__.py",
"api/keras/datasets/mnist/__init__.py",
"api/keras/datasets/reuters/__init__.py",
@ -102,6 +104,7 @@ genrule(
"api/linalg/__init__.py",
"api/logging/__init__.py",
"api/losses/__init__.py",
"api/manip/__init__.py",
"api/metrics/__init__.py",
"api/nn/__init__.py",
"api/nn/rnn_cell/__init__.py",
@ -133,7 +136,9 @@ py_library(
name = "python_api",
srcs = [":python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib:contrib_py", # keep
"//tensorflow/python", # keep
],
)

View File

@ -23,15 +23,14 @@ import collections
import os
import sys
# This import is needed so that we can traverse over TensorFlow modules.
import tensorflow as tf # pylint: disable=unused-import
from tensorflow import python as tf
from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
_CONTRIB_IMPORT = 'from tensorflow import contrib'
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
This file is MACHINE GENERATED! Do not edit.
@ -92,7 +91,7 @@ def get_api_imports():
if module_contents_name == _API_CONSTANTS_ATTR:
for exports, value in attr:
for export in exports:
names = ['tf'] + export.split('.')
names = export.split('.')
dest_module = '.'.join(names[:-1])
import_str = format_import(module.__name__, value, names[-1])
module_imports[dest_module].append(import_str)
@ -104,29 +103,43 @@ def get_api_imports():
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
# The same op might be accessible from multiple modules.
# We only want to consider location where function was defined.
if attr.__module__ != module.__name__:
# Here we check if the op is defined in another TensorFlow module in
# sys.modules.
if (hasattr(attr, '__module__') and
attr.__module__.startswith(tf.__name__) and
attr.__module__ != module.__name__ and
attr.__module__ in sys.modules and
module_contents_name in dir(sys.modules[attr.__module__])):
continue
for export in attr._tf_api_names: # pylint: disable=protected-access
names = ['tf'] + export.split('.')
names = export.split('.')
dest_module = '.'.join(names[:-1])
import_str = format_import(
module.__name__, module_contents_name, names[-1])
module_imports[dest_module].append(import_str)
# Import all required modules in their parent modules.
# For e.g. if we import 'tf.foo.bar.Value'. Then, we also
# import 'bar' in 'tf.foo'.
dest_modules = set(module_imports.keys())
for dest_module in dest_modules:
dest_module_split = dest_module.split('.')
for dest_submodule_index in range(1, len(dest_module_split)):
dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(module_imports.keys())
for module in imported_modules:
if not module:
continue
module_split = module.split('.')
parent_module = '' # we import submodules in their parent_module
for submodule_index in range(len(module_split)):
import_from = _OUTPUT_MODULE
if submodule_index > 0:
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
import_from += '.' + parent_module
submodule_import = format_import(
'', dest_module_split[dest_submodule_index],
dest_module_split[dest_submodule_index])
if submodule_import not in module_imports[dest_submodule]:
module_imports[dest_submodule].append(submodule_import)
import_from, module_split[submodule_index],
module_split[submodule_index])
if submodule_import not in module_imports[parent_module]:
module_imports[parent_module].append(submodule_import)
return module_imports
@ -151,8 +164,8 @@ def create_api_files(output_files):
# First get module directory under _API_DIR.
module_dir = os.path.dirname(
output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
# Convert / to . and prefix with tf.
module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.')
# Convert / to .
module_name = module_dir.replace('/', '.').strip('.')
module_name_to_file_path[module_name] = output_file
# Create file for each expected output in genrule.
@ -162,16 +175,14 @@ def create_api_files(output_files):
open(file_path, 'a').close()
module_imports = get_api_imports()
module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib.
# Add imports to output files.
missing_output_files = []
for module, exports in module_imports.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
module_without_tf = module[len('tf.'):]
module_file_path = '"api/%s/__init__.py"' % (
module_without_tf.replace('.', '/'))
module.replace('.', '/'))
missing_output_files.append(module_file_path)
continue
with open(module_name_to_file_path[module], 'w') as fp:

View File

@ -1,17 +1,9 @@
path: "tensorflow.initializers"
tf_module {
member {
name: "absolute_import"
mtype: "<type \'instance\'>"
}
member {
name: "constant"
mtype: "<type \'type\'>"
}
member {
name: "division"
mtype: "<type \'instance\'>"
}
member {
name: "identity"
mtype: "<type \'type\'>"
@ -24,10 +16,6 @@ tf_module {
name: "orthogonal"
mtype: "<type \'type\'>"
}
member {
name: "print_function"
mtype: "<type \'instance\'>"
}
member {
name: "random_normal"
mtype: "<type \'type\'>"

View File

@ -1,3 +1,7 @@
path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -23,6 +23,7 @@ py_test(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:experimental_tensorflow_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:lib",

View File

@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
from tensorflow import experimental_api as api
from google.protobuf import text_format
@ -46,6 +47,9 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
if hasattr(tf, 'experimental_api'):
del tf.experimental_api
# FLAGS defined at the bottom:
FLAGS = None
# DEFINE_boolean, update_goldens, default False:
@ -109,7 +113,8 @@ class ApiCompatibilityTest(test.TestCase):
expected_dict,
actual_dict,
verbose=False,
update_goldens=False):
update_goldens=False,
additional_missing_object_message=''):
"""Diff given dicts of protobufs and report differences a readable way.
Args:
@ -120,6 +125,8 @@ class ApiCompatibilityTest(test.TestCase):
verbose: Whether to log the full diffs, or simply report which files were
different.
update_goldens: Whether to update goldens when there are diffs found.
additional_missing_object_message: Message to print when a symbol is
missing.
"""
diffs = []
verbose_diffs = []
@ -138,7 +145,8 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = ''
# First check if the key is not found in one or the other.
if key in only_in_expected:
diff_message = 'Object %s expected but not found (removed).' % key
diff_message = 'Object %s expected but not found (removed). %s' % (
key, additional_missing_object_message)
verbose_diff_message = diff_message
elif key in only_in_actual:
diff_message = 'New object %s found (added).' % key
@ -229,6 +237,64 @@ class ApiCompatibilityTest(test.TestCase):
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testNewAPIBackwardsCompatibility(self):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
# TODO(annarev): these symbols have been added recently with tf_export
# decorators, but they are not exported with old API. Export them using
# old API approach and remove them from here.
public_api_visitor.private_map['tf'] = [
'to_complex128', 'to_complex64', 'add_to_collections',
'unsorted_segment_mean']
traverse.traverse(api, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
expression = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*'))
golden_file_list = file_io.get_matching_files(expression)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
# user_ops is an empty module. It is currently available in TensorFlow API
# but we don't keep empty modules in the new API.
# We delete user_ops from golden_proto_dict to make sure assert passes
# when diffing new API against goldens.
# TODO(annarev): remove user_ops from goldens once we switch to new API.
tf_module = golden_proto_dict['tensorflow'].tf_module
for i in range(len(tf_module.member)):
if tf_module.member[i].name == 'user_ops':
del tf_module.member[i]
break
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=False,
additional_missing_object_message=
'Check if tf_export decorator/call is missing for this symbol.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()