Delegate to C++ shape functions for python ops in training_ops and io_ops.
Change: 132345086
This commit is contained in:
parent
8cec0eebd1
commit
7946b8481f
@ -38,10 +38,11 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
|
||||
|
||||
// Trailing part of grad matches *s.
|
||||
ShapeHandle grad_subshape;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(grad, 1, &grad_subshape));
|
||||
TF_RETURN_IF_ERROR(c->Merge(*s, grad_subshape, s));
|
||||
// Trailing part of grad matches trailing part of *s.
|
||||
ShapeHandle grad_unknown_first;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
|
||||
TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -30,15 +30,14 @@ static void TestGradAndIndicesErrorHandling(ShapeInferenceTestOp op,
|
||||
grad_indices_spec, shape_spec_end);
|
||||
};
|
||||
|
||||
// mismatch between grad[1] and var[0].
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
shape_spec("[1]", "[?,2];[?]").c_str());
|
||||
// mismatch between grad[1] and var[1].
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal", op,
|
||||
shape_spec("[?,1]", "[?,2];[?]").c_str());
|
||||
// grad[0] and indices[0] must match.
|
||||
INFER_ERROR("Dimensions must be equal, but are 1 and 2", op,
|
||||
shape_spec("?", "[2,?];[1]").c_str());
|
||||
// grad is wrong rank.
|
||||
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
|
||||
shape_spec("[1]", "[2];[?]").c_str());
|
||||
INFER_ERROR("must be equal rank", op, shape_spec("[1]", "[?,2];[?]").c_str());
|
||||
// indices is wrong rank.
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 2", op,
|
||||
shape_spec("[?]", "[?];[1,2]").c_str());
|
||||
@ -74,7 +73,7 @@ TEST(TrainingOpsTest, SparseApplyProximalGradientDescent_ShapeFn) {
|
||||
ShapeInferenceTestOp op("SparseApplyProximalGradientDescent");
|
||||
|
||||
// Output is a merge of inputs 0 (var) and the non-indices part of 4 (delta).
|
||||
INFER_OK(op, "[1,?];[];[];[];[?,?,2];[3]", "[d0_0,d4_2]");
|
||||
INFER_OK(op, "[1,?];[];[];[];[?,2];[3]", "[d0_0,d4_1]");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "[];[];[]");
|
||||
|
||||
@ -109,14 +108,14 @@ TEST(TrainingOpsTest, SparseApplyAdadelta_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, 2, and non-indices part of 6 (var, accum,
|
||||
// accum_update, grad).
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,?,4];?",
|
||||
"[d0_0,d1_1,d2_2,d6_4]");
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[?,?,?,4];?",
|
||||
"[d0_0,d1_1,d2_2,d6_3]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[1];[];[];[];[1];?");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[2];[];[];[];[1];?");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[1];[];[];[];[?,2];?");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,1];[?,1];[];[];[];[?,2];?");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?;?;?;?");
|
||||
|
||||
@ -145,11 +144,11 @@ TEST(TrainingOpsTest, SparseApplyAdagrad_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
|
||||
// grad).
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?", "[d0_0,d1_1,d3_3]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[];[1];?");
|
||||
INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
|
||||
"[1];[1];[];[2];?");
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?", "[d0_0,d1_1,d3_2]");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,2];[];[?,1];?");
|
||||
INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
|
||||
"[?,1];[?,1];[];[?,?,2];?");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?");
|
||||
|
||||
@ -178,11 +177,11 @@ TEST(TrainingOpsTest, SparseApplyProximalAdagrad_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, and the non-indices part of 5 (var,
|
||||
// accum, grad).
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,?,3];?", "[d0_0,d1_1,d5_3]");
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[];[];[?,?,3];?", "[d0_0,d1_1,d5_2]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[];[];[];[?,1];?");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[];[];[];[?,2];?");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,1];[];[];[];[?,2];?");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?;?;?");
|
||||
|
||||
@ -217,14 +216,14 @@ TEST(TrainingOpsTest, SparseApplyFtrl_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, 2, and non-indices part of 3 (var, accum,
|
||||
// linear, grad).
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,?,4];?;[];[];[];[]",
|
||||
"[d0_0,d1_1,d2_2,d3_4]");
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[?,?,?,4];?;[];[];[];[]",
|
||||
"[d0_0,d1_1,d2_2,d3_3]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[1];[?,1];?;[];[];[];[]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[2];[?,1];?;[];[];[];[]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[1];[?,2];?;[];[];[];[]");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,1];[?,1];[?,2];?;[];[];[];[]");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?", ";?;?;?;?");
|
||||
|
||||
@ -255,11 +254,11 @@ TEST(TrainingOpsTest, SparseApplyMomentum_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, and non-indices part of 3 (var, accum,
|
||||
// grad).
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,?,3];?;[]", "[d0_0,d1_1,d3_3]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[];[?,1];?;[]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[];[?,2];?;[]");
|
||||
INFER_OK(op, "[1,?,?];[?,2,?];[];[?,?,3];?;[]", "[d0_0,d1_1,d3_2]");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,2];[];[?,1];?;[]");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,1];[];[?,2];?;[]");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?", ";?");
|
||||
|
||||
@ -316,14 +315,14 @@ TEST(TrainingOpsTest, SparseApplyRMSProp_ShapeFn) {
|
||||
|
||||
// Output is a merge of inputs 0, 1, 2, and the non-indices part of 7 (var,
|
||||
// ms, mom, and grad).
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,?,4];?",
|
||||
"[d0_0,d1_1,d2_2,d7_4]");
|
||||
INFER_OK(op, "[1,?,?,?];[?,2,?,?];[?,?,3,?];[];[];[];[];[?,?,?,4];?",
|
||||
"[d0_0,d1_1,d2_2,d7_3]");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[2];[1];[];[];[];[];[?,1];?");
|
||||
"[1];[2];[1];[];[];[];[];[1];?");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[2];[];[];[];[];[?,1];?");
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[1];[1];[1];[];[];[];[];[?,2];?");
|
||||
"[1];[1];[2];[];[];[];[];[1];?");
|
||||
INFER_ERROR("Dimension 1 in both shapes must be equal, but are 1 and 2", op,
|
||||
"[?,1];[?,1];[?,1];[];[];[];[];[?,2];?");
|
||||
|
||||
TestGradAndIndicesErrorHandling(op, "?;?;?;?;?;?");
|
||||
|
||||
|
@ -138,7 +138,6 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
# go/tf-wildcard-import
|
||||
@ -205,80 +204,12 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
|
||||
preferred_shard, name=name)
|
||||
|
||||
|
||||
@ops.RegisterShape("Restore")
|
||||
def _RestoreShape(op):
|
||||
"""Shape function for Restore op."""
|
||||
# Validate input shapes.
|
||||
unused_file_pattern = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_tensor_name = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.unknown_shape()]
|
||||
|
||||
|
||||
@ops.RegisterShape("RestoreSlice")
|
||||
def _RestoreSliceShape(op):
|
||||
"""Shape function for RestoreSlice op."""
|
||||
# Validate input shapes.
|
||||
unused_file_pattern = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_tensor_name = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_shape_and_slice_shape = op.inputs[2].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
# TODO(mrry): Attempt to parse the shape_and_slice value and use it
|
||||
# to form the shape of the output.
|
||||
return [tensor_shape.unknown_shape()]
|
||||
|
||||
|
||||
@ops.RegisterShape("Save")
|
||||
def _SaveShape(op):
|
||||
"""Shape function for Save op."""
|
||||
# Validate input shapes.
|
||||
unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
|
||||
data_count = len(op.inputs) - 2
|
||||
unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.vector(data_count))
|
||||
return []
|
||||
|
||||
|
||||
@ops.RegisterShape("SaveSlices")
|
||||
def _SaveSlicesShape(op):
|
||||
"""Shape function for SaveSlices op."""
|
||||
# Validate input shapes.
|
||||
unused_filename = op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
|
||||
data_count = len(op.inputs) - 3
|
||||
unused_tensor_names_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.vector(data_count))
|
||||
unused_shapes_and_slices_shape = op.inputs[2].get_shape().merge_with(
|
||||
tensor_shape.vector(data_count))
|
||||
# TODO(mrry): Attempt to parse the shapes_and_slices values and use
|
||||
# them to constrain the shape of the remaining inputs.
|
||||
return []
|
||||
|
||||
|
||||
@ops.RegisterShape("ShardedFilename")
|
||||
def _ShardedFilenameShape(op):
|
||||
"""Shape function for ShardedFilename op."""
|
||||
# Validate input shapes.
|
||||
unused_basename_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_shard_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_num_shards_shape = op.inputs[2].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.scalar()]
|
||||
|
||||
|
||||
@ops.RegisterShape("ShardedFilespec")
|
||||
def _ShardedFilespecShape(op):
|
||||
"""Shape function for ShardedFilespec op."""
|
||||
# Validate input shapes.
|
||||
unused_basename_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_num_shards_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.scalar()]
|
||||
ops.RegisterShape("Restore")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("RestoreSlice")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Save")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SaveSlices")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ShardedFilename")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ShardedFilespec")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
class ReaderBase(object):
|
||||
@ -574,61 +505,13 @@ ops.RegisterShape("TextLineReader")(common_shapes.scalar_shape)
|
||||
ops.RegisterShape("WholeFileReader")(common_shapes.scalar_shape)
|
||||
ops.RegisterShape("TFRecordReader")(common_shapes.scalar_shape)
|
||||
|
||||
|
||||
@ops.RegisterShape("ReaderNumRecordsProduced")
|
||||
@ops.RegisterShape("ReaderNumWorkUnitsCompleted")
|
||||
@ops.RegisterShape("ReaderSerializeState")
|
||||
def _ReaderScalarShape(op):
|
||||
"""Shape function for ops that transform a reader to a scalar."""
|
||||
unused_handle_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.scalar()]
|
||||
|
||||
|
||||
@ops.RegisterShape("ReaderRead")
|
||||
def _ReaderReadShape(op):
|
||||
"""Shape function for the ReaderBase.Read op."""
|
||||
unused_handle_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_queue_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.scalar(), tensor_shape.scalar()]
|
||||
|
||||
|
||||
@ops.RegisterShape("ReaderReadUpTo")
|
||||
def _ReaderReadUpToShape(_):
|
||||
"""Shape function for the ReaderBase.ReadUpTo op."""
|
||||
return [tensor_shape.unknown_shape(ndims=1),
|
||||
tensor_shape.unknown_shape(ndims=1)]
|
||||
|
||||
|
||||
@ops.RegisterShape("ReaderReset")
|
||||
def _ReaderResetShape(op):
|
||||
"""Shape function for the ReaderBase.Reset op."""
|
||||
unused_handle_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return []
|
||||
|
||||
|
||||
@ops.RegisterShape("ReaderRestoreState")
|
||||
def _ReaderRestoreStateShape(op):
|
||||
"""Shape function for the ReaderBase.Restore op."""
|
||||
unused_handle_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
unused_state_shape = op.inputs[1].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return []
|
||||
|
||||
|
||||
@ops.RegisterShape("ReadFile")
|
||||
def _ReadFileShape(op):
|
||||
"""Shape function for the ReadFile op."""
|
||||
return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())]
|
||||
|
||||
|
||||
@ops.RegisterShape("MatchingFiles")
|
||||
def _MatchingFilesShape(op):
|
||||
"""Shape function for the MatchingFiles op."""
|
||||
unused_patern_shape = op.inputs[0].get_shape().merge_with(
|
||||
tensor_shape.scalar())
|
||||
return [tensor_shape.unknown_shape(ndims=1)]
|
||||
ops.RegisterShape("ReaderNumRecordsProduced")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderNumWorkUnitsCompleted")(
|
||||
common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderSerializeState")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderRead")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.training import gen_training_ops
|
||||
@ -48,246 +49,23 @@ def _AssertInputIsScalar(op, index):
|
||||
op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyAdadelta")
|
||||
def _ApplyAdadeltaShape(op):
|
||||
"""Shape function for the ApplyAdadelta op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
accum_update_shape = op.inputs[2].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # rho
|
||||
_AssertInputIsScalar(op, 5) # epsilon
|
||||
grad_shape = op.inputs[6].get_shape().merge_with(accum_shape)
|
||||
return [grad_shape]
|
||||
|
||||
@ops.RegisterShape("ApplyAdagrad")
|
||||
def _ApplyAdagradShape(op):
|
||||
"""Shape function for the ApplyAdagrad op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
|
||||
return [grad_shape]
|
||||
|
||||
@ops.RegisterShape("ApplyProximalAdagrad")
|
||||
def _ApplyProximalAdagradShape(op):
|
||||
"""Shape function for the ApplyProximalAdagrad op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
_AssertInputIsScalar(op, 3) # l1
|
||||
_AssertInputIsScalar(op, 4) # l2
|
||||
grad_shape = op.inputs[5].get_shape().merge_with(accum_shape)
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyFtrl")
|
||||
def _ApplyFtrlShape(op):
|
||||
"""Shape function for the ApplyFtrlOp op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(linear_shape)
|
||||
_AssertInputIsScalar(op, 4) # lr
|
||||
_AssertInputIsScalar(op, 5) # l1
|
||||
_AssertInputIsScalar(op, 6) # l2
|
||||
_AssertInputIsScalar(op, 7) # lr_power
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyAdagradDA")
|
||||
def ApplyAdagradDAShape(op):
|
||||
"""Shape function for the ApplyAdagradDAOp op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(gg_accum_shape)
|
||||
_AssertInputIsScalar(op, 4) # lr
|
||||
_AssertInputIsScalar(op, 5) # l1
|
||||
_AssertInputIsScalar(op, 6) # l2
|
||||
_AssertInputIsScalar(op, 7) # global_step
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyAdam")
|
||||
def _ApplyAdamShape(op):
|
||||
"""Shape function for the ApplyAdam op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
m_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
v_shape = op.inputs[2].get_shape().merge_with(m_shape)
|
||||
_AssertInputIsScalar(op, 3) # beta1_power
|
||||
_AssertInputIsScalar(op, 4) # beta2_power
|
||||
_AssertInputIsScalar(op, 5) # lr
|
||||
_AssertInputIsScalar(op, 6) # beta1
|
||||
_AssertInputIsScalar(op, 7) # beta2
|
||||
_AssertInputIsScalar(op, 8) # epsilon
|
||||
grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyMomentum")
|
||||
def _ApplyMomentumShape(op):
|
||||
"""Shape function for the ApplyMomentum op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
|
||||
_AssertInputIsScalar(op, 4) # momentum
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyRMSProp")
|
||||
def _ApplyRMSPropShape(op):
|
||||
"""Shape function for the ApplyRMSProp op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # rho
|
||||
_AssertInputIsScalar(op, 5) # momentum
|
||||
_AssertInputIsScalar(op, 6) # epsilon
|
||||
grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
|
||||
return [grad_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyGradientDescent")
|
||||
def _ApplyGradientDescentShape(op):
|
||||
"""Shape function for the ApplyGradientDescent op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
_AssertInputIsScalar(op, 1) # alpha
|
||||
delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
|
||||
return [delta_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("ApplyProximalGradientDescent")
|
||||
def _ApplyProximalGradientDescentShape(op):
|
||||
"""Shape function for the ApplyProximalGradientDescent op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
_AssertInputIsScalar(op, 1) # alpha
|
||||
_AssertInputIsScalar(op, 2) # l1
|
||||
_AssertInputIsScalar(op, 3) # l2
|
||||
delta_shape = op.inputs[4].get_shape().merge_with(var_shape)
|
||||
return [delta_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyProximalGradientDescent")
|
||||
def _SparseApplyProximalGradientDescentShape(op):
|
||||
"""Shape function for the SparseApplyGradientDescent op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
_AssertInputIsScalar(op, 1) # lr
|
||||
_AssertInputIsScalar(op, 2) # l1
|
||||
_AssertInputIsScalar(op, 3) # l2
|
||||
grad_shape = op.inputs[4].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(var_shape[1:]))
|
||||
unused_indices_shape = op.inputs[5].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [var_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyRMSProp")
|
||||
def _SparseApplyRMSPropShape(op):
|
||||
"""Shape function for the SparseApplyRMSProp op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # rho
|
||||
_AssertInputIsScalar(op, 5) # momentum
|
||||
_AssertInputIsScalar(op, 6) # epsilon
|
||||
grad_shape = op.inputs[7].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
|
||||
unused_indices_shape = op.inputs[8].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [mom_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdadelta")
|
||||
def _SparseApplyAdadeltaShape(op):
|
||||
"""Shape function for the SparseApplyAdadelta op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_grad_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
accum_update_shape = op.inputs[2].get_shape().merge_with(accum_grad_shape)
|
||||
_AssertInputIsScalar(op, 3) # lr
|
||||
_AssertInputIsScalar(op, 4) # decay_rate
|
||||
_AssertInputIsScalar(op, 5) # epsilon
|
||||
grad_shape = op.inputs[6].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(accum_update_shape[1:]))
|
||||
unused_indices_shape = op.inputs[7].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [accum_update_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdagrad")
|
||||
def _SparseApplyAdagradShape(op):
|
||||
"""Shape function for the SparseApplyAdagrad op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
|
||||
unused_indices_shape = op.inputs[4].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [accum_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyProximalAdagrad")
|
||||
def _SparseApplyProximalAdagradShape(op):
|
||||
"""Shape function for the SparseApplyProximalAdagrad op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
_AssertInputIsScalar(op, 3) # l1
|
||||
_AssertInputIsScalar(op, 4) # l2
|
||||
grad_shape = op.inputs[5].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
|
||||
unused_indices_shape = op.inputs[6].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
return [accum_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyFtrl")
|
||||
def _SparseApplyFtrlShape(op):
|
||||
"""Shape function for the SparseApplyFtrl op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
linear_shape = op.inputs[2].get_shape().merge_with(accum_shape)
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(linear_shape[1:]))
|
||||
unused_indices_shape = op.inputs[4].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
_AssertInputIsScalar(op, 5) # lr
|
||||
_AssertInputIsScalar(op, 6) # l1
|
||||
_AssertInputIsScalar(op, 7) # l2
|
||||
_AssertInputIsScalar(op, 8) # lr_power
|
||||
return [linear_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyAdagradDA")
|
||||
def _SparseApplyAdagradDAShape(op):
|
||||
"""Shape function for the SparseApplyAdagradDA op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
g_accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
gg_accum_shape = op.inputs[2].get_shape().merge_with(g_accum_shape)
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(gg_accum_shape[1:]))
|
||||
unused_indices_shape = op.inputs[4].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
_AssertInputIsScalar(op, 5) # lr
|
||||
_AssertInputIsScalar(op, 6) # l1
|
||||
_AssertInputIsScalar(op, 7) # l2
|
||||
_AssertInputIsScalar(op, 8) # global_step
|
||||
return [gg_accum_shape]
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseApplyMomentum")
|
||||
def _SparseApplyMomentumShape(op):
|
||||
"""Shape function for the SparseApplyMomentum op."""
|
||||
var_shape = op.inputs[0].get_shape()
|
||||
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
|
||||
_AssertInputIsScalar(op, 2) # lr
|
||||
grad_shape = op.inputs[3].get_shape().merge_with(
|
||||
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
|
||||
unused_indices_shape = op.inputs[4].get_shape().merge_with(
|
||||
tensor_shape.vector(grad_shape[0]))
|
||||
_AssertInputIsScalar(op, 5) # momentum
|
||||
return [accum_shape]
|
||||
ops.RegisterShape("ApplyAdadelta")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyAdagrad")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyFtrl")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyAdam")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyMomentum")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyRMSProp")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyGradientDescent")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ApplyProximalGradientDescent")(
|
||||
common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyProximalGradientDescent")(
|
||||
common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyRMSProp")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyAdadelta")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyAdagrad")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyProximalAdagrad")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyFtrl")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyAdagradDA")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseApplyMomentum")(common_shapes.call_cpp_shape_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user