C++ Gradients: Port a couple of gradient functions to the new C++ graph building API.

Change: 129016020
This commit is contained in:
A. Unique TensorFlower 2016-08-01 11:52:49 -08:00 committed by TensorFlower Gardener
parent 01336ac329
commit 84391613c3
7 changed files with 483 additions and 2 deletions

View File

@ -73,6 +73,46 @@ tf_cc_test(
],
)
cc_library(
name = "grad_op_registry",
srcs = ["framework/grad_op_registry.cc"],
hdrs = ["framework/grad_op_registry.h"],
deps = [
":ops",
":scope",
],
)
cc_library(
name = "math_grad",
srcs = ["gradients/math_grad.cc"],
deps = [
":cc_ops",
":grad_op_registry",
":ops",
":scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
],
)
tf_cc_test(
name = "gradients/math_grad_test",
deps = [
":cc_ops",
":grad_op_registry",
":math_grad",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [

View File

@ -0,0 +1,42 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
// static
GradOpRegistry* GradOpRegistry::Global() {
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
return grad_op_registry;
}
bool GradOpRegistry::Register(const string& op, GradFunc func) {
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
return true;
}
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
auto iter = registry_.find(op);
if (iter == registry_.end()) {
return errors::NotFound("No gradient defined for op: ", op);
}
*func = iter->second;
return Status::OK();
}
} // end namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,75 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
#include <unordered_map>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
namespace tensorflow {
namespace ops {
// GradFunc is the signature for all gradient functions in GradOpRegistry.
// Implementations should add operations to compute the gradient outputs of 'op'
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs);
// GradOpRegistry maintains a static registry of gradient functions.
// Gradient functions are indexed in the registry by the forward op name (i.e.
// "MatMul" -> MatMulGrad func).
class GradOpRegistry {
public:
// Registers 'func' as the the gradient function for 'op'.
// Returns true if registration was succesful, check fails otherwise.
bool Register(const string& op, GradFunc func);
// Sets 'func' to the gradient function for 'op' and returns Status OK if
// the gradient function for 'op' exists in the registry.
// Note that 'func' can be null for ops that have registered no-gradient with
// the registry.
// Returns error status otherwise.
Status Lookup(const string& op, GradFunc* func);
// Returns a pointer to the global gradient function registry.
static GradOpRegistry* Global();
private:
std::unordered_map<string, GradFunc> registry_;
};
} // namespace ops
// Macros used to define gradient functions for ops.
#define REGISTER_GRADIENT_OP(name, fn) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
#define REGISTER_NO_GRADIENT_OP(name) \
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
static bool unused_ret_val_##ctr = \
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_

View File

@ -18,6 +18,44 @@ limitations under the License.
namespace tensorflow {
namespace ops {
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
Output Operation::input(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_inputs());
// Handle the case where the input was unknown at the time this
// Operation was constructed.
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
for (const Edge* e : node_->in_edges()) {
if (e->IsControlEdge()) continue;
if (e->dst_input() == i) {
return Output(e->src(), e->src_output());
}
}
}
return Output(inputs_[i].first, inputs_[i].second);
}
Output Operation::output(int i) const {
CHECK_NOTNULL(node_);
CHECK_GE(i, 0);
CHECK_LT(i, node_->num_outputs());
return Output(node_, i);
}
Operation::Inputs Operation::GetInputs(Node* node) {
Operation::Inputs inputs;
if (node != nullptr) {
inputs.resize(node->num_inputs(), {nullptr, -1});
for (const Edge* e : node->in_edges()) {
if (e->IsControlEdge()) continue;
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
}
}
return inputs;
}
Input::Initializer::Initializer(
const std::initializer_list<Input::Initializer>& v) {
if (v.size() < 1) {

View File

@ -27,17 +27,29 @@ limitations under the License.
namespace tensorflow {
namespace ops {
class Output;
// Represents a node in the computation graph.
class Operation {
public:
Operation() : node_(nullptr) {}
explicit Operation(Node* n) : node_(n) {}
explicit Operation(Node* n);
int num_inputs() const { return node_->num_inputs(); }
DataType input_type(int o) const { return node_->input_type(o); }
Output input(int i) const;
int num_outputs() const { return node_->num_outputs(); }
DataType output_type(int o) const { return node_->output_type(o); }
Output output(int i) const;
Node* node() const { return node_; }
private:
typedef std::vector<std::pair<Node*, int64>> Inputs;
static Inputs GetInputs(Node* node);
Inputs inputs_;
Node* node_;
};
@ -81,7 +93,7 @@ class Input {
tensor = t;
}
explicit Initializer(const Tensor& t) : tensor(t) {}
Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
// Construct from a scalar value and an explicit shape
template <typename T, typename = typename std::enable_if<

View File

@ -0,0 +1,91 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/framework/grad_op_registry.h"
namespace tensorflow {
namespace ops {
namespace {
// TODO(andydavis) Move this to a more appropriate file.
REGISTER_NO_GRADIENT_OP("Const");
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
const Output& x1, const bool adj_x1, const Output& y0,
const bool adj_y0, const Output& y1, const bool adj_y1,
std::vector<Output>* grad_outputs) {
auto dx =
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
grad_outputs->push_back(dx);
auto dy =
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
grad_outputs->push_back(dy);
return Status::OK();
}
// MatMulGrad common used to read and check node attr state, and determine
// proper MatMul products for gradients based on input matrix transposition
// combinations.
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
Status MatMulGradCommon(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
const string& attr_adj_x, const string& attr_adj_y,
std::vector<Output>* grad_outputs) {
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
return errors::Unimplemented(
"MatMul gradient for complex data type is not supported yet.");
}
bool ta;
bool tb;
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
if (!ta && !tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
op.input(0), true, grad_inputs[0], false,
grad_outputs);
} else if (!ta && tb) {
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
grad_inputs[0], true, op.input(0), false,
grad_outputs);
} else if (ta && !tb) {
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
op.input(0), false, grad_inputs[0], false,
grad_outputs);
}
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
grad_inputs[0], true, op.input(0), true,
grad_outputs);
}
Status MatMulGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
grad_outputs);
}
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
} // anonymous namespace
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,183 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/framework/grad_op_registry.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
using namespace ops; // NOLINT(build/namespaces)
namespace {
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
// to a testutil library.
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope()) {}
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
const bool t_y, const Output& dz,
std::vector<Tensor>* out) {
// Compute forward MatMul: z = MatMul(x, y).
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
CHECK_NOTNULL(z.node());
std::vector<Output> grad_outputs;
// Call MatMulGrad which populates 'grad_outputs'.
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
EXPECT_EQ(2, grad_outputs.size());
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
}
void CallGradFunction(const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
GradFunc grad_fn;
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
TF_EXPECT_OK(root_.status());
}
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
const bool t_y) {
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
TF_EXPECT_OK(root_.status());
Tensor out;
GetTensor(root_, z, &out);
return out;
}
void RandMatMulGradData(const bool tx, const bool ty,
std::vector<Tensor>* data) {
// z = MatMul(x, y)
const int m = Rand();
const int k = Rand();
const int n = Rand();
// x.shape = [m, k]
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
data->emplace_back(DT_FLOAT, x_shape);
RandTensor(&data->back());
// y.shape = [k, n]
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
data->emplace_back(DT_FLOAT, y_shape);
RandTensor(&data->back());
// z.shape = [m, n]
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
RandTensor(&data->back());
}
void RandTensor(Tensor* t) {
test::FillFn<float>(
t, [this](const int i) { return static_cast<float>(Rand()); });
}
int Rand() { return 1 + (random::New64() % 10); }
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
// Note: they should be moved to a general/non-grad specific testutil class.
void GetTensors(const Scope& scope, OutputList tensors,
std::vector<Tensor>* out) {
SessionOptions options;
std::unique_ptr<Session> session(NewSession(options));
GraphDef def;
scope.graph()->ToGraphDef(&def);
graph::SetDefaultDevice("/cpu:0", &def);
TF_CHECK_OK(session->Create(def));
std::vector<string> names;
for (const auto& t : tensors) {
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
}
TF_CHECK_OK(session->Run({}, names, {}, out));
TF_CHECK_OK(session->Close());
}
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
std::vector<Tensor> outputs;
GetTensors(scope, {tensor}, &outputs);
*out = outputs[0];
}
Scope root_;
};
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
std::vector<Tensor> data;
RandMatMulGradData(false, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
std::vector<Tensor> data;
RandMatMulGradData(true, false, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(false, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
}
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
std::vector<Tensor> data;
RandMatMulGradData(true, true, &data);
auto x = Const(root_, data[0]);
auto y = Const(root_, data[1]);
auto dz = Const(root_, data[2]);
std::vector<Tensor> grad_outputs;
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
}
} // namespace
} // namespace tensorflow