Make DynamicStitch's shape function handle the case where all inputs are
constant. PiperOrigin-RevId: 170637740
This commit is contained in:
parent
418fac23f1
commit
af8da61ad4
@ -133,17 +133,23 @@ num_partitions: The number of partitions to output.
|
||||
namespace {
|
||||
|
||||
Status DynamicStitchShapeFunction(InferenceContext* c) {
|
||||
int64 num_partitions;
|
||||
int32 num_partitions;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
|
||||
|
||||
bool all_indices_constant = true;
|
||||
int32 max_index = 0;
|
||||
ShapeHandle extra_shape = c->UnknownShape();
|
||||
for (int64 i = 0; i < num_partitions; ++i) {
|
||||
for (int i = 0; i < num_partitions; ++i) {
|
||||
const Tensor* indices_t = c->input_tensor(i);
|
||||
if (indices_t == nullptr) {
|
||||
all_indices_constant = false;
|
||||
}
|
||||
|
||||
ShapeHandle indices_shape = c->input(i);
|
||||
ShapeHandle data_shape = c->input(i + num_partitions);
|
||||
if (!c->RankKnown(indices_shape)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 indices_rank = c->Rank(indices_shape);
|
||||
|
||||
// Assert that data_shape starts with indices_shape.
|
||||
@ -155,9 +161,21 @@ Status DynamicStitchShapeFunction(InferenceContext* c) {
|
||||
ShapeHandle rest;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
|
||||
TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
|
||||
|
||||
if (indices_t != nullptr) {
|
||||
// The length is based on the highest index from flattened indices.
|
||||
const int32* indices = indices_t->flat<int32>().data();
|
||||
int64 count = indices_t->NumElements();
|
||||
for (int64 i = 0; i < count; ++i) {
|
||||
if (indices[i] > max_index) {
|
||||
max_index = indices[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ShapeHandle output_shape = c->Vector(c->UnknownDim());
|
||||
ShapeHandle output_shape = c->Vector(
|
||||
all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim());
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape));
|
||||
c->set_output(0, output_shape);
|
||||
return Status::OK();
|
||||
|
@ -126,8 +126,6 @@ TEST(DataFlowOpsTest, DynamicStitch) {
|
||||
.Attr("N", 2)
|
||||
.Finalize(&op.node_def));
|
||||
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
|
||||
|
||||
// Bad prefix for the second data input.
|
||||
INFER_ERROR("Dimensions must be equal, but are 10 and 5", op,
|
||||
"[2,3];[5,6];[2,3,4,5];[10,11,4,5]");
|
||||
@ -135,6 +133,32 @@ TEST(DataFlowOpsTest, DynamicStitch) {
|
||||
// Inconsistent suffix dimensions
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 4 and 13", op,
|
||||
"[2,3];[5,6];[2,3,4,5];[5,6,13,14]");
|
||||
|
||||
// Good case, but no known input tensors.
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
|
||||
|
||||
// 1 known input tensors, not enough to change answer.
|
||||
Tensor tensor_2 = test::AsTensor<int32>(
|
||||
std::vector<int32>{2, 4, 6, 0, 10, 11}, TensorShape({2, 3}));
|
||||
Tensor tensor_5 = test::AsTensor<int32>(
|
||||
std::vector<int32>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
1000, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
||||
TensorShape({5, 6}));
|
||||
op.input_tensors.push_back(nullptr);
|
||||
op.input_tensors.push_back(&tensor_5);
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
|
||||
|
||||
op.input_tensors[0] = &tensor_2;
|
||||
op.input_tensors[1] = nullptr;
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
|
||||
INFER_OK(op, "[2,3];?;[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
|
||||
|
||||
op.input_tensors[1] = &tensor_5;
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[1001,d2_2,d2_3]");
|
||||
|
||||
tensor_2.flat<int32>()(3) = 10000;
|
||||
INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[10001,d2_2,d2_3]");
|
||||
}
|
||||
|
||||
TEST(DataFlowOpsTest, ParallelDynamicStitch) {
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
|
||||
@ -42,8 +43,18 @@ class DynamicStitchTestBase(object):
|
||||
stitched_t = self.stitch_op(indices[::step], data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([40, 60][::step], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testShapeInferenceForScalarWithNonConstantIndices(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
indices = [array_ops.placeholder(dtype=dtypes.int32),
|
||||
constant_op.constant(1)]
|
||||
data = [constant_op.constant(40), constant_op.constant(60)]
|
||||
for step in -1, 1:
|
||||
stitched_t = self.stitch_op(indices[::step], data)
|
||||
# Dimension 0 is max(flatten(indices))+1, but the first indices input is
|
||||
# not a constant tensor, so we can only infer it as a vector of unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
|
||||
@ -59,10 +70,8 @@ class DynamicStitchTestBase(object):
|
||||
stitched_t = self.stitch_op(indices, data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([8], stitched_t.get_shape().as_list())
|
||||
|
||||
def testOneListOneDimensional(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
@ -71,10 +80,8 @@ class DynamicStitchTestBase(object):
|
||||
stitched_t = self.stitch_op(indices, data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([8], stitched_t.get_shape().as_list())
|
||||
|
||||
def testSimpleTwoDimensional(self):
|
||||
with self.test_session(use_gpu=True):
|
||||
@ -91,10 +98,8 @@ class DynamicStitchTestBase(object):
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
|
||||
[50, 51], [60, 61], [70, 71]], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a matrix with 2 columns and
|
||||
# some unknown number of rows.
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testHigherRank(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
@ -111,7 +116,7 @@ class DynamicStitchTestBase(object):
|
||||
stitched_val = stitched_t.eval()
|
||||
correct = 10 * np.arange(7)[:, None] + [1, 2]
|
||||
self.assertAllEqual(correct, stitched_val)
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
self.assertEqual([7, 2], stitched_t.get_shape().as_list())
|
||||
# Test gradients
|
||||
stitched_grad = 7 * stitched_val
|
||||
grads = gradients_impl.gradients(stitched_t, indices + data,
|
||||
@ -186,10 +191,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testHigherRank(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
@ -208,7 +211,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||
stitched_val = stitched_t.eval()
|
||||
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
|
||||
self.assertAllEqual(correct, stitched_val)
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
self.assertEqual([7, 2], stitched_t.get_shape().as_list())
|
||||
# Test gradients
|
||||
stitched_grad = 7 * stitched_val
|
||||
grads = gradients_impl.gradients(stitched_t, indices + data,
|
||||
@ -226,10 +229,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
|
||||
stitched_val = stitched_t.eval()
|
||||
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
|
||||
# Dimension 0 is determined by the max index in indices, so we
|
||||
# can only infer that the output is a vector of some unknown
|
||||
# length.
|
||||
self.assertEqual([None], stitched_t.get_shape().as_list())
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testHigherRankGPU(self):
|
||||
with self.test_session() as sess:
|
||||
@ -246,7 +247,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
|
||||
stitched_val = stitched_t.eval()
|
||||
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
|
||||
self.assertAllEqual(correct, stitched_val)
|
||||
self.assertEqual([None, 2], stitched_t.get_shape().as_list())
|
||||
self.assertEqual([7, 2], stitched_t.get_shape().as_list())
|
||||
# Test gradients
|
||||
stitched_grad = 7 * stitched_val
|
||||
grads = gradients_impl.gradients(stitched_t, indices + data,
|
||||
|
Loading…
x
Reference in New Issue
Block a user