Add test for API generated with tf_export calls to api_compatibility_test.py.
Also, fix remaining differences between this new API and the current TensorFlow API. PiperOrigin-RevId: 188943768
This commit is contained in:
parent
86dd46a3c6
commit
d57f0213bf
@ -830,3 +830,14 @@ py_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = ["//tensorflow/python"],
|
deps = ["//tensorflow/python"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "experimental_tensorflow_py",
|
||||||
|
srcs = ["experimental_api.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//tensorflow/tools/api/tests:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python",
|
||||||
|
"//tensorflow/tools/api/generator:python_api",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
4
tensorflow/core/api_def/python_api/api_def_Assign.pbtxt
Normal file
4
tensorflow/core/api_def/python_api/api_def_Assign.pbtxt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "Assign"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "AssignAdd"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "AssignSub"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseReduceMax"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseReduceMaxSparse"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseReduceSum"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseReduceSumSparse"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseSlice"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseSoftmax"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
38
tensorflow/experimental_api.py
Normal file
38
tensorflow/experimental_api.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# Bring in all of the public TensorFlow interface into this
|
||||||
|
# module.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=g-bad-import-order
|
||||||
|
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
|
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
|
||||||
|
del LazyLoader
|
||||||
|
|
||||||
|
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
|
||||||
|
app.flags = flags # pylint: disable=undefined-variable
|
||||||
|
|
||||||
|
del absolute_import
|
||||||
|
del division
|
||||||
|
del print_function
|
@ -343,7 +343,9 @@ tf_export("uint8").export_constant(__name__, "uint8")
|
|||||||
uint16 = DType(types_pb2.DT_UINT16)
|
uint16 = DType(types_pb2.DT_UINT16)
|
||||||
tf_export("uint16").export_constant(__name__, "uint16")
|
tf_export("uint16").export_constant(__name__, "uint16")
|
||||||
uint32 = DType(types_pb2.DT_UINT32)
|
uint32 = DType(types_pb2.DT_UINT32)
|
||||||
|
tf_export("uint32").export_constant(__name__, "uint32")
|
||||||
uint64 = DType(types_pb2.DT_UINT64)
|
uint64 = DType(types_pb2.DT_UINT64)
|
||||||
|
tf_export("uint64").export_constant(__name__, "uint32")
|
||||||
int16 = DType(types_pb2.DT_INT16)
|
int16 = DType(types_pb2.DT_INT16)
|
||||||
tf_export("int16").export_constant(__name__, "int16")
|
tf_export("int16").export_constant(__name__, "int16")
|
||||||
int8 = DType(types_pb2.DT_INT8)
|
int8 = DType(types_pb2.DT_INT8)
|
||||||
|
@ -24,8 +24,10 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
|
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('keras.datasets.fashion_mnist.load_data')
|
||||||
def load_data():
|
def load_data():
|
||||||
"""Loads the Fashion-MNIST dataset.
|
"""Loads the Fashion-MNIST dataset.
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('keras.layers.InputLayer')
|
||||||
class InputLayer(base_layer.Layer):
|
class InputLayer(base_layer.Layer):
|
||||||
"""Layer to be used as an entry point into a Network (a graph of layers).
|
"""Layer to be used as an entry point into a Network (a graph of layers).
|
||||||
|
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Fashion-MNIST dataset."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.keras._impl.keras.datasets.fashion_mnist import load_data
|
||||||
|
|
||||||
|
del absolute_import
|
||||||
|
del division
|
||||||
|
del print_function
|
@ -1795,6 +1795,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059,
|
|||||||
[0.114, -0.32134392, 0.31119955]]
|
[0.114, -0.32134392, 0.31119955]]
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.rgb_to_yiq')
|
||||||
def rgb_to_yiq(images):
|
def rgb_to_yiq(images):
|
||||||
"""Converts one or more images from RGB to YIQ.
|
"""Converts one or more images from RGB to YIQ.
|
||||||
|
|
||||||
@ -1820,6 +1821,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
|
|||||||
[0.6208248, -0.64720424, 1.70423049]]
|
[0.6208248, -0.64720424, 1.70423049]]
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.yiq_to_rgb')
|
||||||
def yiq_to_rgb(images):
|
def yiq_to_rgb(images):
|
||||||
"""Converts one or more images from YIQ to RGB.
|
"""Converts one or more images from YIQ to RGB.
|
||||||
|
|
||||||
@ -1847,6 +1849,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119,
|
|||||||
[0.114, 0.43601035, -0.10001026]]
|
[0.114, 0.43601035, -0.10001026]]
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.rgb_to_yuv')
|
||||||
def rgb_to_yuv(images):
|
def rgb_to_yuv(images):
|
||||||
"""Converts one or more images from RGB to YUV.
|
"""Converts one or more images from RGB to YUV.
|
||||||
|
|
||||||
@ -1872,6 +1875,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
|
|||||||
[1.13988303, -0.58062185, 0]]
|
[1.13988303, -0.58062185, 0]]
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.yuv_to_rgb')
|
||||||
def yuv_to_rgb(images):
|
def yuv_to_rgb(images):
|
||||||
"""Converts one or more images from YUV to RGB.
|
"""Converts one or more images from YUV to RGB.
|
||||||
|
|
||||||
|
@ -39,5 +39,8 @@ global_variables = _variables.global_variables_initializer
|
|||||||
local_variables = _variables.local_variables_initializer
|
local_variables = _variables.local_variables_initializer
|
||||||
|
|
||||||
# Seal API.
|
# Seal API.
|
||||||
|
del absolute_import
|
||||||
|
del division
|
||||||
|
del print_function
|
||||||
del init_ops
|
del init_ops
|
||||||
del _variables
|
del _variables
|
||||||
|
@ -32,15 +32,18 @@ cholesky = linalg_ops.cholesky
|
|||||||
cholesky_solve = linalg_ops.cholesky_solve
|
cholesky_solve = linalg_ops.cholesky_solve
|
||||||
det = linalg_ops.matrix_determinant
|
det = linalg_ops.matrix_determinant
|
||||||
slogdet = gen_linalg_ops.log_matrix_determinant
|
slogdet = gen_linalg_ops.log_matrix_determinant
|
||||||
|
tf_export('linalg.slogdet')(slogdet)
|
||||||
diag = array_ops.matrix_diag
|
diag = array_ops.matrix_diag
|
||||||
diag_part = array_ops.matrix_diag_part
|
diag_part = array_ops.matrix_diag_part
|
||||||
eigh = linalg_ops.self_adjoint_eig
|
eigh = linalg_ops.self_adjoint_eig
|
||||||
eigvalsh = linalg_ops.self_adjoint_eigvals
|
eigvalsh = linalg_ops.self_adjoint_eigvals
|
||||||
einsum = special_math_ops.einsum
|
einsum = special_math_ops.einsum
|
||||||
expm = gen_linalg_ops.matrix_exponential
|
expm = gen_linalg_ops.matrix_exponential
|
||||||
|
tf_export('linalg.expm')(expm)
|
||||||
eye = linalg_ops.eye
|
eye = linalg_ops.eye
|
||||||
inv = linalg_ops.matrix_inverse
|
inv = linalg_ops.matrix_inverse
|
||||||
logm = gen_linalg_ops.matrix_logarithm
|
logm = gen_linalg_ops.matrix_logarithm
|
||||||
|
tf_export('linalg.logm')(logm)
|
||||||
lstsq = linalg_ops.matrix_solve_ls
|
lstsq = linalg_ops.matrix_solve_ls
|
||||||
norm = linalg_ops.norm
|
norm = linalg_ops.norm
|
||||||
qr = linalg_ops.qr
|
qr = linalg_ops.qr
|
||||||
|
@ -23,9 +23,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
|
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
@tf_export('manip.roll')
|
||||||
def roll(input, shift, axis): # pylint: disable=redefined-builtin
|
def roll(input, shift, axis): # pylint: disable=redefined-builtin
|
||||||
return _gen_manip_ops.roll(input, shift, axis)
|
return _gen_manip_ops.roll(input, shift, axis)
|
||||||
|
|
||||||
|
@ -1184,11 +1184,16 @@ def floordiv(x, y, name=None):
|
|||||||
|
|
||||||
|
|
||||||
realdiv = gen_math_ops.real_div
|
realdiv = gen_math_ops.real_div
|
||||||
|
tf_export("realdiv")(realdiv)
|
||||||
truncatediv = gen_math_ops.truncate_div
|
truncatediv = gen_math_ops.truncate_div
|
||||||
|
tf_export("truncatediv")(truncatediv)
|
||||||
# TODO(aselle): Rename this to floordiv when we can.
|
# TODO(aselle): Rename this to floordiv when we can.
|
||||||
floor_div = gen_math_ops.floor_div
|
floor_div = gen_math_ops.floor_div
|
||||||
|
tf_export("floor_div")(floor_div)
|
||||||
truncatemod = gen_math_ops.truncate_mod
|
truncatemod = gen_math_ops.truncate_mod
|
||||||
|
tf_export("truncatemod")(truncatemod)
|
||||||
floormod = gen_math_ops.floor_mod
|
floormod = gen_math_ops.floor_mod
|
||||||
|
tf_export("floormod")(floormod)
|
||||||
|
|
||||||
|
|
||||||
def _mul_dispatch(x, y, name=None):
|
def _mul_dispatch(x, y, name=None):
|
||||||
@ -2111,6 +2116,7 @@ def matmul(a,
|
|||||||
_OverrideBinaryOperatorHelper(matmul, "matmul")
|
_OverrideBinaryOperatorHelper(matmul, "matmul")
|
||||||
|
|
||||||
sparse_matmul = gen_math_ops.sparse_mat_mul
|
sparse_matmul = gen_math_ops.sparse_mat_mul
|
||||||
|
tf_export("sparse_matmul")(sparse_matmul)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterStatistics("MatMul", "flops")
|
@ops.RegisterStatistics("MatMul", "flops")
|
||||||
|
@ -303,12 +303,12 @@ def _swish_grad(features, grad):
|
|||||||
# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
|
# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
|
||||||
# during backprop, and we can free the sigmoid(features) expression immediately
|
# during backprop, and we can free the sigmoid(features) expression immediately
|
||||||
# after use during the forward pass.
|
# after use during the forward pass.
|
||||||
|
@tf_export("nn.swish")
|
||||||
@function.Defun(
|
@function.Defun(
|
||||||
grad_func=_swish_grad,
|
grad_func=_swish_grad,
|
||||||
shape_func=_swish_shape,
|
shape_func=_swish_shape,
|
||||||
func_name="swish",
|
func_name="swish",
|
||||||
noinline=True)
|
noinline=True)
|
||||||
@tf_export("nn.swish")
|
|
||||||
def swish(features):
|
def swish(features):
|
||||||
# pylint: disable=g-doc-args
|
# pylint: disable=g-doc-args
|
||||||
"""Computes the Swish activation function: `x * sigmoid(x)`.
|
"""Computes the Swish activation function: `x * sigmoid(x)`.
|
||||||
@ -1343,4 +1343,4 @@ def sampled_softmax_loss(weights,
|
|||||||
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
|
sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
|
||||||
labels=labels, logits=logits)
|
labels=labels, logits=logits)
|
||||||
# sampled_losses is a [batch_size] tensor.
|
# sampled_losses is a [batch_size] tensor.
|
||||||
return sampled_losses
|
return sampled_losses
|
||||||
|
@ -1385,7 +1385,6 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
|
|||||||
"GraphKeys.GLOBAL_VARIABLES")
|
"GraphKeys.GLOBAL_VARIABLES")
|
||||||
|
|
||||||
|
|
||||||
@functools.wraps(get_variable)
|
|
||||||
@tf_export("get_local_variable")
|
@tf_export("get_local_variable")
|
||||||
def get_local_variable(*args, **kwargs):
|
def get_local_variable(*args, **kwargs):
|
||||||
kwargs["trainable"] = False
|
kwargs["trainable"] = False
|
||||||
|
@ -36,6 +36,7 @@ from tensorflow.python.platform import benchmark
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name
|
Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name
|
||||||
@ -138,6 +139,7 @@ def StatefulSessionAvailable():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('test.StubOutForTesting')
|
||||||
class StubOutForTesting(object):
|
class StubOutForTesting(object):
|
||||||
"""Support class for stubbing methods out for unit testing.
|
"""Support class for stubbing methods out for unit testing.
|
||||||
|
|
||||||
|
@ -62,6 +62,8 @@ if sys.version_info.major == 2:
|
|||||||
else:
|
else:
|
||||||
from unittest import mock # pylint: disable=g-import-not-at-top
|
from unittest import mock # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
|
tf_export('test.mock')(mock)
|
||||||
|
|
||||||
# Import Benchmark class
|
# Import Benchmark class
|
||||||
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
|
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Description:
|
# Description:
|
||||||
# Scripts used to generate TensorFlow Python API.
|
# Scripts used to generate TensorFlow Python API.
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
@ -21,7 +22,7 @@ py_binary(
|
|||||||
srcs = ["create_python_api.py"],
|
srcs = ["create_python_api.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow/python",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -80,6 +81,7 @@ genrule(
|
|||||||
"api/keras/datasets/boston_housing/__init__.py",
|
"api/keras/datasets/boston_housing/__init__.py",
|
||||||
"api/keras/datasets/cifar10/__init__.py",
|
"api/keras/datasets/cifar10/__init__.py",
|
||||||
"api/keras/datasets/cifar100/__init__.py",
|
"api/keras/datasets/cifar100/__init__.py",
|
||||||
|
"api/keras/datasets/fashion_mnist/__init__.py",
|
||||||
"api/keras/datasets/imdb/__init__.py",
|
"api/keras/datasets/imdb/__init__.py",
|
||||||
"api/keras/datasets/mnist/__init__.py",
|
"api/keras/datasets/mnist/__init__.py",
|
||||||
"api/keras/datasets/reuters/__init__.py",
|
"api/keras/datasets/reuters/__init__.py",
|
||||||
@ -102,6 +104,7 @@ genrule(
|
|||||||
"api/linalg/__init__.py",
|
"api/linalg/__init__.py",
|
||||||
"api/logging/__init__.py",
|
"api/logging/__init__.py",
|
||||||
"api/losses/__init__.py",
|
"api/losses/__init__.py",
|
||||||
|
"api/manip/__init__.py",
|
||||||
"api/metrics/__init__.py",
|
"api/metrics/__init__.py",
|
||||||
"api/nn/__init__.py",
|
"api/nn/__init__.py",
|
||||||
"api/nn/rnn_cell/__init__.py",
|
"api/nn/rnn_cell/__init__.py",
|
||||||
@ -133,7 +136,9 @@ py_library(
|
|||||||
name = "python_api",
|
name = "python_api",
|
||||||
srcs = [":python_api_gen"],
|
srcs = [":python_api_gen"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/contrib:contrib_py", # keep
|
"//tensorflow/contrib:contrib_py", # keep
|
||||||
|
"//tensorflow/python", # keep
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -23,15 +23,14 @@ import collections
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# This import is needed so that we can traverse over TensorFlow modules.
|
from tensorflow import python as tf
|
||||||
import tensorflow as tf # pylint: disable=unused-import
|
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
|
|
||||||
_API_CONSTANTS_ATTR = '_tf_api_constants'
|
_API_CONSTANTS_ATTR = '_tf_api_constants'
|
||||||
_API_NAMES_ATTR = '_tf_api_names'
|
_API_NAMES_ATTR = '_tf_api_names'
|
||||||
_API_DIR = '/api/'
|
_API_DIR = '/api/'
|
||||||
_CONTRIB_IMPORT = 'from tensorflow import contrib'
|
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
|
||||||
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
|
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
|
||||||
|
|
||||||
This file is MACHINE GENERATED! Do not edit.
|
This file is MACHINE GENERATED! Do not edit.
|
||||||
@ -92,7 +91,7 @@ def get_api_imports():
|
|||||||
if module_contents_name == _API_CONSTANTS_ATTR:
|
if module_contents_name == _API_CONSTANTS_ATTR:
|
||||||
for exports, value in attr:
|
for exports, value in attr:
|
||||||
for export in exports:
|
for export in exports:
|
||||||
names = ['tf'] + export.split('.')
|
names = export.split('.')
|
||||||
dest_module = '.'.join(names[:-1])
|
dest_module = '.'.join(names[:-1])
|
||||||
import_str = format_import(module.__name__, value, names[-1])
|
import_str = format_import(module.__name__, value, names[-1])
|
||||||
module_imports[dest_module].append(import_str)
|
module_imports[dest_module].append(import_str)
|
||||||
@ -104,29 +103,43 @@ def get_api_imports():
|
|||||||
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
|
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
|
||||||
# The same op might be accessible from multiple modules.
|
# The same op might be accessible from multiple modules.
|
||||||
# We only want to consider location where function was defined.
|
# We only want to consider location where function was defined.
|
||||||
if attr.__module__ != module.__name__:
|
# Here we check if the op is defined in another TensorFlow module in
|
||||||
|
# sys.modules.
|
||||||
|
if (hasattr(attr, '__module__') and
|
||||||
|
attr.__module__.startswith(tf.__name__) and
|
||||||
|
attr.__module__ != module.__name__ and
|
||||||
|
attr.__module__ in sys.modules and
|
||||||
|
module_contents_name in dir(sys.modules[attr.__module__])):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for export in attr._tf_api_names: # pylint: disable=protected-access
|
for export in attr._tf_api_names: # pylint: disable=protected-access
|
||||||
names = ['tf'] + export.split('.')
|
names = export.split('.')
|
||||||
dest_module = '.'.join(names[:-1])
|
dest_module = '.'.join(names[:-1])
|
||||||
import_str = format_import(
|
import_str = format_import(
|
||||||
module.__name__, module_contents_name, names[-1])
|
module.__name__, module_contents_name, names[-1])
|
||||||
module_imports[dest_module].append(import_str)
|
module_imports[dest_module].append(import_str)
|
||||||
|
|
||||||
# Import all required modules in their parent modules.
|
# Import all required modules in their parent modules.
|
||||||
# For e.g. if we import 'tf.foo.bar.Value'. Then, we also
|
# For e.g. if we import 'foo.bar.Value'. Then, we also
|
||||||
# import 'bar' in 'tf.foo'.
|
# import 'bar' in 'foo'.
|
||||||
dest_modules = set(module_imports.keys())
|
imported_modules = set(module_imports.keys())
|
||||||
for dest_module in dest_modules:
|
for module in imported_modules:
|
||||||
dest_module_split = dest_module.split('.')
|
if not module:
|
||||||
for dest_submodule_index in range(1, len(dest_module_split)):
|
continue
|
||||||
dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
|
module_split = module.split('.')
|
||||||
|
parent_module = '' # we import submodules in their parent_module
|
||||||
|
|
||||||
|
for submodule_index in range(len(module_split)):
|
||||||
|
import_from = _OUTPUT_MODULE
|
||||||
|
if submodule_index > 0:
|
||||||
|
parent_module += ('.' + module_split[submodule_index-1] if parent_module
|
||||||
|
else module_split[submodule_index-1])
|
||||||
|
import_from += '.' + parent_module
|
||||||
submodule_import = format_import(
|
submodule_import = format_import(
|
||||||
'', dest_module_split[dest_submodule_index],
|
import_from, module_split[submodule_index],
|
||||||
dest_module_split[dest_submodule_index])
|
module_split[submodule_index])
|
||||||
if submodule_import not in module_imports[dest_submodule]:
|
if submodule_import not in module_imports[parent_module]:
|
||||||
module_imports[dest_submodule].append(submodule_import)
|
module_imports[parent_module].append(submodule_import)
|
||||||
|
|
||||||
return module_imports
|
return module_imports
|
||||||
|
|
||||||
@ -151,8 +164,8 @@ def create_api_files(output_files):
|
|||||||
# First get module directory under _API_DIR.
|
# First get module directory under _API_DIR.
|
||||||
module_dir = os.path.dirname(
|
module_dir = os.path.dirname(
|
||||||
output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
|
output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
|
||||||
# Convert / to . and prefix with tf.
|
# Convert / to .
|
||||||
module_name = '.'.join(['tf', module_dir.replace('/', '.')]).strip('.')
|
module_name = module_dir.replace('/', '.').strip('.')
|
||||||
module_name_to_file_path[module_name] = output_file
|
module_name_to_file_path[module_name] = output_file
|
||||||
|
|
||||||
# Create file for each expected output in genrule.
|
# Create file for each expected output in genrule.
|
||||||
@ -162,16 +175,14 @@ def create_api_files(output_files):
|
|||||||
open(file_path, 'a').close()
|
open(file_path, 'a').close()
|
||||||
|
|
||||||
module_imports = get_api_imports()
|
module_imports = get_api_imports()
|
||||||
module_imports['tf'].append(_CONTRIB_IMPORT) # Include all of contrib.
|
|
||||||
|
|
||||||
# Add imports to output files.
|
# Add imports to output files.
|
||||||
missing_output_files = []
|
missing_output_files = []
|
||||||
for module, exports in module_imports.items():
|
for module, exports in module_imports.items():
|
||||||
# Make sure genrule output file list is in sync with API exports.
|
# Make sure genrule output file list is in sync with API exports.
|
||||||
if module not in module_name_to_file_path:
|
if module not in module_name_to_file_path:
|
||||||
module_without_tf = module[len('tf.'):]
|
|
||||||
module_file_path = '"api/%s/__init__.py"' % (
|
module_file_path = '"api/%s/__init__.py"' % (
|
||||||
module_without_tf.replace('.', '/'))
|
module.replace('.', '/'))
|
||||||
missing_output_files.append(module_file_path)
|
missing_output_files.append(module_file_path)
|
||||||
continue
|
continue
|
||||||
with open(module_name_to_file_path[module], 'w') as fp:
|
with open(module_name_to_file_path[module], 'w') as fp:
|
||||||
|
@ -1,17 +1,9 @@
|
|||||||
path: "tensorflow.initializers"
|
path: "tensorflow.initializers"
|
||||||
tf_module {
|
tf_module {
|
||||||
member {
|
|
||||||
name: "absolute_import"
|
|
||||||
mtype: "<type \'instance\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "constant"
|
name: "constant"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
|
||||||
name: "division"
|
|
||||||
mtype: "<type \'instance\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "identity"
|
name: "identity"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
@ -24,10 +16,6 @@ tf_module {
|
|||||||
name: "orthogonal"
|
name: "orthogonal"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
|
||||||
name: "print_function"
|
|
||||||
mtype: "<type \'instance\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "random_normal"
|
name: "random_normal"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
path: "tensorflow.keras.datasets.fashion_mnist"
|
path: "tensorflow.keras.datasets.fashion_mnist"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "load_data"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ py_test(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow:experimental_tensorflow_py",
|
||||||
"//tensorflow:tensorflow_py",
|
"//tensorflow:tensorflow_py",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
|
@ -34,6 +34,7 @@ import sys
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
from tensorflow import experimental_api as api
|
||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
@ -46,6 +47,9 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor
|
|||||||
from tensorflow.tools.common import public_api
|
from tensorflow.tools.common import public_api
|
||||||
from tensorflow.tools.common import traverse
|
from tensorflow.tools.common import traverse
|
||||||
|
|
||||||
|
if hasattr(tf, 'experimental_api'):
|
||||||
|
del tf.experimental_api
|
||||||
|
|
||||||
# FLAGS defined at the bottom:
|
# FLAGS defined at the bottom:
|
||||||
FLAGS = None
|
FLAGS = None
|
||||||
# DEFINE_boolean, update_goldens, default False:
|
# DEFINE_boolean, update_goldens, default False:
|
||||||
@ -109,7 +113,8 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
expected_dict,
|
expected_dict,
|
||||||
actual_dict,
|
actual_dict,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
update_goldens=False):
|
update_goldens=False,
|
||||||
|
additional_missing_object_message=''):
|
||||||
"""Diff given dicts of protobufs and report differences a readable way.
|
"""Diff given dicts of protobufs and report differences a readable way.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -120,6 +125,8 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
verbose: Whether to log the full diffs, or simply report which files were
|
verbose: Whether to log the full diffs, or simply report which files were
|
||||||
different.
|
different.
|
||||||
update_goldens: Whether to update goldens when there are diffs found.
|
update_goldens: Whether to update goldens when there are diffs found.
|
||||||
|
additional_missing_object_message: Message to print when a symbol is
|
||||||
|
missing.
|
||||||
"""
|
"""
|
||||||
diffs = []
|
diffs = []
|
||||||
verbose_diffs = []
|
verbose_diffs = []
|
||||||
@ -138,7 +145,8 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
verbose_diff_message = ''
|
verbose_diff_message = ''
|
||||||
# First check if the key is not found in one or the other.
|
# First check if the key is not found in one or the other.
|
||||||
if key in only_in_expected:
|
if key in only_in_expected:
|
||||||
diff_message = 'Object %s expected but not found (removed).' % key
|
diff_message = 'Object %s expected but not found (removed). %s' % (
|
||||||
|
key, additional_missing_object_message)
|
||||||
verbose_diff_message = diff_message
|
verbose_diff_message = diff_message
|
||||||
elif key in only_in_actual:
|
elif key in only_in_actual:
|
||||||
diff_message = 'New object %s found (added).' % key
|
diff_message = 'New object %s found (added).' % key
|
||||||
@ -229,6 +237,64 @@ class ApiCompatibilityTest(test.TestCase):
|
|||||||
verbose=FLAGS.verbose_diffs,
|
verbose=FLAGS.verbose_diffs,
|
||||||
update_goldens=FLAGS.update_goldens)
|
update_goldens=FLAGS.update_goldens)
|
||||||
|
|
||||||
|
@unittest.skipUnless(
|
||||||
|
sys.version_info.major == 2,
|
||||||
|
'API compabitility test goldens are generated using python2.')
|
||||||
|
def testNewAPIBackwardsCompatibility(self):
|
||||||
|
# Extract all API stuff.
|
||||||
|
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
|
||||||
|
|
||||||
|
public_api_visitor = public_api.PublicAPIVisitor(visitor)
|
||||||
|
public_api_visitor.do_not_descend_map['tf'].append('contrib')
|
||||||
|
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
|
||||||
|
# TODO(annarev): these symbols have been added recently with tf_export
|
||||||
|
# decorators, but they are not exported with old API. Export them using
|
||||||
|
# old API approach and remove them from here.
|
||||||
|
public_api_visitor.private_map['tf'] = [
|
||||||
|
'to_complex128', 'to_complex64', 'add_to_collections',
|
||||||
|
'unsorted_segment_mean']
|
||||||
|
traverse.traverse(api, public_api_visitor)
|
||||||
|
|
||||||
|
proto_dict = visitor.GetProtos()
|
||||||
|
|
||||||
|
# Read all golden files.
|
||||||
|
expression = os.path.join(
|
||||||
|
resource_loader.get_root_dir_with_all_resources(),
|
||||||
|
_KeyToFilePath('*'))
|
||||||
|
golden_file_list = file_io.get_matching_files(expression)
|
||||||
|
|
||||||
|
def _ReadFileToProto(filename):
|
||||||
|
"""Read a filename, create a protobuf from its contents."""
|
||||||
|
ret_val = api_objects_pb2.TFAPIObject()
|
||||||
|
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
|
||||||
|
return ret_val
|
||||||
|
|
||||||
|
golden_proto_dict = {
|
||||||
|
_FileNameToKey(filename): _ReadFileToProto(filename)
|
||||||
|
for filename in golden_file_list
|
||||||
|
}
|
||||||
|
|
||||||
|
# user_ops is an empty module. It is currently available in TensorFlow API
|
||||||
|
# but we don't keep empty modules in the new API.
|
||||||
|
# We delete user_ops from golden_proto_dict to make sure assert passes
|
||||||
|
# when diffing new API against goldens.
|
||||||
|
# TODO(annarev): remove user_ops from goldens once we switch to new API.
|
||||||
|
tf_module = golden_proto_dict['tensorflow'].tf_module
|
||||||
|
for i in range(len(tf_module.member)):
|
||||||
|
if tf_module.member[i].name == 'user_ops':
|
||||||
|
del tf_module.member[i]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Diff them. Do not fail if called with update.
|
||||||
|
# If the test is run to update goldens, only report diffs but do not fail.
|
||||||
|
self._AssertProtoDictEquals(
|
||||||
|
golden_proto_dict,
|
||||||
|
proto_dict,
|
||||||
|
verbose=FLAGS.verbose_diffs,
|
||||||
|
update_goldens=False,
|
||||||
|
additional_missing_object_message=
|
||||||
|
'Check if tf_export decorator/call is missing for this symbol.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
Loading…
Reference in New Issue
Block a user