diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 5e19d597f02..585240ec898 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -278,6 +278,7 @@ filegroup( "//tensorflow/contrib/data/python/util:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/eager/python:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 89e9072fa0d..0bfbdb81686 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -23,6 +23,7 @@ py_library( "//tensorflow/contrib/data", "//tensorflow/contrib/deprecated:deprecated_py", "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD new file mode 100644 index 00000000000..cdad3e6e348 --- /dev/null +++ b/tensorflow/contrib/eager/python/BUILD @@ -0,0 +1,31 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "tfe", + srcs = ["tfe.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:core", + "//tensorflow/python/eager:custom_gradient", + "//tensorflow/python/eager:function", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "g3doc/sitemap.md", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py new file mode 100644 index 00000000000..1a8086cd510 --- /dev/null +++ b/tensorflow/contrib/eager/python/tfe.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Eager execution prototype. + +EXPERIMENTAL: APIs here are unstable and likely to change without notice. + +To use, at program startup, call `tfe.enable_eager_execution()`. + +@@list_devices +@@device + + +@@defun +@@implicit_gradients +@@implicit_value_and_gradients +@@gradients_function +@@value_and_gradients_function + +@@enable_tracing +@@flush_trace + +@@run +@@enable_eager_execution + +@@custom_gradient +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import +# +from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.eager import backprop +from tensorflow.python.eager.custom_gradient import custom_gradient +from tensorflow.python.eager import function +from tensorflow.python.eager.context import context +from tensorflow.python.eager.context import device +from tensorflow.python.eager.context import enable_eager_execution +from tensorflow.python.eager.context import run +from tensorflow.python.eager.core import enable_tracing + + +def list_devices(): + return context().devices() + +defun = function.defun +implicit_gradients = backprop.implicit_grad +implicit_value_and_gradients = backprop.implicit_val_and_grad +gradients_function = backprop.gradients_function +value_and_gradients_function = backprop.val_and_grad_function + +remove_undocumented(__name__)