Move the simple_estimator_example to the examples/ dir.
PiperOrigin-RevId: 191017560
This commit is contained in:
parent
a8a95bf470
commit
ac39aec50f
@ -162,6 +162,7 @@ tensorflow/contrib/decision_trees/proto
|
||||
tensorflow/contrib/deprecated
|
||||
tensorflow/contrib/distribute
|
||||
tensorflow/contrib/distribute/python
|
||||
tensorflow/contrib/distribute/python/examples
|
||||
tensorflow/contrib/distributions
|
||||
tensorflow/contrib/distributions/python
|
||||
tensorflow/contrib/distributions/python/ops
|
||||
|
@ -340,21 +340,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "simple_estimator_example",
|
||||
srcs = ["simple_estimator_example.py"],
|
||||
deps = [
|
||||
":mirrored_strategy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cross_tower_utils",
|
||||
srcs = ["cross_tower_utils.py"],
|
||||
|
19
tensorflow/contrib/distribute/python/examples/BUILD
Normal file
19
tensorflow/contrib/distribute/python/examples/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
# Example TensorFlow models that use DistributionStrategy for training.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_binary(
|
||||
name = "simple_estimator_example",
|
||||
srcs = ["simple_estimator_example.py"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
@ -20,63 +20,53 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distribute.python import mirrored_strategy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.layers import core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import training_util
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def build_model_fn_optimizer():
|
||||
"""Simple model_fn with optimizer."""
|
||||
# TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is
|
||||
# done?
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.2)
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.2)
|
||||
|
||||
def model_fn(features, labels, mode): # pylint: disable=unused-argument
|
||||
"""model_fn which uses a single unit Dense layer."""
|
||||
# You can also use the Flatten layer if you want to test a model without any
|
||||
# weights.
|
||||
layer = core.Dense(1, use_bias=True)
|
||||
layer = tf.layers.Dense(1, use_bias=True)
|
||||
logits = layer(features)
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
if mode == tf.estimator.ModeKeys.PREDICT:
|
||||
predictions = {"logits": logits}
|
||||
return model_fn_lib.EstimatorSpec(mode, predictions=predictions)
|
||||
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
|
||||
|
||||
def loss_fn():
|
||||
y = array_ops.reshape(logits, []) - constant_op.constant(1.)
|
||||
y = tf.reshape(logits, []) - tf.constant(1.)
|
||||
return y * y
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return model_fn_lib.EstimatorSpec(mode, loss=loss_fn())
|
||||
if mode == tf.estimator.ModeKeys.EVAL:
|
||||
return tf.estimator.EstimatorSpec(mode, loss=loss_fn())
|
||||
|
||||
assert mode == model_fn_lib.ModeKeys.TRAIN
|
||||
assert mode == tf.estimator.ModeKeys.TRAIN
|
||||
|
||||
global_step = training_util.get_global_step()
|
||||
global_step = tf.train.get_global_step()
|
||||
train_op = optimizer.minimize(loss_fn(), global_step=global_step)
|
||||
return model_fn_lib.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op)
|
||||
return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op)
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def main(_):
|
||||
distribution = mirrored_strategy.MirroredStrategy(
|
||||
distribution = tf.contrib.distribute.MirroredStrategy(
|
||||
["/device:GPU:0", "/device:GPU:1"])
|
||||
config = run_config.RunConfig(distribute=distribution)
|
||||
config = tf.estimator.RunConfig(distribute=distribution)
|
||||
|
||||
def input_fn():
|
||||
features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
labels = dataset_ops.Dataset.from_tensors([1.]).repeat(10)
|
||||
return dataset_ops.Dataset.zip((features, labels))
|
||||
features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
|
||||
return tf.data.Dataset.zip((features, labels))
|
||||
|
||||
estimator = estimator_lib.Estimator(
|
||||
estimator = tf.estimator.Estimator(
|
||||
model_fn=build_model_fn_optimizer(), config=config)
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
|
||||
@ -84,7 +74,7 @@ def main(_):
|
||||
print("Eval result: {}".format(eval_result))
|
||||
|
||||
def predict_input_fn():
|
||||
predict_features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
|
||||
return predict_features
|
||||
|
||||
predictions = estimator.predict(input_fn=predict_input_fn)
|
||||
@ -94,4 +84,4 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
tf.app.run()
|
Loading…
x
Reference in New Issue
Block a user