Move the simple_estimator_example to the examples/ dir.

PiperOrigin-RevId: 191017560
This commit is contained in:
Anjali Sridhar 2018-03-29 18:15:34 -07:00 committed by TensorFlower Gardener
parent a8a95bf470
commit ac39aec50f
4 changed files with 39 additions and 44 deletions

View File

@ -162,6 +162,7 @@ tensorflow/contrib/decision_trees/proto
tensorflow/contrib/deprecated
tensorflow/contrib/distribute
tensorflow/contrib/distribute/python
tensorflow/contrib/distribute/python/examples
tensorflow/contrib/distributions
tensorflow/contrib/distributions/python
tensorflow/contrib/distributions/python/ops

View File

@ -340,21 +340,6 @@ py_test(
],
)
py_binary(
name = "simple_estimator_example",
srcs = ["simple_estimator_example.py"],
deps = [
":mirrored_strategy",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/estimator:model_fn",
],
)
py_library(
name = "cross_tower_utils",
srcs = ["cross_tower_utils.py"],

View File

@ -0,0 +1,19 @@
# Example TensorFlow models that use DistributionStrategy for training.
package(
default_visibility = [
"//tensorflow:internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_binary(
name = "simple_estimator_example",
srcs = ["simple_estimator_example.py"],
deps = [
"//tensorflow:tensorflow_py",
],
)

View File

@ -20,63 +20,53 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import training_util
import tensorflow as tf
def build_model_fn_optimizer():
"""Simple model_fn with optimizer."""
# TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is
# done?
optimizer = gradient_descent.GradientDescentOptimizer(0.2)
optimizer = tf.train.GradientDescentOptimizer(0.2)
def model_fn(features, labels, mode): # pylint: disable=unused-argument
"""model_fn which uses a single unit Dense layer."""
# You can also use the Flatten layer if you want to test a model without any
# weights.
layer = core.Dense(1, use_bias=True)
layer = tf.layers.Dense(1, use_bias=True)
logits = layer(features)
if mode == model_fn_lib.ModeKeys.PREDICT:
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {"logits": logits}
return model_fn_lib.EstimatorSpec(mode, predictions=predictions)
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
def loss_fn():
y = array_ops.reshape(logits, []) - constant_op.constant(1.)
y = tf.reshape(logits, []) - tf.constant(1.)
return y * y
if mode == model_fn_lib.ModeKeys.EVAL:
return model_fn_lib.EstimatorSpec(mode, loss=loss_fn())
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss_fn())
assert mode == model_fn_lib.ModeKeys.TRAIN
assert mode == tf.estimator.ModeKeys.TRAIN
global_step = training_util.get_global_step()
global_step = tf.train.get_global_step()
train_op = optimizer.minimize(loss_fn(), global_step=global_step)
return model_fn_lib.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op)
return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op)
return model_fn
def main(_):
distribution = mirrored_strategy.MirroredStrategy(
distribution = tf.contrib.distribute.MirroredStrategy(
["/device:GPU:0", "/device:GPU:1"])
config = run_config.RunConfig(distribute=distribution)
config = tf.estimator.RunConfig(distribute=distribution)
def input_fn():
features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
labels = dataset_ops.Dataset.from_tensors([1.]).repeat(10)
return dataset_ops.Dataset.zip((features, labels))
features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
return tf.data.Dataset.zip((features, labels))
estimator = estimator_lib.Estimator(
estimator = tf.estimator.Estimator(
model_fn=build_model_fn_optimizer(), config=config)
estimator.train(input_fn=input_fn, steps=10)
@ -84,7 +74,7 @@ def main(_):
print("Eval result: {}".format(eval_result))
def predict_input_fn():
predict_features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
return predict_features
predictions = estimator.predict(input_fn=predict_input_fn)
@ -94,4 +84,4 @@ def main(_):
if __name__ == "__main__":
app.run(main)
tf.app.run()