Move TensorForestEstimator to contrib, since that's where most of its code is and it will not be considered a canned estimator in the near future.
Change: 143989623
This commit is contained in:
parent
c71ac2dce6
commit
7ad7e4dfae
@ -32,10 +32,6 @@ py_library(
|
||||
"//tensorflow/contrib/rnn:rnn_py",
|
||||
"//tensorflow/contrib/session_bundle:exporter",
|
||||
"//tensorflow/contrib/session_bundle:gc",
|
||||
"//tensorflow/contrib/tensor_forest:client_lib",
|
||||
"//tensorflow/contrib/tensor_forest:data_ops_py",
|
||||
"//tensorflow/contrib/tensor_forest:eval_metrics",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -674,21 +670,6 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "random_forest_test",
|
||||
size = "medium",
|
||||
srcs = ["python/learn/estimators/random_forest_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":learn",
|
||||
"//tensorflow/contrib/learn/python/learn/datasets",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dynamic_rnn_estimator_test",
|
||||
size = "medium",
|
||||
|
@ -322,8 +322,6 @@ from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModeKeys
|
||||
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import ClusterConfig
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import Environment
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||
|
@ -121,6 +121,7 @@ py_library(
|
||||
":constants",
|
||||
":data_ops_py",
|
||||
":eval_metrics",
|
||||
":random_forest",
|
||||
":tensor_forest_ops_py",
|
||||
":tensor_forest_py",
|
||||
],
|
||||
@ -395,3 +396,34 @@ py_test(
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "random_forest",
|
||||
srcs = ["client/random_forest.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":client_lib",
|
||||
":data_ops_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "random_forest_test",
|
||||
size = "medium",
|
||||
srcs = ["client/random_forest_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":random_forest",
|
||||
":tensor_forest_py",
|
||||
"//tensorflow/contrib/learn/python/learn/datasets",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -19,4 +19,5 @@ from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.contrib.tensor_forest.client import eval_metrics
|
||||
from tensorflow.contrib.tensor_forest.client import random_forest
|
||||
# pylint: enable=unused-import
|
||||
|
@ -18,7 +18,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework import deprecated_arg_values
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.contrib.learn.python.learn import evaluable
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
@ -28,7 +28,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.datasets import base
|
||||
from tensorflow.contrib.learn.python.learn.estimators import random_forest
|
||||
from tensorflow.contrib.tensor_forest.client import random_forest
|
||||
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -21,25 +21,24 @@ import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# pylint: disable=g-backslash-continuation
|
||||
from tensorflow.contrib.learn.python.learn\
|
||||
import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn.estimators\
|
||||
import random_forest
|
||||
from tensorflow.contrib.tensor_forest.client\
|
||||
import eval_metrics
|
||||
from tensorflow.contrib.tensor_forest.client\
|
||||
import random_forest
|
||||
from tensorflow.contrib.tensor_forest.python\
|
||||
import tensor_forest
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def build_estimator(model_dir):
|
||||
"""Build an estimator."""
|
||||
params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
|
||||
params = tensor_forest.ForestHParams(
|
||||
num_classes=10, num_features=784,
|
||||
num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
|
||||
graph_builder_class = tensor_forest.RandomForestGraphs
|
||||
@ -129,4 +128,4 @@ if __name__ == '__main__':
|
||||
help='If true, use training loss as termination criteria.'
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user