Allow one tensor to be the input to the estimator.
PiperOrigin-RevId: 163747076
This commit is contained in:
parent
104f349e9e
commit
1333e77450
@ -68,6 +68,9 @@ def model_builder(features, labels, mode, params, config):
|
|||||||
|
|
||||||
center_bias = params["center_bias"]
|
center_bias = params["center_bias"]
|
||||||
|
|
||||||
|
if isinstance(features, ops.Tensor):
|
||||||
|
features = {features.name: features}
|
||||||
|
|
||||||
# Make a shallow copy of features to ensure downstream usage
|
# Make a shallow copy of features to ensure downstream usage
|
||||||
# is unaffected by modifications in the model function.
|
# is unaffected by modifications in the model function.
|
||||||
training_features = copy.copy(features)
|
training_features = copy.copy(features)
|
||||||
|
@ -142,7 +142,7 @@ def extract_features(features, feature_columns):
|
|||||||
"""Extracts columns from a dictionary of features.
|
"""Extracts columns from a dictionary of features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
features: `Tensor` or `dict` of `Tensor` objects.
|
features: `dict` of `Tensor` objects.
|
||||||
feature_columns: A list of feature_columns.
|
feature_columns: A list of feature_columns.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -161,9 +161,6 @@ def extract_features(features, feature_columns):
|
|||||||
if not features:
|
if not features:
|
||||||
raise ValueError("Features dictionary must be specified.")
|
raise ValueError("Features dictionary must be specified.")
|
||||||
|
|
||||||
if isinstance(features, ops.Tensor):
|
|
||||||
features = {features.name, features}
|
|
||||||
|
|
||||||
# Make a shallow copy of features to ensure downstream usage
|
# Make a shallow copy of features to ensure downstream usage
|
||||||
# is unaffected by modifications in the model function.
|
# is unaffected by modifications in the model function.
|
||||||
features = copy.copy(features)
|
features = copy.copy(features)
|
||||||
@ -277,7 +274,7 @@ class GradientBoostedDecisionTreeModel(object):
|
|||||||
examples based on the depth of the layer that's being built.
|
examples based on the depth of the layer that's being built.
|
||||||
learner_config: A learner config.
|
learner_config: A learner config.
|
||||||
print split, sorted_feature_names[split.feature_column]
|
print split, sorted_feature_names[split.feature_column]
|
||||||
features: `Tensor` or `dict` of `Tensor` objects.
|
features: `dict` of `Tensor` objects.
|
||||||
feature_columns: A list of feature columns.
|
feature_columns: A list of feature columns.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user