Moving generated API to tensorflow/.

PiperOrigin-RevId: 198767512
This commit is contained in:
Anna R 2018-05-31 13:11:43 -07:00 committed by TensorFlower Gardener
parent eebbcaf554
commit 106191ccf0
18 changed files with 342 additions and 233 deletions

View File

@ -19,6 +19,10 @@ load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_additional_binary_deps",
)
load(
"//tensorflow/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
# Config setting for determining if we are building for Android.
config_setting(
@ -536,13 +540,16 @@ exports_files(
],
)
gen_api_init_files(
name = "python_api_gen",
srcs = ["api_template.__init__.py"],
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs = [":python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python",
"//tensorflow/tools/api/generator:python_api",
],
deps = ["//tensorflow/python"],
)

View File

@ -22,9 +22,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# pylint: disable=wildcard-import
from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')

View File

@ -0,0 +1,43 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
del absolute_import
del division
del print_function
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
del python
del core
# pylint: enable=undefined-variable

View File

@ -725,7 +725,7 @@ endif()
########################################################
# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text)
FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text})
string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text})
string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text})
@ -736,7 +736,7 @@ foreach(api_init_file ${api_init_files_list})
string(STRIP "${api_init_file}" api_init_file)
if(api_init_file)
string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}")
list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}")
endif()
endforeach(api_init_file)
set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
@ -749,18 +749,14 @@ add_custom_command(
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
# this step is running since the files aren't there yet.
COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}"
# Re-add tensorflow/__init__.py back.
COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
"--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
"--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
"${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"

View File

@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
# Disable following manual tag in BUILD.
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py"
# These tests depend on a .so file
${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py
${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py
${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py
)
if (WIN32)

View File

@ -71,6 +71,7 @@ py_library(
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
"//tensorflow/tools/api/generator:__pkg__",
],
deps = [
":array_ops",

View File

@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "sycl_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
# CPU only tests should use tf_py_test, GPU tests use cuda_py_test
# Please avoid the py_tests and cuda_py_tests (plural) while we
@ -3029,3 +3030,60 @@ tf_py_test(
"//tensorflow/python/eager:tape",
],
)
# Custom op tests
tf_custom_op_library(
name = "ackermann_op.so",
srcs = ["ackermann_op.cc"],
)
tf_py_test(
name = "ackermann_test",
size = "small",
srcs = ["ackermann_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
],
data = [":ackermann_op.so"],
tags = ["no_pip"],
)
tf_custom_op_library(
name = "duplicate_op.so",
srcs = ["duplicate_op.cc"],
)
tf_py_test(
name = "duplicate_op_test",
size = "small",
srcs = ["duplicate_op_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
],
data = [":duplicate_op.so"],
tags = ["no_pip"],
)
tf_custom_op_library(
name = "invalid_op.so",
srcs = ["invalid_op.cc"],
)
tf_py_test(
name = "invalid_op_test",
size = "small",
srcs = ["invalid_op_test.py"],
additional_deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
],
data = [":invalid_op.so"],
tags = ["no_pip"],
)

View File

@ -17,17 +17,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import os
import tensorflow as tf
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class AckermannTest(tf.test.TestCase):
class AckermannTest(test.TestCase):
def testBasic(self):
library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
library_filename = os.path.join(resource_loader.get_data_files_path(),
'ackermann_op.so')
ackermann = tf.load_op_library(library_filename)
ackermann = load_library.load_op_library(library_filename)
self.assertEqual(len(ackermann.OP_LIST.op), 1)
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
@ -37,4 +39,4 @@ class AckermannTest(tf.test.TestCase):
if __name__ == '__main__':
tf.test.main()
test.main()

View File

@ -17,23 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import os
import tensorflow as tf
from tensorflow.python.framework import load_library
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class DuplicateOpTest(tf.test.TestCase):
class DuplicateOpTest(test.TestCase):
def testBasic(self):
library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
library_filename = os.path.join(resource_loader.get_data_files_path(),
'duplicate_op.so')
duplicate = tf.load_op_library(library_filename)
duplicate = load_library.load_op_library(library_filename)
self.assertEqual(len(duplicate.OP_LIST.op), 0)
with self.test_session():
self.assertEqual(tf.add(1, 41).eval(), 42)
self.assertEqual(math_ops.add(1, 41).eval(), 42)
if __name__ == '__main__':
tf.test.main()
test.main()

View File

@ -17,19 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import os
import tensorflow as tf
from tensorflow.python.framework import errors
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class InvalidOpTest(tf.test.TestCase):
class InvalidOpTest(test.TestCase):
def testBasic(self):
library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
library_filename = os.path.join(resource_loader.get_data_files_path(),
'invalid_op.so')
with self.assertRaises(tf.errors.InvalidArgumentError):
tf.load_op_library(library_filename)
with self.assertRaises(errors.InvalidArgumentError):
load_library.load_op_library(library_filename)
if __name__ == '__main__':
tf.test.main()
test.main()

View File

@ -88,9 +88,4 @@ def NewStatSummarizer(unused):
def DeleteStatSummarizer(stat_summarizer):
_DeleteStatSummarizer(stat_summarizer)
NewStatSummarizer._tf_api_names = ["contrib.stat_summarizer.NewStatSummarizer"]
DeleteStatSummarizer._tf_api_names = [
"contrib.stat_summarizer.DeleteStatSummarizer"]
StatSummarizer._tf_api_names = ["contrib.stat_summarizer.StatSummarizer"]
%}

View File

@ -9,8 +9,9 @@ py_binary(
name = "create_python_api",
srcs = ["create_python_api.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python",
"//tensorflow/python:no_contrib",
],
)
@ -23,116 +24,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
genrule(
name = "python_api_gen",
# List of API files. This list should include file name for
# every module exported using tf_export. For e.g. if an op is decorated with
# @tf_export('module1.module2', 'module3'). Then, outs should include
# api/module1/module2/__init__.py and api/module3/__init__.py.
# keep sorted
outs = [
# BEGIN GENERATED FILES
"api/__init__.py",
"api/app/__init__.py",
"api/bitwise/__init__.py",
"api/compat/__init__.py",
"api/contrib/__init__.py",
"api/contrib/stat_summarizer/__init__.py",
"api/data/__init__.py",
"api/distributions/__init__.py",
"api/distributions/bijectors/__init__.py",
"api/errors/__init__.py",
"api/estimator/__init__.py",
"api/estimator/export/__init__.py",
"api/estimator/inputs/__init__.py",
"api/feature_column/__init__.py",
"api/gfile/__init__.py",
"api/graph_util/__init__.py",
"api/image/__init__.py",
"api/initializers/__init__.py",
"api/keras/__init__.py",
"api/keras/activations/__init__.py",
"api/keras/applications/__init__.py",
"api/keras/applications/densenet/__init__.py",
"api/keras/applications/inception_resnet_v2/__init__.py",
"api/keras/applications/inception_v3/__init__.py",
"api/keras/applications/mobilenet/__init__.py",
"api/keras/applications/nasnet/__init__.py",
"api/keras/applications/resnet50/__init__.py",
"api/keras/applications/vgg16/__init__.py",
"api/keras/applications/vgg19/__init__.py",
"api/keras/applications/xception/__init__.py",
"api/keras/backend/__init__.py",
"api/keras/callbacks/__init__.py",
"api/keras/constraints/__init__.py",
"api/keras/datasets/__init__.py",
"api/keras/datasets/boston_housing/__init__.py",
"api/keras/datasets/cifar10/__init__.py",
"api/keras/datasets/cifar100/__init__.py",
"api/keras/datasets/fashion_mnist/__init__.py",
"api/keras/datasets/imdb/__init__.py",
"api/keras/datasets/mnist/__init__.py",
"api/keras/datasets/reuters/__init__.py",
"api/keras/estimator/__init__.py",
"api/keras/initializers/__init__.py",
"api/keras/layers/__init__.py",
"api/keras/losses/__init__.py",
"api/keras/metrics/__init__.py",
"api/keras/models/__init__.py",
"api/keras/optimizers/__init__.py",
"api/keras/preprocessing/__init__.py",
"api/keras/preprocessing/image/__init__.py",
"api/keras/preprocessing/sequence/__init__.py",
"api/keras/preprocessing/text/__init__.py",
"api/keras/regularizers/__init__.py",
"api/keras/utils/__init__.py",
"api/keras/wrappers/__init__.py",
"api/keras/wrappers/scikit_learn/__init__.py",
"api/layers/__init__.py",
"api/linalg/__init__.py",
"api/logging/__init__.py",
"api/losses/__init__.py",
"api/manip/__init__.py",
"api/math/__init__.py",
"api/metrics/__init__.py",
"api/nn/__init__.py",
"api/nn/rnn_cell/__init__.py",
"api/profiler/__init__.py",
"api/python_io/__init__.py",
"api/resource_loader/__init__.py",
"api/strings/__init__.py",
"api/saved_model/__init__.py",
"api/saved_model/builder/__init__.py",
"api/saved_model/constants/__init__.py",
"api/saved_model/loader/__init__.py",
"api/saved_model/main_op/__init__.py",
"api/saved_model/signature_constants/__init__.py",
"api/saved_model/signature_def_utils/__init__.py",
"api/saved_model/tag_constants/__init__.py",
"api/saved_model/utils/__init__.py",
"api/sets/__init__.py",
"api/sparse/__init__.py",
"api/spectral/__init__.py",
"api/summary/__init__.py",
"api/sysconfig/__init__.py",
"api/test/__init__.py",
"api/train/__init__.py",
"api/train/queue_runner/__init__.py",
"api/user_ops/__init__.py",
# END GENERATED FILES
],
cmd = "$(location create_python_api) $(OUTS)",
tools = ["create_python_api"],
)
py_library(
name = "python_api",
srcs = [":python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib:contrib_py", # keep
"//tensorflow/python", # keep
],
)

View File

@ -0,0 +1,125 @@
"""Targets for generating TensorFlow Python API __init__.py files."""
# keep sorted
TENSORFLOW_API_INIT_FILES = [
# BEGIN GENERATED FILES
"__init__.py",
"app/__init__.py",
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
"distributions/__init__.py",
"distributions/bijectors/__init__.py",
"errors/__init__.py",
"estimator/__init__.py",
"estimator/export/__init__.py",
"estimator/inputs/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
"graph_util/__init__.py",
"image/__init__.py",
"initializers/__init__.py",
"keras/__init__.py",
"keras/activations/__init__.py",
"keras/applications/__init__.py",
"keras/applications/densenet/__init__.py",
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
"keras/applications/vgg19/__init__.py",
"keras/applications/xception/__init__.py",
"keras/backend/__init__.py",
"keras/callbacks/__init__.py",
"keras/constraints/__init__.py",
"keras/datasets/__init__.py",
"keras/datasets/boston_housing/__init__.py",
"keras/datasets/cifar10/__init__.py",
"keras/datasets/cifar100/__init__.py",
"keras/datasets/fashion_mnist/__init__.py",
"keras/datasets/imdb/__init__.py",
"keras/datasets/mnist/__init__.py",
"keras/datasets/reuters/__init__.py",
"keras/estimator/__init__.py",
"keras/initializers/__init__.py",
"keras/layers/__init__.py",
"keras/losses/__init__.py",
"keras/metrics/__init__.py",
"keras/models/__init__.py",
"keras/optimizers/__init__.py",
"keras/preprocessing/__init__.py",
"keras/preprocessing/image/__init__.py",
"keras/preprocessing/sequence/__init__.py",
"keras/preprocessing/text/__init__.py",
"keras/regularizers/__init__.py",
"keras/utils/__init__.py",
"keras/wrappers/__init__.py",
"keras/wrappers/scikit_learn/__init__.py",
"layers/__init__.py",
"linalg/__init__.py",
"logging/__init__.py",
"losses/__init__.py",
"manip/__init__.py",
"math/__init__.py",
"metrics/__init__.py",
"nn/__init__.py",
"nn/rnn_cell/__init__.py",
"profiler/__init__.py",
"python_io/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
"saved_model/builder/__init__.py",
"saved_model/constants/__init__.py",
"saved_model/loader/__init__.py",
"saved_model/main_op/__init__.py",
"saved_model/signature_constants/__init__.py",
"saved_model/signature_def_utils/__init__.py",
"saved_model/tag_constants/__init__.py",
"saved_model/utils/__init__.py",
"sets/__init__.py",
"sparse/__init__.py",
"spectral/__init__.py",
"summary/__init__.py",
"sysconfig/__init__.py",
"test/__init__.py",
"train/__init__.py",
"train/queue_runner/__init__.py",
"user_ops/__init__.py",
# END GENERATED FILES
]
# Creates a genrule that generates a directory structure with __init__.py
# files that import all exported modules (i.e. modules with tf_export
# decorators).
#
# Args:
# name: name of genrule to create.
# output_files: List of __init__.py files that should be generated.
# This list should include file name for every module exported using
# tf_export. For e.g. if an op is decorated with
# @tf_export('module1.module2', 'module3'). Then, output_files should
# include module1/module2/__init__.py and module3/__init__.py.
# root_init_template: Python init file that should be used as template for
# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
# template will be replaced with root imports collected by this genrule.
# srcs: genrule sources. If passing root_init_template, the template file
# must be included in sources.
def gen_api_init_files(name,
output_files=TENSORFLOW_API_INIT_FILES,
root_init_template=None,
srcs=[]):
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
native.genrule(
name = name,
outs = output_files,
cmd = (
"$(location //tensorflow/tools/api/generator:create_python_api) " +
root_init_template_flag + " --apidir=$(@D) $(OUTS)"),
srcs = srcs,
tools = ["//tensorflow/tools/api/generator:create_python_api"],
)

View File

@ -29,9 +29,13 @@ from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
_API_DIR = '/api/'
_DEFAULT_PACKAGE = 'tensorflow.python'
_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
_GENFILES_DIR_SUFFIX = 'genfiles/'
_SYMBOLS_TO_SKIP_EXPLICITLY = {
# Overrides __getattr__, so that unwrapping tf_decorator
# would have side effects.
'tensorflow.python.platform.flags.FLAGS'
}
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
This file is MACHINE GENERATED! Do not edit.
@ -143,8 +147,8 @@ class _ModuleInitCodeBuilder(object):
# the script outputs.
module_text_map[''] = module_text_map.get('', '') + '''
_names_with_underscore = [%s]
__all__ = [s for s in dir() if not s.startswith('_')]
__all__.extend([s for s in _names_with_underscore])
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
''' % underscore_names_str
return module_text_map
@ -177,6 +181,9 @@ def get_api_init_text(package):
continue
for module_contents_name in dir(module):
if (module.__name__ + '.' + module_contents_name
in _SYMBOLS_TO_SKIP_EXPLICITLY):
continue
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
@ -189,7 +196,11 @@ def get_api_init_text(package):
-1, dest_module, module.__name__, value, names[-1])
continue
_, attr = tf_decorator.unwrap(attr)
try:
_, attr = tf_decorator.unwrap(attr)
except Exception as e:
print('5555: %s %s' % (module, module_contents_name), file=sys.stderr)
raise e
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
@ -204,6 +215,7 @@ def get_api_init_text(package):
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(module_code_builder.module_imports.keys())
import_from = '.'
for module in imported_modules:
if not module:
continue
@ -211,11 +223,9 @@ def get_api_init_text(package):
parent_module = '' # we import submodules in their parent_module
for submodule_index in range(len(module_split)):
import_from = _OUTPUT_MODULE
if submodule_index > 0:
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
import_from += '.' + parent_module
module_code_builder.add_import(
-1, parent_module, import_from,
module_split[submodule_index], module_split[submodule_index])
@ -223,7 +233,24 @@ def get_api_init_text(package):
return module_code_builder.build()
def create_api_files(output_files, package):
def get_module(dir_path, relative_to_dir):
"""Get module that corresponds to path relative to relative_to_dir.
Args:
dir_path: Path to directory.
relative_to_dir: Get module relative to this directory.
Returns:
module that corresponds to the given directory.
"""
dir_path = dir_path[len(relative_to_dir):]
# Convert path separators to '/' for easier parsing below.
dir_path = dir_path.replace(os.sep, '/')
return dir_path.replace('/', '.').strip('.')
def create_api_files(
output_files, package, root_init_template, output_dir):
"""Creates __init__.py files for the Python API.
Args:
@ -231,6 +258,10 @@ def create_api_files(output_files, package):
Each file must be under api/ directory.
package: Base python package containing python with target tf_export
decorators.
root_init_template: Template for top-level __init__.py file.
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
Raises:
ValueError: if an output file is not under api/ directory,
@ -238,18 +269,7 @@ def create_api_files(output_files, package):
"""
module_name_to_file_path = {}
for output_file in output_files:
# Convert path separators to '/' for easier parsing below.
normalized_output_file = output_file.replace(os.sep, '/')
if _API_DIR not in output_file:
raise ValueError(
'Output files must be in api/ directory, found %s.' % output_file)
# Get the module name that corresponds to output_file.
# First get module directory under _API_DIR.
module_dir = os.path.dirname(
normalized_output_file[
normalized_output_file.rfind(_API_DIR)+len(_API_DIR):])
# Convert / to .
module_name = module_dir.replace('/', '.').strip('.')
module_name = get_module(os.path.dirname(output_file), output_dir)
module_name_to_file_path[module_name] = os.path.normpath(output_file)
# Create file for each expected output in genrule.
@ -265,12 +285,20 @@ def create_api_files(output_files, package):
for module, text in module_text_map.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
module_file_path = '"api/%s/__init__.py"' % (
module_file_path = '"%s/__init__.py"' % (
module.replace('.', '/'))
missing_output_files.append(module_file_path)
continue
contents = ''
if module or not root_init_template:
contents = _GENERATED_FILE_HEADER + text
else:
# Read base init file
with open(root_init_template, 'r') as root_init_template_file:
contents = root_init_template_file.read()
contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
with open(module_name_to_file_path[module], 'w') as fp:
fp.write(_GENERATED_FILE_HEADER + text)
fp.write(contents)
if missing_output_files:
raise ValueError(
@ -292,6 +320,16 @@ def main():
'--package', default=_DEFAULT_PACKAGE, type=str,
help='Base package that imports modules containing the target tf_export '
'decorators.')
parser.add_argument(
'--root_init_template', default='', type=str,
help='Template for top level __init__.py file. '
'"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
parser.add_argument(
'--apidir', type=str, required=True,
help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
args = parser.parse_args()
if len(args.outputs) == 1:
@ -304,7 +342,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
create_api_files(outputs, args.package)
create_api_files(
outputs, args.package, args.root_init_template, args.apidir)
if __name__ == '__main__':

View File

@ -1,52 +0,0 @@
# Description:
# An example for custom op and kernel defined as a TensorFlow plugin.
package(
default_visibility = ["//tensorflow:internal"],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
tf_custom_op_library(
name = "ackermann_op.so",
srcs = ["ackermann_op.cc"],
)
tf_py_test(
name = "ackermann_test",
size = "small",
srcs = ["ackermann_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
data = [":ackermann_op.so"],
)
tf_custom_op_library(
name = "duplicate_op.so",
srcs = ["duplicate_op.cc"],
)
tf_py_test(
name = "duplicate_op_test",
size = "small",
srcs = ["duplicate_op_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
data = [":duplicate_op.so"],
)
tf_custom_op_library(
name = "invalid_op.so",
srcs = ["invalid_op.cc"],
)
tf_py_test(
name = "invalid_op_test",
size = "small",
srcs = ["invalid_op_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
data = [":invalid_op.so"],
)