Switch ops in functional_ops, math_ops, state_ops, string_ops,
contrib/quantization, and contrib/layers to use C++ shape inference functions. Implement Betainc C++ shape inference function; it's a little different from the python one, mainly because propagating an unknown input (a_shape in this case) when all are unknown is not correct for C++ (in case we later backwards bind the value), since it could end up being a broadcasted scalar. Change ReduceJoin's C++ shape inference function to match python. Fix SymbolicGradient's shape function in C++ - there are actually more inputs than outputs. Changes QuantizedBiasAdd to be more precise (the C++ shape fn invokes the common bias_add to compute the shape of the first output). Change: 132688147
This commit is contained in:
parent
b7541f67c2
commit
29faf36ed2
@ -17,10 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
@ -87,13 +87,5 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0,
|
||||
return ops.SparseTensor(indices_out, values_out, shape_out)
|
||||
|
||||
|
||||
@ops.RegisterShape("SparseFeatureCross")
|
||||
def _SparseFeatureCrossShape(unused_op): # pylint: disable=invalid-name
|
||||
return [
|
||||
tensor_shape.matrix(None, 2),
|
||||
tensor_shape.vector(None),
|
||||
tensor_shape.vector(2)
|
||||
]
|
||||
|
||||
|
||||
ops.RegisterShape("SparseFeatureCross")(common_shapes.call_cpp_shape_fn)
|
||||
ops.NoGradient("SparseFeatureCross")
|
||||
|
@ -23,16 +23,6 @@ from tensorflow.contrib.quantization.ops import gen_math_ops
|
||||
from tensorflow.contrib.quantization.ops.gen_math_ops import *
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
|
||||
# QuantizedMatMul* ops.
|
||||
@ops.RegisterShape("QuantizedMatMul")
|
||||
def _QuantizedMatMulShape(op):
|
||||
unused_a_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_a_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_b_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_b_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
|
||||
result = common_shapes.matmul_shape(op)
|
||||
result.extend([tensor_shape.scalar(), tensor_shape.scalar()])
|
||||
return result
|
||||
ops.RegisterShape("QuantizedMatMul")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -23,60 +23,12 @@ from tensorflow.contrib.quantization.ops import gen_nn_ops
|
||||
from tensorflow.contrib.quantization.ops.gen_nn_ops import *
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
|
||||
# QuantizedAvgPool* ops.
|
||||
@ops.RegisterShape("QuantizedAvgPool")
|
||||
def _QuantizedAvgPoolShape(op):
|
||||
return [common_shapes.avg_pool_shape(op)[0], tensor_shape.scalar(),
|
||||
tensor_shape.scalar()]
|
||||
|
||||
|
||||
# QuantizedBiasAdd op.
|
||||
@ops.RegisterShape("QuantizedBiasAdd")
|
||||
def _QuantizedBiasAddShape(op):
|
||||
"""Returns the same shape as the input, plus min and max scalar values.
|
||||
|
||||
Args:
|
||||
op: Input operation.
|
||||
Returns:
|
||||
Shape of ops first input, plus min and max tensors.
|
||||
"""
|
||||
unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_bias_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_bias_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
|
||||
return [op.inputs[0].get_shape(), tensor_shape.scalar(),
|
||||
tensor_shape.scalar()]
|
||||
|
||||
|
||||
# QuantizedConv2D* ops.
|
||||
@ops.RegisterShape("QuantizedConv2D")
|
||||
def _QuantizedConv2DShape(op):
|
||||
"""Returns the same shape as Conv2D, plus min and max scalar values.
|
||||
|
||||
Args:
|
||||
op: Input operation.
|
||||
Returns:
|
||||
Shape of float Conv2D, plus min and max tensors.
|
||||
"""
|
||||
unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_filter_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar())
|
||||
unused_filter_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar())
|
||||
result = common_shapes.conv2d_shape(op)
|
||||
result.extend([tensor_shape.scalar(), tensor_shape.scalar()])
|
||||
return result
|
||||
|
||||
|
||||
# QuantizedMaxPool* ops.
|
||||
@ops.RegisterShape("QuantizedMaxPool")
|
||||
def _QuantizedMaxPoolShape(op):
|
||||
return [common_shapes.max_pool_shape(op)[0], tensor_shape.scalar(),
|
||||
tensor_shape.scalar()]
|
||||
|
||||
|
||||
ops.RegisterShape("QuantizedAvgPool")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedBiasAdd")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedConv2D")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedMaxPool")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedRelu")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedRelu6")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("QuantizedReluX")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -657,6 +657,64 @@ Status ReductionShape(InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReductionShapeForReduceJoin(InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
|
||||
ShapeHandle indices;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
|
||||
|
||||
bool keep_dims;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));
|
||||
|
||||
const Tensor* reduction_indices_t = c->input_tensor(1);
|
||||
if (reduction_indices_t == nullptr || !c->RankKnown(input)) {
|
||||
// If we do not have the reduction values at runtime, or the
|
||||
// rank of the input, we don't know the output shape.
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
const int32 input_rank = c->Rank(input);
|
||||
std::set<int32> true_indices;
|
||||
auto reduction_indices = reduction_indices_t->flat<int32>();
|
||||
for (int i = 0; i < reduction_indices_t->NumElements(); ++i) {
|
||||
int32 reduction_index = reduction_indices(i);
|
||||
if (reduction_index < -input_rank || reduction_index >= input_rank) {
|
||||
return errors::InvalidArgument("Invalid reduction dimension ",
|
||||
reduction_index, " for input with ",
|
||||
input_rank, " dimensions.");
|
||||
}
|
||||
|
||||
int32 wrapped_index = reduction_index;
|
||||
if (wrapped_index < 0) {
|
||||
wrapped_index += input_rank;
|
||||
}
|
||||
|
||||
if (!true_indices.insert(wrapped_index).second) {
|
||||
return errors::InvalidArgument("Duplicate reduction index ",
|
||||
wrapped_index);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<DimensionHandle> dims;
|
||||
bool reduce_all = (reduction_indices_t->NumElements() == 0);
|
||||
for (int i = 0; i < input_rank; ++i) {
|
||||
if (reduce_all || true_indices.count(i) > 0) {
|
||||
if (c->Value(c->Dim(input, i)) == 0) {
|
||||
return errors::InvalidArgument("Cannot reduce dimension ", i,
|
||||
" with size 0");
|
||||
}
|
||||
if (keep_dims) {
|
||||
dims.emplace_back(c->MakeDim(1));
|
||||
}
|
||||
} else {
|
||||
dims.emplace_back(c->Dim(input, i));
|
||||
}
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape(dims));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConcatShape(InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
|
@ -178,6 +178,10 @@ Status UnknownShape(shape_inference::InferenceContext* c);
|
||||
// Shape function for reduction operations.
|
||||
Status ReductionShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for reduction operations where an empty reduction indices
|
||||
// vector means to reduce all.
|
||||
Status ReductionShapeForReduceJoin(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for concat operations.
|
||||
Status ConcatShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -739,5 +739,64 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) {
|
||||
INFER_OK(op, "[?,?,?];[2]", "[?,?,?]");
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, ReduceWithEmptyReductionIndices_ShapeFn) {
|
||||
ShapeInferenceTestOp op("ReduceJoin");
|
||||
op.input_tensors.resize(2);
|
||||
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "ReduceJoin")
|
||||
.Input("input", 0, DT_STRING)
|
||||
.Input("reduction_indices", 1, DT_INT32)
|
||||
.Attr("keep_dims", false)
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// Reduction indices not available, so output is unknown.
|
||||
INFER_OK(op, "[2,4,5];[2]", "?");
|
||||
INFER_OK(op, "?;[2]", "?");
|
||||
|
||||
Tensor indices = test::AsTensor<int32>({1, 2});
|
||||
op.input_tensors[1] = &indices;
|
||||
|
||||
// Reduction indices available
|
||||
INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
|
||||
|
||||
// Wrapped indices
|
||||
indices = test::AsTensor<int32>({-1, -2});
|
||||
op.input_tensors[1] = &indices;
|
||||
INFER_OK(op, "[2,4,5];[2]", "[d0_0]");
|
||||
|
||||
// Scalar
|
||||
indices = test::AsScalar<int32>(0);
|
||||
op.input_tensors[1] = &indices;
|
||||
INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]");
|
||||
|
||||
indices = test::AsScalar<int32>(-4);
|
||||
op.input_tensors[1] = &indices;
|
||||
INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]");
|
||||
|
||||
// Empty reduction indices. Unlike Reduce_ShapeFn, this reduces all dims away.
|
||||
indices = test::AsTensor<int32>({});
|
||||
op.input_tensors[1] = &indices;
|
||||
INFER_OK(op, "[2,4,5];[0]", "[]");
|
||||
|
||||
// Keep dims = true
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", op.name)
|
||||
.Input("input", 0, DT_STRING)
|
||||
.Input("reduction_indices", 1, DT_INT32)
|
||||
.Attr("keep_dims", true)
|
||||
.Finalize(&op.node_def));
|
||||
indices = test::AsTensor<int32>({-1, -2});
|
||||
op.input_tensors[1] = &indices;
|
||||
INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]");
|
||||
|
||||
// input rank is known, but reduction indices are not (with keep_dim=true).
|
||||
// The output rank is unknown because reduction indices could end up being
|
||||
// empty and cause it all to be reduced.
|
||||
op.input_tensors[1] = nullptr;
|
||||
INFER_OK(op, "[?,?,?];?", "?");
|
||||
// TODO(cwhipkey): in this case, it could output [?,?,?], because the shape of
|
||||
// reduction indices is known to be non-empty.
|
||||
INFER_OK(op, "[?,?,?];[2]", "?");
|
||||
}
|
||||
|
||||
} // namespace shape_inference
|
||||
} // namespace tensorflow
|
||||
|
@ -29,13 +29,13 @@ REGISTER_OP("SymbolicGradient")
|
||||
.Attr("Tout: list(type)")
|
||||
.Attr("f: func")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
if (c->num_inputs() != c->num_outputs()) {
|
||||
return errors::InvalidArgument("len(inputs) != len(outputs)");
|
||||
if (c->num_inputs() < c->num_outputs()) {
|
||||
return errors::InvalidArgument("len(inputs) < len(outputs)");
|
||||
}
|
||||
// Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
|
||||
// (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
|
||||
// outputs (dx, dy, dz) are the same as (x, y, z).
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->input(i));
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -24,22 +24,27 @@ namespace tensorflow {
|
||||
|
||||
TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) {
|
||||
ShapeInferenceTestOp op("SymbolicGradient");
|
||||
int n = 4;
|
||||
int num_inputs = 4;
|
||||
int num_outputs = 3;
|
||||
std::vector<NodeDefBuilder::NodeOut> src_list;
|
||||
std::vector<DataType> type_list;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
type_list.emplace_back(DT_FLOAT);
|
||||
std::vector<DataType> in_type_list;
|
||||
std::vector<DataType> out_type_list;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
in_type_list.emplace_back(DT_FLOAT);
|
||||
src_list.emplace_back("a", 0, DT_FLOAT);
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
out_type_list.emplace_back(DT_FLOAT);
|
||||
}
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "SymbolicGradient")
|
||||
.Input(src_list)
|
||||
.Attr("Tin", type_list)
|
||||
.Attr("Tout", type_list)
|
||||
.Attr("Tin", in_type_list)
|
||||
.Attr("Tout", out_type_list)
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
// Inputs transferred to outputs.
|
||||
INFER_OK(op, "?;?;?;?", "in0;in1;in2;in3");
|
||||
INFER_OK(op, "[];[2];?;?", "in0;in1;in2;in3");
|
||||
INFER_OK(op, "?;?;?;?", "in0;in1;in2");
|
||||
INFER_OK(op, "[];[2];?;?", "in0;in1;in2");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -714,6 +714,38 @@ REGISTER_OP("Betainc")
|
||||
.Input("x: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const int num_inputs = 3;
|
||||
ShapeHandle output = c->UnknownShape();
|
||||
int num_scalars = 0;
|
||||
ShapeHandle some_non_scalar;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
ShapeHandle in = c->input(i);
|
||||
if (!c->RankKnown(in)) {
|
||||
some_non_scalar = in;
|
||||
// An input with unknown rank could be either a scalar (to be
|
||||
// broadcast) or some other shape.
|
||||
} else if (c->Rank(in) == 0) {
|
||||
// Input is a scalar, it will be broadcast to the output shape.
|
||||
++num_scalars;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
|
||||
some_non_scalar = output;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_scalars == num_inputs - 1) {
|
||||
// If all but one input is known to be a scalar, then output is the
|
||||
// remaining input.
|
||||
output = some_non_scalar;
|
||||
} else if (num_scalars == num_inputs) {
|
||||
// If all are scalars, output is scalar; pick the first one arbitrarily.
|
||||
output = c->input(0);
|
||||
}
|
||||
|
||||
c->set_output(0, output);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
|
||||
|
||||
|
@ -431,4 +431,27 @@ TEST(MathOpsTest, ArgOps_ShapeFn) {
|
||||
INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, Betainc_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Betainc");
|
||||
|
||||
INFER_OK(op, "?;?;?", "?");
|
||||
INFER_OK(op, "[?,?];?;?", "in0");
|
||||
INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
|
||||
INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
|
||||
|
||||
INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
|
||||
INFER_OK(op, "[];[];[?,?,3]", "in2");
|
||||
|
||||
// All but one is a scalar, so use it.
|
||||
INFER_OK(op, "[];[];?", "in2");
|
||||
INFER_OK(op, "[];[];[1,2,3,4]", "in2");
|
||||
|
||||
// All scalar input; implementation picks in0.
|
||||
INFER_OK(op, "[];[];[]", "in0");
|
||||
|
||||
// Non-scalars must match shape.
|
||||
INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
|
||||
INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -95,7 +95,7 @@ REGISTER_OP("ReduceJoin")
|
||||
.Attr("keep_dims: bool = false")
|
||||
.Attr("separator: string = ''")
|
||||
.Output("output: string")
|
||||
.SetShapeFn(shape_inference::ReductionShape)
|
||||
.SetShapeFn(shape_inference::ReductionShapeForReduceJoin)
|
||||
.Doc(R"doc(
|
||||
Joins a string Tensor across the given dimensions.
|
||||
|
||||
|
@ -75,7 +75,7 @@ class BetaincTest(tf.test.TestCase):
|
||||
special.betainc(0.1, 0.1, 0.1).astype(np_dt),
|
||||
tf.betainc(0.1, 0.1, 0.1).eval(), rtol=tol, atol=tol)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "Shapes .* are not compatible"):
|
||||
with self.assertRaisesRegexp(ValueError, "must be equal"):
|
||||
tf.betainc(0.5, [0.5], [[0.5]])
|
||||
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
|
@ -258,7 +258,7 @@ class ReduceJoinTest(UnicodeTestCase):
|
||||
|
||||
def testInvalidReductionIndices(self):
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(ValueError, "scalar"):
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"):
|
||||
tf.reduce_join(inputs="", reduction_indices=0)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Invalid reduction dimension -3"):
|
||||
|
@ -30,6 +30,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -576,9 +577,4 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
return output_pack(results_flat)
|
||||
|
||||
|
||||
@ops.RegisterShape("SymbolicGradient")
|
||||
def _symbolic_gradient_shape(op):
|
||||
# Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
|
||||
# (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
|
||||
# outputs (dx, dy, dz) are the same as (x, y, z).
|
||||
return [op.inputs[i].get_shape() for i in range(len(op.outputs))]
|
||||
ops.RegisterShape("SymbolicGradient")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -1842,21 +1842,7 @@ def _BroadcastShape(op):
|
||||
op.inputs[1].get_shape())]
|
||||
|
||||
|
||||
@ops.RegisterShape("Betainc")
|
||||
def _BetaincOpShape(op): # pylint: disable=invalid-name
|
||||
"""Shape function for BetaincOp."""
|
||||
a_shape = op.inputs[0].get_shape()
|
||||
b_shape = op.inputs[1].get_shape()
|
||||
x_shape = op.inputs[2].get_shape()
|
||||
merged_shape = tensor_shape.TensorShape(None)
|
||||
for shape in (a_shape, b_shape, x_shape):
|
||||
if shape.ndims != 0:
|
||||
merged_shape = merged_shape.merge_with(shape)
|
||||
# Scalars get broadcasted; non-scalar shapes must all match.
|
||||
# Output will be the merged non-scalar shape, if any.
|
||||
return [merged_shape if merged_shape.ndims is not None else a_shape]
|
||||
|
||||
|
||||
ops.RegisterShape("Betainc")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseDenseCwiseMul")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseDenseCwiseDiv")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("SparseDenseCwiseAdd")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -147,8 +147,7 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
|
||||
|
||||
|
||||
# NOTE(mrry): Shapes are conditionally set in the Python wrapper.
|
||||
ops.RegisterShape("Variable")(common_shapes.unknown_shape)
|
||||
|
||||
ops.RegisterShape("Variable")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("IsVariableInitialized")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("TemporaryVariable")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("DestroyTemporaryVariable")(common_shapes.call_cpp_shape_fn)
|
||||
|
@ -51,8 +51,6 @@ import six
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
@ -133,47 +131,7 @@ ops.RegisterShape("DecodeBase64")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
@ops.RegisterShape("ReduceJoin")
|
||||
def _ReduceJoinShape(op):
|
||||
"""Shape function for the ReduceJoin op."""
|
||||
reduction_indices = tensor_util.constant_value(op.inputs[1])
|
||||
if reduction_indices is None:
|
||||
return [tensor_shape.unknown_shape()]
|
||||
|
||||
input_shape = op.inputs[0].get_shape()
|
||||
keep_dims = op.get_attr("keep_dims")
|
||||
|
||||
if input_shape.ndims is None:
|
||||
return [tensor_shape.unknown_shape()]
|
||||
|
||||
if input_shape.ndims == 0:
|
||||
raise ValueError("Input string tensor cannot be a scalar.")
|
||||
|
||||
true_indices = set()
|
||||
for reduction_index in np.ravel(reduction_indices):
|
||||
if (reduction_index < -input_shape.ndims or
|
||||
reduction_index >= input_shape.ndims):
|
||||
raise ValueError("Invalid reduction dimension %d for input with %d "
|
||||
"dimensions" % (reduction_index, input_shape.ndims))
|
||||
|
||||
true_index = reduction_index % input_shape.ndims
|
||||
if true_index in true_indices:
|
||||
raise ValueError("Duplicate reduction index %d." % reduction_index)
|
||||
|
||||
if input_shape.dims[true_index] == 0:
|
||||
raise ValueError("Cannot reduce dimension %d with size 0." %
|
||||
reduction_index)
|
||||
|
||||
true_indices.add(true_index)
|
||||
|
||||
returned_dims = []
|
||||
reduce_all = reduction_indices.size == 0
|
||||
for i, dim in enumerate(input_shape.dims):
|
||||
if reduce_all or i in true_indices:
|
||||
if keep_dims:
|
||||
returned_dims.append(1)
|
||||
else:
|
||||
returned_dims.append(dim)
|
||||
|
||||
return [tensor_shape.TensorShape(returned_dims)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
ops.RegisterShape("StringJoin")(common_shapes.call_cpp_shape_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user