Implement c++ gradients for data_flow operators.
Closes #12856 PiperOrigin-RevId: 167949574
This commit is contained in:
parent
2b5011625b
commit
d27db78cd0
@ -91,6 +91,7 @@ cc_library(
|
|||||||
name = "grad_ops",
|
name = "grad_ops",
|
||||||
deps = [
|
deps = [
|
||||||
":array_grad",
|
":array_grad",
|
||||||
|
":data_flow_grad",
|
||||||
":math_grad",
|
":math_grad",
|
||||||
":nn_grad",
|
":nn_grad",
|
||||||
],
|
],
|
||||||
@ -363,6 +364,36 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "data_flow_grad",
|
||||||
|
srcs = ["gradients/data_flow_grad.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":cc_ops_internal",
|
||||||
|
":grad_op_registry",
|
||||||
|
":gradients",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gradients_data_flow_grad_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["gradients/data_flow_grad_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":cc_ops",
|
||||||
|
":data_flow_grad",
|
||||||
|
":grad_op_registry",
|
||||||
|
":grad_testutil",
|
||||||
|
":gradient_checker",
|
||||||
|
":testutil",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_gen_op_wrappers_cc(
|
tf_gen_op_wrappers_cc(
|
||||||
name = "cc_ops",
|
name = "cc_ops",
|
||||||
op_lib_names = [
|
op_lib_names = [
|
||||||
|
155
tensorflow/cc/gradients/data_flow_grad.cc
Normal file
155
tensorflow/cc/gradients/data_flow_grad.cc
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/data_flow_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/data_flow_ops_internal.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradients.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace ops {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
REGISTER_NO_GRADIENT_OP("Queue");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueEnqueue");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueEnqueueMany");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueDequeue");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueDequeueMany");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueDequeueUpTo");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueClose");
|
||||||
|
REGISTER_NO_GRADIENT_OP("QueueSize");
|
||||||
|
REGISTER_NO_GRADIENT_OP("Stack");
|
||||||
|
REGISTER_NO_GRADIENT_OP("StackPush");
|
||||||
|
REGISTER_NO_GRADIENT_OP("StackPop");
|
||||||
|
REGISTER_NO_GRADIENT_OP("StackClose");
|
||||||
|
REGISTER_NO_GRADIENT_OP("GetSessionHandle");
|
||||||
|
REGISTER_NO_GRADIENT_OP("GetSessionHandleV2");
|
||||||
|
REGISTER_NO_GRADIENT_OP("GetSessionTensor");
|
||||||
|
REGISTER_NO_GRADIENT_OP("DeleteSessionTensor");
|
||||||
|
|
||||||
|
Status DynamicPartitionGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
// DynamicPartition only moves input values into various positions
|
||||||
|
// in the output, so the gradient operation only has to map incoming
|
||||||
|
// gradients into their input source locations.
|
||||||
|
// running example:
|
||||||
|
// data = [10, 20, 30, 40, 50]
|
||||||
|
// partitions = [0, 0, 1, 1, 0]
|
||||||
|
// num_partitions = 2
|
||||||
|
// dynamic_partition(data, partitions, num_partitions) = {
|
||||||
|
// [10, 20, 50],
|
||||||
|
// [30, 40]
|
||||||
|
// }
|
||||||
|
// grads = {
|
||||||
|
// [g1, g2, g3],
|
||||||
|
// [g4, g5]
|
||||||
|
// }
|
||||||
|
// The desired propagation of the gradients back to the data inputs is:
|
||||||
|
// [g1, g2, g4, g5, g3]
|
||||||
|
auto data = op.input(0);
|
||||||
|
auto partitions = op.input(1);
|
||||||
|
int32 num_partitions;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetNodeAttr(op.node()->attrs(), "num_partitions", &num_partitions));
|
||||||
|
|
||||||
|
// Note: the shape of the partitions is a prefix of the data shape.
|
||||||
|
// shape(partitions) = [5]
|
||||||
|
auto partitions_shape = Shape(scope, partitions);
|
||||||
|
// We now create a partitions-shaped tensor with integers from
|
||||||
|
// [0..size(partitions)) This will be dynamic_partitioned with the
|
||||||
|
// input parameters, providing the destination index for a given
|
||||||
|
// source item.
|
||||||
|
// partitions_size = prod([5]) = 5
|
||||||
|
// reshape(range(partitions_size), [5]) = [0, 1, 2, 3, 4]
|
||||||
|
auto zero = Const(scope, 0);
|
||||||
|
auto one = Const(scope, 1);
|
||||||
|
auto original_indices = Reshape(
|
||||||
|
scope, Range(scope, zero, Prod(scope, partitions_shape, zero), one),
|
||||||
|
partitions_shape);
|
||||||
|
// dynamic_partition(
|
||||||
|
// [0, 1, 2, 3, 4],
|
||||||
|
// [0, 0, 1, 1, 0], 2)
|
||||||
|
// = { [0, 1, 4],
|
||||||
|
// [2, 3] }
|
||||||
|
auto partitioned_indices =
|
||||||
|
DynamicPartition(scope, original_indices, partitions, num_partitions);
|
||||||
|
|
||||||
|
// Invert these indices with dynamic_stitch to map the incoming
|
||||||
|
// gradients to their source inputs.
|
||||||
|
// dynamic_stitch(
|
||||||
|
// { [0, 1, 4], [2, 3] },
|
||||||
|
// { [g1, g2, g3], [g4, g5] })
|
||||||
|
// = [g1, g2, g4, g5, g3]
|
||||||
|
auto reconstructed =
|
||||||
|
DynamicStitch(scope, partitioned_indices.outputs, grad_inputs);
|
||||||
|
// reshape back into a data-shaped tensor to propagate gradients for the data
|
||||||
|
// input.
|
||||||
|
grad_outputs->push_back(Reshape(scope, reconstructed, Shape(scope, data)));
|
||||||
|
// Stop propagation along the partitions input
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("DynamicPartition", DynamicPartitionGrad);
|
||||||
|
|
||||||
|
Status DynamicStitchGrad(const Scope& scope, const Operation& op,
|
||||||
|
const std::vector<Output>& grad_inputs,
|
||||||
|
std::vector<Output>* grad_outputs) {
|
||||||
|
// Running example:
|
||||||
|
// indices = {2, [1, 0]}
|
||||||
|
// data = {[d_1, d_2], [[d_3, d_4], [d_5, d_6]]}
|
||||||
|
// out = [[d_5, d_6], [d_3, d_4], [d_1, d_2]]
|
||||||
|
// grad = [[g_1, g_2], [g_3, g_4], [g_5, g_6]]
|
||||||
|
|
||||||
|
// indices and data are two equal-sized lists passed
|
||||||
|
// into DynamicStitch.
|
||||||
|
// num_values = 2
|
||||||
|
int32 num_values = op.num_inputs() / 2;
|
||||||
|
|
||||||
|
// Stop propagation along the indices list
|
||||||
|
for (int32 i = 0; i < num_values; i++) {
|
||||||
|
grad_outputs->push_back(NoGradient());
|
||||||
|
}
|
||||||
|
|
||||||
|
// DynamicStitch shuffles its data to the output (using items in
|
||||||
|
// indices) so the gradient propagated to a given data input simply
|
||||||
|
// selects the gradient for its output position.
|
||||||
|
for (int32 i = 0; i < num_values; i++) {
|
||||||
|
// index has the destination positions for the i'th data
|
||||||
|
// element. We cast it into an int32 if necessary, so we can use
|
||||||
|
// it from a Gather op.
|
||||||
|
// i = 0: index = 2
|
||||||
|
// i = 1: index = [1, 0]
|
||||||
|
auto index = op.input(i);
|
||||||
|
if (index.type() != DT_INT32) {
|
||||||
|
index = Cast(scope, index, DT_INT32);
|
||||||
|
}
|
||||||
|
// Gather the index specified locations in the gradient and
|
||||||
|
// propagate it as the gradient for the i'th data item.
|
||||||
|
// i = 0: gather(grad, 2) = [g_5, g_6]
|
||||||
|
// i = 1: gather(grad, [1, 0]) = [[g_3, g_4], [g_1, g_2]]
|
||||||
|
grad_outputs->push_back(Gather(scope, grad_inputs[0], index));
|
||||||
|
}
|
||||||
|
|
||||||
|
return scope.status();
|
||||||
|
}
|
||||||
|
REGISTER_GRADIENT_OP("DynamicStitch", DynamicStitchGrad);
|
||||||
|
REGISTER_GRADIENT_OP("ParallelDynamicStitch", DynamicStitchGrad);
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tensorflow
|
69
tensorflow/cc/gradients/data_flow_grad_test.cc
Normal file
69
tensorflow/cc/gradients/data_flow_grad_test.cc
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||||
|
#include "tensorflow/cc/framework/gradient_checker.h"
|
||||||
|
#include "tensorflow/cc/framework/testutil.h"
|
||||||
|
#include "tensorflow/cc/gradients/grad_testutil.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using namespace ops; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class DataFlowGradTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
DataFlowGradTest() : scope_(Scope::NewRootScope()) {}
|
||||||
|
|
||||||
|
void RunTest(const OutputList& xs, const std::vector<TensorShape>& x_shapes,
|
||||||
|
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
|
||||||
|
TF_ASSERT_OK(scope_.status());
|
||||||
|
float max_error;
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
|
||||||
|
EXPECT_LT(max_error, 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scope scope_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(DataFlowGradTest, DynamicPartitionGrad) {
|
||||||
|
TensorShape data_shape({2, 3, 2});
|
||||||
|
auto data = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(data_shape));
|
||||||
|
auto partitions = Const(scope_, {{2, 1, 0}, {1, 2, 0}});
|
||||||
|
auto y = DynamicPartition(scope_, data, partitions, 3);
|
||||||
|
TensorShape partition_shape({2, 2});
|
||||||
|
RunTest({data}, {data_shape}, y.outputs,
|
||||||
|
{partition_shape, partition_shape, partition_shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DataFlowGradTest, DynamicStitchGrad) {
|
||||||
|
TensorShape d1_shape({2});
|
||||||
|
TensorShape d2_shape({2, 2});
|
||||||
|
std::vector<Output> indices = {Const(scope_, 2), Const(scope_, {1, 0})};
|
||||||
|
std::vector<Output> data = {
|
||||||
|
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d1_shape)),
|
||||||
|
Placeholder(scope_, DT_FLOAT, Placeholder::Shape(d2_shape))};
|
||||||
|
auto y = DynamicStitch(scope_, indices, data);
|
||||||
|
TensorShape y_shape({3, 2});
|
||||||
|
RunTest(data, {d1_shape, d2_shape}, {y}, {y_shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user