Move TensorForestEstimator to contrib, since that's where most of its code is and it will not be considered a canned estimator in the near future.

Change: 143989623
This commit is contained in:
A. Unique TensorFlower 2017-01-09 11:54:28 -08:00 committed by TensorFlower Gardener
parent c71ac2dce6
commit 7ad7e4dfae
7 changed files with 39 additions and 29 deletions

View File

@ -32,10 +32,6 @@ py_library(
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/session_bundle:gc",
"//tensorflow/contrib/tensor_forest:client_lib",
"//tensorflow/contrib/tensor_forest:data_ops_py",
"//tensorflow/contrib/tensor_forest:eval_metrics",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@ -674,21 +670,6 @@ py_test(
],
)
py_test(
name = "random_forest_test",
size = "medium",
srcs = ["python/learn/estimators/random_forest_test.py"],
srcs_version = "PY2AND3",
deps = [
":learn",
"//tensorflow/contrib/learn/python/learn/datasets",
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)
py_test(
name = "dynamic_rnn_estimator_test",
size = "medium",

View File

@ -322,8 +322,6 @@ from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
from tensorflow.contrib.learn.python.learn.estimators.run_config import ClusterConfig
from tensorflow.contrib.learn.python.learn.estimators.run_config import Environment
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig

View File

@ -121,6 +121,7 @@ py_library(
":constants",
":data_ops_py",
":eval_metrics",
":random_forest",
":tensor_forest_ops_py",
":tensor_forest_py",
],
@ -395,3 +396,34 @@ py_test(
"//tensorflow/python:variables",
],
)
py_library(
name = "random_forest",
srcs = ["client/random_forest.py"],
srcs_version = "PY2AND3",
deps = [
":client_lib",
":data_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:state_ops",
],
)
py_test(
name = "random_forest_test",
size = "medium",
srcs = ["client/random_forest_test.py"],
srcs_version = "PY2AND3",
deps = [
":random_forest",
":tensor_forest_py",
"//tensorflow/contrib/learn/python/learn/datasets",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//third_party/py/numpy",
],
)

View File

@ -19,4 +19,5 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.client import random_forest
# pylint: enable=unused-import

View File

@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable

View File

@ -28,7 +28,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.contrib.learn.python.learn.estimators import random_forest
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.platform import test

View File

@ -21,25 +21,24 @@ import argparse
import sys
import tempfile
import tensorflow as tf
# pylint: disable=g-backslash-continuation
from tensorflow.contrib.learn.python.learn\
import metric_spec
from tensorflow.contrib.learn.python.learn.estimators\
import random_forest
from tensorflow.contrib.tensor_forest.client\
import eval_metrics
from tensorflow.contrib.tensor_forest.client\
import random_forest
from tensorflow.contrib.tensor_forest.python\
import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.platform import app
FLAGS = None
def build_estimator(model_dir):
"""Build an estimator."""
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
params = tensor_forest.ForestHParams(
num_classes=10, num_features=784,
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
graph_builder_class = tensor_forest.RandomForestGraphs
@ -129,4 +128,4 @@ if __name__ == '__main__':
help='If true, use training loss as termination criteria.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
app.run(main=main, argv=[sys.argv[0]] + unparsed)