Add control dependencies to the correct graph when simplifying packing ops.
PiperOrigin-RevId: 198622727
This commit is contained in:
parent
5810723cc8
commit
5c751fe8d7
@ -7101,6 +7101,14 @@ class CohenKappaTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
metrics.cohen_kappa(labels, invalid_predictions, 3)
|
||||
|
||||
def testConditionalPackingOptimization(self):
|
||||
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
|
||||
values, update_op = metric_ops.streaming_concat(placeholder)
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
for feed in range(10):
|
||||
sess.run(update_op, feed_dict={placeholder: [feed]})
|
||||
print(sess.run(values))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -2171,7 +2171,7 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
||||
}
|
||||
// Add a control dependency to make sure axis_node is in the right frame.
|
||||
const string ctrl_dep = ConstantFolding::AddControlDependency(
|
||||
node->input(0), graph_, node_map_.get());
|
||||
node->input(0), optimized_graph, node_map_.get());
|
||||
axis_node->add_input(ctrl_dep);
|
||||
axis_node->set_device(node->device());
|
||||
node->set_op("ExpandDims");
|
||||
|
Loading…
x
Reference in New Issue
Block a user