Implement c++ gradients for data_flow operators.
Closes #12856 PiperOrigin-RevId: 167949574
This commit is contained in:
parent
2b5011625b
commit
d27db78cd0
@ -91,6 +91,7 @@ cc_library(
|
||||
name = "grad_ops",
|
||||
deps = [
|
||||
":array_grad",
|
||||
":data_flow_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
],
|
||||
@ -363,6 +364,36 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "data_flow_grad",
|
||||
srcs = ["gradients/data_flow_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":cc_ops_internal",
|
||||
":grad_op_registry",
|
||||
":gradients",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients_data_flow_grad_test",
|
||||
size = "small",
|
||||
srcs = ["gradients/data_flow_grad_test.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":data_flow_grad",
|
||||
":grad_op_registry",
|
||||
":grad_testutil",
|
||||
":gradient_checker",
|
||||
":testutil",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
|
155
tensorflow/cc/gradients/data_flow_grad.cc
Normal file
155
tensorflow/cc/gradients/data_flow_grad.cc
Normal file
@ -0,0 +1,155 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||
#include "tensorflow/cc/ops/data_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
REGISTER_NO_GRADIENT_OP("Queue");
|
||||
REGISTER_NO_GRADIENT_OP("QueueEnqueue");
|
||||
REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
|
||||
REGISTER_NO_GRADIENT_OP("QueueDequeue");
|
||||
REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
|
||||
REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
|
||||
REGISTER_NO_GRADIENT_OP("QueueClose");
|
||||
REGISTER_NO_GRADIENT_OP("QueueSize");
|
||||
REGISTER_NO_GRADIENT_OP("Stack");
|
||||
REGISTER_NO_GRADIENT_OP("StackPush");
|
||||
REGISTER_NO_GRADIENT_OP("StackPop");
|
||||
REGISTER_NO_GRADIENT_OP("StackClose");
|
||||
REGISTER_NO_GRADIENT_OP("GetSessionHandle");
|
||||
REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
|
||||
REGISTER_NO_GRADIENT_OP("GetSessionTensor");
|
||||
REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
|
||||
|
||||
Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// DynamicPartition only moves input values into various positions
|
||||
// in the output, so the gradient operation only has to map incoming
|
||||
// gradients into their input source locations.
|
||||
// running example:
|
||||
// data = [10, 20, 30, 40, 50]
|
||||
// partitions = [0, 0, 1, 1, 0]
|
||||
// num_partitions = 2
|
||||
// dynamic_partition(data, partitions, num_partitions) = {
|
||||
// [10, 20, 50],
|
||||
// [30, 40]
|
||||
// }
|
||||
// grads = {
|
||||
// [g1, g2, g3],
|
||||
// [g4, g5]
|
||||
// }
|
||||
// The desired propagation of the gradients back to the data inputs is:
|
||||
// [g1, g2, g4, g5, g3]
|
||||
auto data = op.input(0);
|
||||
auto partitions = op.input(1);
|
||||
int32 num_partitions;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
|
||||
|
||||
// Note: the shape of the partitions is a prefix of the data shape.
|
||||
// shape(partitions) = [5]
|
||||
auto partitions_shape = Shape(scope, partitions);
|
||||
// We now create a partitions-shaped tensor with integers from
|
||||
// [0..size(partitions)) This will be dynamic_partitioned with the
|
||||
// input parameters, providing the destination index for a given
|
||||
// source item.
|
||||
// partitions_size = prod([5]) = 5
|
||||
// reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
|
||||
auto zero = Const(scope, 0);
|
||||
auto one = Const(scope, 1);
|
||||
auto original_indices = Reshape(
|
||||
scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
|
||||
partitions_shape);
|
||||
// dynamic_partition(
|
||||
// [0, 1, 2, 3, 4],
|
||||
// [0, 0, 1, 1, 0], 2)
|
||||
// = { [0, 1, 4],
|
||||
// [2, 3] }
|
||||
auto partitioned_indices =
|
||||
DynamicPartition(scope, original_indices, partitions, num_partitions);
|
||||
|
||||
// Invert these indices with dynamic_stitch to map the incoming
|
||||
// gradients to their source inputs.
|
||||
// dynamic_stitch(
|
||||
// { [0, 1, 4], [2, 3] },
|
||||
// { [g1, g2, g3], [g4, g5] })
|
||||
// = [g1, g2, g4, g5, g3]
|
||||
auto reconstructed =
|
||||
DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
|
||||
// reshape back into a data-shaped tensor to propagate gradients for the data
|
||||
// input.
|
||||
grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
|
||||
// Stop propagation along the partitions input
|
||||
grad_outputs->push_back(NoGradient());
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
|
||||
|
||||
Status DynamicStitchGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// Running example:
|
||||
// indices = {2, [1, 0]}
|
||||
// data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
|
||||
// out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
|
||||
// grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
|
||||
|
||||
// indices and data are two equal-sized lists passed
|
||||
// into DynamicStitch.
|
||||
// num_values = 2
|
||||
int32 num_values = op.num_inputs() / 2;
|
||||
|
||||
// Stop propagation along the indices list
|
||||
for (int32 i = 0; i < num_values; i++) {
|
||||
grad_outputs->push_back(NoGradient());
|
||||
}
|
||||
|
||||
// DynamicStitch shuffles its data to the output (using items in
|
||||
// indices) so the gradient propagated to a given data input simply
|
||||
// selects the gradient for its output position.
|
||||
for (int32 i = 0; i < num_values; i++) {
|
||||
// index has the destination positions for the i'th data
|
||||
// element. We cast it into an int32 if necessary, so we can use
|
||||
// it from a Gather op.
|
||||
// i = 0: index = 2
|
||||
// i = 1: index = [1, 0]
|
||||
auto index = op.input(i);
|
||||
if (index.type() != DT_INT32) {
|
||||
index = Cast(scope, index, DT_INT32);
|
||||
}
|
||||
// Gather the index specified locations in the gradient and
|
||||
// propagate it as the gradient for the i'th data item.
|
||||
// i = 0: gather(grad, 2) = [g_5, g_6]
|
||||
// i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
|
||||
grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
|
||||
}
|
||||
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
|
||||
REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
69
tensorflow/cc/gradients/data_flow_grad_test.cc
Normal file
69
tensorflow/cc/gradients/data_flow_grad_test.cc
Normal file
@ -0,0 +1,69 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||
#include "tensorflow/cc/framework/testutil.h"
|
||||
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
class DataFlowGradTest : public ::testing::Test {
|
||||
protected:
|
||||
DataFlowGradTest() : scope_(Scope::NewRootScope()) {}
|
||||
|
||||
void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
|
||||
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||
TF_ASSERT_OK(scope_.status());
|
||||
float max_error;
|
||||
TF_ASSERT_OK(
|
||||
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||
EXPECT_LT(max_error, 1e-4);
|
||||
}
|
||||
|
||||
Scope scope_;
|
||||
};
|
||||
|
||||
TEST_F(DataFlowGradTest, DynamicPartitionGrad) {
|
||||
TensorShape data_shape({2, 3, 2});
|
||||
auto data = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(data_shape));
|
||||
auto partitions = Const(scope_, {{2, 1, 0}, {1, 2, 0}});
|
||||
auto y = DynamicPartition(scope_, data, partitions, 3);
|
||||
TensorShape partition_shape({2, 2});
|
||||
RunTest({data}, {data_shape}, y.outputs,
|
||||
{partition_shape, partition_shape, partition_shape});
|
||||
}
|
||||
|
||||
TEST_F(DataFlowGradTest, DynamicStitchGrad) {
|
||||
TensorShape d1_shape({2});
|
||||
TensorShape d2_shape({2, 2});
|
||||
std::vector<Output> indices = {Const(scope_, 2), Const(scope_, {1, 0})};
|
||||
std::vector<Output> data = {
|
||||
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d1_shape)),
|
||||
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d2_shape))};
|
||||
auto y = DynamicStitch(scope_, indices, data);
|
||||
TensorShape y_shape({3, 2});
|
||||
RunTest(data, {d1_shape, d2_shape}, {y}, {y_shape});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user