diff --git a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py index fa279d9ae1e..528b85a02b5 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py +++ b/tensorflow/contrib/layers/python/ops/sparse_feature_cross_op.py @@ -17,10 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import load_library from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import math_ops from tensorflow.python.platform import resource_loader @@ -87,13 +87,5 @@ def sparse_feature_cross(inputs, hashed_output=False, num_buckets=0, return ops.SparseTensor(indices_out, values_out, shape_out) -@ops.RegisterShape("SparseFeatureCross") -def _SparseFeatureCrossShape(unused_op): # pylint: disable=invalid-name - return [ - tensor_shape.matrix(None, 2), - tensor_shape.vector(None), - tensor_shape.vector(2) - ] - - +ops.RegisterShape("SparseFeatureCross")(common_shapes.call_cpp_shape_fn) ops.NoGradient("SparseFeatureCross") diff --git a/tensorflow/contrib/quantization/python/math_ops.py b/tensorflow/contrib/quantization/python/math_ops.py index 43c1409358c..d4fabbd36bd 100644 --- a/tensorflow/contrib/quantization/python/math_ops.py +++ b/tensorflow/contrib/quantization/python/math_ops.py @@ -23,16 +23,6 @@ from tensorflow.contrib.quantization.ops import gen_math_ops from tensorflow.contrib.quantization.ops.gen_math_ops import * from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -# QuantizedMatMul* ops. -@ops.RegisterShape("QuantizedMatMul") -def _QuantizedMatMulShape(op): - unused_a_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - unused_a_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - unused_b_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar()) - unused_b_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar()) - result = common_shapes.matmul_shape(op) - result.extend([tensor_shape.scalar(), tensor_shape.scalar()]) - return result +ops.RegisterShape("QuantizedMatMul")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/quantization/python/nn_ops.py b/tensorflow/contrib/quantization/python/nn_ops.py index 122d93fd236..d31f1d4e686 100644 --- a/tensorflow/contrib/quantization/python/nn_ops.py +++ b/tensorflow/contrib/quantization/python/nn_ops.py @@ -23,60 +23,12 @@ from tensorflow.contrib.quantization.ops import gen_nn_ops from tensorflow.contrib.quantization.ops.gen_nn_ops import * from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape - - -# QuantizedAvgPool* ops. -@ops.RegisterShape("QuantizedAvgPool") -def _QuantizedAvgPoolShape(op): - return [common_shapes.avg_pool_shape(op)[0], tensor_shape.scalar(), - tensor_shape.scalar()] - - -# QuantizedBiasAdd op. -@ops.RegisterShape("QuantizedBiasAdd") -def _QuantizedBiasAddShape(op): - """Returns the same shape as the input, plus min and max scalar values. - - Args: - op: Input operation. - Returns: - Shape of ops first input, plus min and max tensors. - """ - unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - unused_bias_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar()) - unused_bias_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar()) - return [op.inputs[0].get_shape(), tensor_shape.scalar(), - tensor_shape.scalar()] - - -# QuantizedConv2D* ops. -@ops.RegisterShape("QuantizedConv2D") -def _QuantizedConv2DShape(op): - """Returns the same shape as Conv2D, plus min and max scalar values. - - Args: - op: Input operation. - Returns: - Shape of float Conv2D, plus min and max tensors. - """ - unused_input_min = op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - unused_input_max = op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - unused_filter_min = op.inputs[4].get_shape().merge_with(tensor_shape.scalar()) - unused_filter_max = op.inputs[5].get_shape().merge_with(tensor_shape.scalar()) - result = common_shapes.conv2d_shape(op) - result.extend([tensor_shape.scalar(), tensor_shape.scalar()]) - return result - - -# QuantizedMaxPool* ops. -@ops.RegisterShape("QuantizedMaxPool") -def _QuantizedMaxPoolShape(op): - return [common_shapes.max_pool_shape(op)[0], tensor_shape.scalar(), - tensor_shape.scalar()] +ops.RegisterShape("QuantizedAvgPool")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("QuantizedBiasAdd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("QuantizedConv2D")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("QuantizedMaxPool")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("QuantizedRelu")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("QuantizedRelu6")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("QuantizedReluX")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 8df470aa229..7b6a4eecccf 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -657,6 +657,64 @@ Status ReductionShape(InferenceContext* c) { return Status::OK(); } +Status ReductionShapeForReduceJoin(InferenceContext* c) { + ShapeHandle input = c->input(0); + + ShapeHandle indices; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); + + bool keep_dims; + TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); + + const Tensor* reduction_indices_t = c->input_tensor(1); + if (reduction_indices_t == nullptr || !c->RankKnown(input)) { + // If we do not have the reduction values at runtime, or the + // rank of the input, we don't know the output shape. + return shape_inference::UnknownShape(c); + } + + const int32 input_rank = c->Rank(input); + std::set true_indices; + auto reduction_indices = reduction_indices_t->flat(); + for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { + int32 reduction_index = reduction_indices(i); + if (reduction_index < -input_rank || reduction_index >= input_rank) { + return errors::InvalidArgument("Invalid reduction dimension ", + reduction_index, " for input with ", + input_rank, " dimensions."); + } + + int32 wrapped_index = reduction_index; + if (wrapped_index < 0) { + wrapped_index += input_rank; + } + + if (!true_indices.insert(wrapped_index).second) { + return errors::InvalidArgument("Duplicate reduction index ", + wrapped_index); + } + } + + std::vector dims; + bool reduce_all = (reduction_indices_t->NumElements() == 0); + for (int i = 0; i < input_rank; ++i) { + if (reduce_all || true_indices.count(i) > 0) { + if (c->Value(c->Dim(input, i)) == 0) { + return errors::InvalidArgument("Cannot reduce dimension ", i, + " with size 0"); + } + if (keep_dims) { + dims.emplace_back(c->MakeDim(1)); + } + } else { + dims.emplace_back(c->Dim(input, i)); + } + } + + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + Status ConcatShape(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index b828b23dfe7..b4dc1fb1547 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -178,6 +178,10 @@ Status UnknownShape(shape_inference::InferenceContext* c); // Shape function for reduction operations. Status ReductionShape(shape_inference::InferenceContext* c); +// Shape function for reduction operations where an empty reduction indices +// vector means to reduce all. +Status ReductionShapeForReduceJoin(shape_inference::InferenceContext* c); + // Shape function for concat operations. Status ConcatShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 570ac28127a..334d673c787 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -739,5 +739,64 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { INFER_OK(op, "[?,?,?];[2]", "[?,?,?]"); } +TEST(CommonShapeFnsTest, ReduceWithEmptyReductionIndices_ShapeFn) { + ShapeInferenceTestOp op("ReduceJoin"); + op.input_tensors.resize(2); + + TF_ASSERT_OK(NodeDefBuilder("test", "ReduceJoin") + .Input("input", 0, DT_STRING) + .Input("reduction_indices", 1, DT_INT32) + .Attr("keep_dims", false) + .Finalize(&op.node_def)); + + // Reduction indices not available, so output is unknown. + INFER_OK(op, "[2,4,5];[2]", "?"); + INFER_OK(op, "?;[2]", "?"); + + Tensor indices = test::AsTensor({1, 2}); + op.input_tensors[1] = &indices; + + // Reduction indices available + INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); + + // Wrapped indices + indices = test::AsTensor({-1, -2}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); + + // Scalar + indices = test::AsScalar(0); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]"); + + indices = test::AsScalar(-4); + op.input_tensors[1] = &indices; + INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]"); + + // Empty reduction indices. Unlike Reduce_ShapeFn, this reduces all dims away. + indices = test::AsTensor({}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[0]", "[]"); + + // Keep dims = true + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input("input", 0, DT_STRING) + .Input("reduction_indices", 1, DT_INT32) + .Attr("keep_dims", true) + .Finalize(&op.node_def)); + indices = test::AsTensor({-1, -2}); + op.input_tensors[1] = &indices; + INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]"); + + // input rank is known, but reduction indices are not (with keep_dim=true). + // The output rank is unknown because reduction indices could end up being + // empty and cause it all to be reduced. + op.input_tensors[1] = nullptr; + INFER_OK(op, "[?,?,?];?", "?"); + // TODO(cwhipkey): in this case, it could output [?,?,?], because the shape of + // reduction indices is known to be non-empty. + INFER_OK(op, "[?,?,?];[2]", "?"); +} + } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index 04d7ac53943..63b15204295 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -29,13 +29,13 @@ REGISTER_OP("SymbolicGradient") .Attr("Tout: list(type)") .Attr("f: func") .SetShapeFn([](InferenceContext* c) { - if (c->num_inputs() != c->num_outputs()) { - return errors::InvalidArgument("len(inputs) != len(outputs)"); + if (c->num_inputs() < c->num_outputs()) { + return errors::InvalidArgument("len(inputs) < len(outputs)"); } // Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of // (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its // outputs (dx, dy, dz) are the same as (x, y, z). - for (int i = 0; i < c->num_inputs(); ++i) { + for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->input(i)); } return Status::OK(); diff --git a/tensorflow/core/ops/functional_ops_test.cc b/tensorflow/core/ops/functional_ops_test.cc index 827c34c7b91..37ee301c3bd 100644 --- a/tensorflow/core/ops/functional_ops_test.cc +++ b/tensorflow/core/ops/functional_ops_test.cc @@ -24,22 +24,27 @@ namespace tensorflow { TEST(FunctionalOpsTest, SymbolicGradient_ShapeFn) { ShapeInferenceTestOp op("SymbolicGradient"); - int n = 4; + int num_inputs = 4; + int num_outputs = 3; std::vector src_list; - std::vector type_list; - for (int i = 0; i < n; ++i) { - type_list.emplace_back(DT_FLOAT); + std::vector in_type_list; + std::vector out_type_list; + for (int i = 0; i < num_inputs; ++i) { + in_type_list.emplace_back(DT_FLOAT); src_list.emplace_back("a", 0, DT_FLOAT); } + for (int i = 0; i < num_outputs; ++i) { + out_type_list.emplace_back(DT_FLOAT); + } TF_ASSERT_OK(NodeDefBuilder("test", "SymbolicGradient") .Input(src_list) - .Attr("Tin", type_list) - .Attr("Tout", type_list) + .Attr("Tin", in_type_list) + .Attr("Tout", out_type_list) .Finalize(&op.node_def)); // Inputs transferred to outputs. - INFER_OK(op, "?;?;?;?", "in0;in1;in2;in3"); - INFER_OK(op, "[];[2];?;?", "in0;in1;in2;in3"); + INFER_OK(op, "?;?;?;?", "in0;in1;in2"); + INFER_OK(op, "[];[2];?;?", "in0;in1;in2"); } } // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index a4707007f85..e4e5f32561b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -714,6 +714,38 @@ REGISTER_OP("Betainc") .Input("x: T") .Output("z: T") .Attr("T: {float, double}") + .SetShapeFn([](InferenceContext* c) { + const int num_inputs = 3; + ShapeHandle output = c->UnknownShape(); + int num_scalars = 0; + ShapeHandle some_non_scalar; + for (int i = 0; i < num_inputs; ++i) { + ShapeHandle in = c->input(i); + if (!c->RankKnown(in)) { + some_non_scalar = in; + // An input with unknown rank could be either a scalar (to be + // broadcast) or some other shape. + } else if (c->Rank(in) == 0) { + // Input is a scalar, it will be broadcast to the output shape. + ++num_scalars; + } else { + TF_RETURN_IF_ERROR(c->Merge(output, in, &output)); + some_non_scalar = output; + } + } + + if (num_scalars == num_inputs - 1) { + // If all but one input is known to be a scalar, then output is the + // remaining input. + output = some_non_scalar; + } else if (num_scalars == num_inputs) { + // If all are scalars, output is scalar; pick the first one arbitrarily. + output = c->input(0); + } + + c->set_output(0, output); + return Status::OK(); + }) .Doc(R"doc( Compute the regularized incomplete beta integral \\(I_x(a, b)\\). diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index cb7b50262b0..75e4e304f08 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -431,4 +431,27 @@ TEST(MathOpsTest, ArgOps_ShapeFn) { INFER_ERROR("must be in the range [0, 3)", op, "[2,3,4];[]"); } +TEST(MathOpsTest, Betainc_ShapeFn) { + ShapeInferenceTestOp op("Betainc"); + + INFER_OK(op, "?;?;?", "?"); + INFER_OK(op, "[?,?];?;?", "in0"); + INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]"); + INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]"); + + INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]"); + INFER_OK(op, "[];[];[?,?,3]", "in2"); + + // All but one is a scalar, so use it. + INFER_OK(op, "[];[];?", "in2"); + INFER_OK(op, "[];[];[1,2,3,4]", "in2"); + + // All scalar input; implementation picks in0. + INFER_OK(op, "[];[];[]", "in0"); + + // Non-scalars must match shape. + INFER_ERROR("must be equal", op, "[1,2];[];[1,4]"); + INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); +} + } // end namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 248fb7cbbb3..fac5b20097d 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -95,7 +95,7 @@ REGISTER_OP("ReduceJoin") .Attr("keep_dims: bool = false") .Attr("separator: string = ''") .Output("output: string") - .SetShapeFn(shape_inference::ReductionShape) + .SetShapeFn(shape_inference::ReductionShapeForReduceJoin) .Doc(R"doc( Joins a string Tensor across the given dimensions. diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index a78a1b934f1..d311f45a19e 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -75,7 +75,7 @@ class BetaincTest(tf.test.TestCase): special.betainc(0.1, 0.1, 0.1).astype(np_dt), tf.betainc(0.1, 0.1, 0.1).eval(), rtol=tol, atol=tol) - with self.assertRaisesRegexp(ValueError, "Shapes .* are not compatible"): + with self.assertRaisesRegexp(ValueError, "must be equal"): tf.betainc(0.5, [0.5], [[0.5]]) with self.test_session(use_gpu=self.use_gpu): diff --git a/tensorflow/python/kernel_tests/reduce_join_op_test.py b/tensorflow/python/kernel_tests/reduce_join_op_test.py index c4ef55ef028..695d442382a 100644 --- a/tensorflow/python/kernel_tests/reduce_join_op_test.py +++ b/tensorflow/python/kernel_tests/reduce_join_op_test.py @@ -258,7 +258,7 @@ class ReduceJoinTest(UnicodeTestCase): def testInvalidReductionIndices(self): with self.test_session(): - with self.assertRaisesRegexp(ValueError, "scalar"): + with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"): tf.reduce_join(inputs="", reduction_indices=0) with self.assertRaisesRegexp(ValueError, "Invalid reduction dimension -3"): diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index a68c278c140..38b672f47cf 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -30,6 +30,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -576,9 +577,4 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, return output_pack(results_flat) -@ops.RegisterShape("SymbolicGradient") -def _symbolic_gradient_shape(op): - # Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of - # (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its - # outputs (dx, dy, dz) are the same as (x, y, z). - return [op.inputs[i].get_shape() for i in range(len(op.outputs))] +ops.RegisterShape("SymbolicGradient")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3941783c1eb..5516779fe71 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1842,21 +1842,7 @@ def _BroadcastShape(op): op.inputs[1].get_shape())] -@ops.RegisterShape("Betainc") -def _BetaincOpShape(op): # pylint: disable=invalid-name - """Shape function for BetaincOp.""" - a_shape = op.inputs[0].get_shape() - b_shape = op.inputs[1].get_shape() - x_shape = op.inputs[2].get_shape() - merged_shape = tensor_shape.TensorShape(None) - for shape in (a_shape, b_shape, x_shape): - if shape.ndims != 0: - merged_shape = merged_shape.merge_with(shape) - # Scalars get broadcasted; non-scalar shapes must all match. - # Output will be the merged non-scalar shape, if any. - return [merged_shape if merged_shape.ndims is not None else a_shape] - - +ops.RegisterShape("Betainc")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("SparseDenseCwiseMul")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("SparseDenseCwiseDiv")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("SparseDenseCwiseAdd")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 1df379491d1..ad8463a30c4 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -147,8 +147,7 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="", # NOTE(mrry): Shapes are conditionally set in the Python wrapper. -ops.RegisterShape("Variable")(common_shapes.unknown_shape) - +ops.RegisterShape("Variable")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("IsVariableInitialized")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("TemporaryVariable")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("DestroyTemporaryVariable")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index bd2cd3bd49c..bda736ee180 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -51,8 +51,6 @@ import six from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util # pylint: disable=unused-import from tensorflow.python.ops import gen_string_ops @@ -133,47 +131,7 @@ ops.RegisterShape("DecodeBase64")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("ReduceJoin") def _ReduceJoinShape(op): - """Shape function for the ReduceJoin op.""" - reduction_indices = tensor_util.constant_value(op.inputs[1]) - if reduction_indices is None: - return [tensor_shape.unknown_shape()] - - input_shape = op.inputs[0].get_shape() - keep_dims = op.get_attr("keep_dims") - - if input_shape.ndims is None: - return [tensor_shape.unknown_shape()] - - if input_shape.ndims == 0: - raise ValueError("Input string tensor cannot be a scalar.") - - true_indices = set() - for reduction_index in np.ravel(reduction_indices): - if (reduction_index < -input_shape.ndims or - reduction_index >= input_shape.ndims): - raise ValueError("Invalid reduction dimension %d for input with %d " - "dimensions" % (reduction_index, input_shape.ndims)) - - true_index = reduction_index % input_shape.ndims - if true_index in true_indices: - raise ValueError("Duplicate reduction index %d." % reduction_index) - - if input_shape.dims[true_index] == 0: - raise ValueError("Cannot reduce dimension %d with size 0." % - reduction_index) - - true_indices.add(true_index) - - returned_dims = [] - reduce_all = reduction_indices.size == 0 - for i, dim in enumerate(input_shape.dims): - if reduce_all or i in true_indices: - if keep_dims: - returned_dims.append(1) - else: - returned_dims.append(dim) - - return [tensor_shape.TensorShape(returned_dims)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1]) ops.RegisterShape("StringJoin")(common_shapes.call_cpp_shape_fn)