Add control dependencies to the correct graph when simplifying packing ops.
PiperOrigin-RevId: 198622727
This commit is contained in:
parent
5810723cc8
commit
5c751fe8d7
@ -7101,6 +7101,14 @@ class CohenKappaTest(test.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
metrics.cohen_kappa(labels, invalid_predictions, 3)
|
metrics.cohen_kappa(labels, invalid_predictions, 3)
|
||||||
|
|
||||||
|
def testConditionalPackingOptimization(self):
|
||||||
|
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
|
||||||
|
values, update_op = metric_ops.streaming_concat(placeholder)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(variables.local_variables_initializer())
|
||||||
|
for feed in range(10):
|
||||||
|
sess.run(update_op, feed_dict={placeholder: [feed]})
|
||||||
|
print(sess.run(values))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2171,7 +2171,7 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
}
|
}
|
||||||
// Add a control dependency to make sure axis_node is in the right frame.
|
// Add a control dependency to make sure axis_node is in the right frame.
|
||||||
const string ctrl_dep = ConstantFolding::AddControlDependency(
|
const string ctrl_dep = ConstantFolding::AddControlDependency(
|
||||||
node->input(0), graph_, node_map_.get());
|
node->input(0), optimized_graph, node_map_.get());
|
||||||
axis_node->add_input(ctrl_dep);
|
axis_node->add_input(ctrl_dep);
|
||||||
axis_node->set_device(node->device());
|
axis_node->set_device(node->device());
|
||||||
node->set_op("ExpandDims");
|
node->set_op("ExpandDims");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user