Add C++ gradients to c_api.
#6268 This CL does the following: (1) Adds TF_AddGradients function to C_API which adds gradient nodes for the specified inputs. (2) Adds internal constructor for Scope, need to create a scope from an existing graph in the c_api. (3) Adds constructor for AddSymbolicGradients that assumes OnesLike when grad_inputs aren't provided. (4) Improves error message when gradients aren't provided. Change: 153092774
This commit is contained in:
parent
59ccf014e8
commit
908d5b6ede
@ -38,6 +38,9 @@ tf_cuda_library(
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/cc/saved_model:loader",
|
||||
"//tensorflow/cc:gradients",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:scope_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -97,21 +100,22 @@ tf_cc_test(
|
||||
# linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:grad_ops",
|
||||
"//tensorflow/cc/saved_model:signature_constants",
|
||||
"//tensorflow/cc/saved_model:tag_constants",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#ifndef __ANDROID__
|
||||
#include "tensorflow/cc/framework/gradients.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
@ -2113,6 +2116,75 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
|
||||
|
||||
void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
|
||||
|
||||
#ifndef __ANDROID__
|
||||
namespace {
|
||||
|
||||
void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status,
|
||||
std::vector<tensorflow::Output>* outputs) {
|
||||
outputs->resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
const TF_Output& tf_output = tf_outputs[i];
|
||||
(*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index);
|
||||
}
|
||||
}
|
||||
|
||||
void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
|
||||
TF_Output* tf_outputs) {
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
tf_outputs[i].oper = ToOperation(outputs[i].node());
|
||||
tf_outputs[i].index = outputs[i].index();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
#endif // __ANDROID__
|
||||
|
||||
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
|
||||
TF_Output* dx, TF_Status* status, TF_Output* dy) {
|
||||
#ifdef __ANDROID__
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"Adding gradients is not supported in Android. File a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues if this feature is "
|
||||
"important to you");
|
||||
#else
|
||||
std::vector<tensorflow::Output> y_arg;
|
||||
std::vector<tensorflow::Output> x_arg;
|
||||
std::vector<tensorflow::Output> dy_arg;
|
||||
OutputsFromTFOutputs(y, ny, status, &y_arg);
|
||||
OutputsFromTFOutputs(x, nx, status, &x_arg);
|
||||
|
||||
{
|
||||
// We need to hold on to the lock while we have a scope that uses TF_Graph.
|
||||
mutex_lock graph_lock(g->mu);
|
||||
|
||||
const int max_node_id_before = g->graph.num_node_ids();
|
||||
|
||||
tensorflow::Scope scope =
|
||||
NewInternalScope(&g->graph, &status->status, &g->refiner);
|
||||
|
||||
if (dx != nullptr) {
|
||||
std::vector<tensorflow::Output> dx_arg;
|
||||
OutputsFromTFOutputs(dx, ny, status, &dx_arg);
|
||||
status->status =
|
||||
AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
|
||||
} else {
|
||||
status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
|
||||
}
|
||||
|
||||
// Update g->name_map with the name_map from the scope, which will contain
|
||||
// the new gradient ops.
|
||||
for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) {
|
||||
Node* n = g->graph.FindNodeId(i);
|
||||
if (n == nullptr) continue;
|
||||
g->name_map[n->name()] = n;
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack the results from grad_outputs_arg.
|
||||
TFOutputsFromOutputs(dy_arg, dy);
|
||||
#endif // __ANDROID__
|
||||
}
|
||||
|
||||
// TF_Session functions ----------------------------------------------
|
||||
|
||||
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
|
||||
|
@ -898,7 +898,7 @@ typedef struct TF_WhileParams {
|
||||
// TF_FinishWhile() or TF_AbortWhile().
|
||||
//
|
||||
// Missing functionality (TODO):
|
||||
// - Gradients (not yet implmented for any ops)
|
||||
// - Gradients
|
||||
// - Reference-type inputs
|
||||
// - Directly referencing external tensors from the cond/body graphs (this is
|
||||
// possible in the Python API)
|
||||
@ -921,7 +921,22 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
|
||||
// called after a successful TF_NewWhile() call.
|
||||
void TF_AbortWhile(const TF_WhileParams* params);
|
||||
|
||||
// TODO(andydavis): Function to add gradients to a graph.
|
||||
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
|
||||
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
|
||||
// `dx` are used as initial gradients (which represent the symbolic partial
|
||||
// derivatives of some loss function `L` w.r.t. `y`).
|
||||
// `dx` must be nullptr or have size `ny`.
|
||||
// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all
|
||||
// shapes in `y`.
|
||||
// The partial derivatives are returned in `dy`. `dy` should be allocated to
|
||||
// size `nx`.
|
||||
//
|
||||
// WARNING: This function does not yet support all the gradients that python
|
||||
// supports. See
|
||||
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
|
||||
// for instructions on how to add C++ more gradients.
|
||||
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
|
||||
TF_Output* dx, TF_Status* status, TF_Output* dy);
|
||||
|
||||
// TODO(josh11b): Register OpDef, available to all operations added
|
||||
// to this graph.
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
using tensorflow::int32;
|
||||
using tensorflow::string;
|
||||
@ -1525,6 +1526,281 @@ TEST_F(CApiWhileLoopTest, BadTypes) {
|
||||
TF_AbortWhile(params_.get());
|
||||
}
|
||||
|
||||
REGISTER_OP("TestOpWithNoGradient")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Doc(R"doc(
|
||||
Test op with no grad registered.
|
||||
|
||||
x: input
|
||||
y: output
|
||||
)doc");
|
||||
|
||||
class CApiGradientsTest : public ::testing::Test {
|
||||
protected:
|
||||
CApiGradientsTest()
|
||||
: s_(TF_NewStatus()),
|
||||
graph_(TF_NewGraph()),
|
||||
expected_graph_(TF_NewGraph()) {}
|
||||
|
||||
~CApiGradientsTest() override {
|
||||
TF_DeleteGraph(graph_);
|
||||
TF_DeleteGraph(expected_graph_);
|
||||
TF_DeleteStatus(s_);
|
||||
}
|
||||
|
||||
void TestGradientsSuccess(bool grad_inputs_provided) {
|
||||
TF_Output inputs[2];
|
||||
TF_Output outputs[1];
|
||||
TF_Output grad_outputs[2];
|
||||
TF_Output expected_grad_outputs[2];
|
||||
|
||||
BuildSuccessGraph(inputs, outputs);
|
||||
BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
|
||||
|
||||
AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
|
||||
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
|
||||
// Compare that the graphs match.
|
||||
GraphDef expected_gdef;
|
||||
GraphDef gdef;
|
||||
EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef));
|
||||
EXPECT_TRUE(GetGraphDef(graph_, &gdef));
|
||||
TF_EXPECT_GRAPH_EQ(expected_gdef, gdef);
|
||||
|
||||
// Compare that the output of the gradients of both graphs match.
|
||||
RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs);
|
||||
}
|
||||
|
||||
void TestGradientsError(bool grad_inputs_provided) {
|
||||
TF_Output inputs[1];
|
||||
TF_Output outputs[1];
|
||||
TF_Output grad_outputs[1];
|
||||
|
||||
BuildErrorGraph(inputs, outputs);
|
||||
|
||||
AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
|
||||
|
||||
string expected_msg =
|
||||
"No gradient defined for op: TestOpWithNoGradient. Please see "
|
||||
"https://www.tensorflow.org/code/"
|
||||
"tensorflow/cc/gradients/README.md"
|
||||
" for instructions on how to add C++ gradients.";
|
||||
EXPECT_EQ(expected_msg, TF_Message(s_));
|
||||
}
|
||||
|
||||
// Run the graph and ensure that the gradient values are as expected.
|
||||
void RunGraphsAndCompareOutputs(TF_Output* grad_outputs,
|
||||
TF_Output* expected_grad_outputs) {
|
||||
std::unique_ptr<CSession> csession(new CSession(graph_, s_));
|
||||
std::unique_ptr<CSession> expected_csession(
|
||||
new CSession(expected_graph_, s_));
|
||||
|
||||
std::vector<TF_Output> grad_outputs_vec;
|
||||
grad_outputs_vec.assign(grad_outputs, grad_outputs + 2);
|
||||
csession->SetOutputs(grad_outputs_vec);
|
||||
csession->Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* out0 = csession->output_tensor(0);
|
||||
TF_Tensor* out1 = csession->output_tensor(1);
|
||||
|
||||
std::vector<TF_Output> expected_grad_outputs_vec;
|
||||
expected_grad_outputs_vec.assign(expected_grad_outputs,
|
||||
expected_grad_outputs + 2);
|
||||
expected_csession->SetOutputs(expected_grad_outputs_vec);
|
||||
expected_csession->Run(s_);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
TF_Tensor* expected_out0 = expected_csession->output_tensor(0);
|
||||
TF_Tensor* expected_out1 = expected_csession->output_tensor(1);
|
||||
|
||||
CompareTensors(out0, expected_out0);
|
||||
CompareTensors(out1, expected_out1);
|
||||
}
|
||||
|
||||
void CompareTensors(TF_Tensor* a, TF_Tensor* b) {
|
||||
float* a_data = static_cast<float*>(TF_TensorData(a));
|
||||
float* b_data = static_cast<float*>(TF_TensorData(b));
|
||||
EXPECT_EQ(*a_data, *b_data);
|
||||
}
|
||||
|
||||
void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
|
||||
TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
|
||||
if (grad_inputs_provided) {
|
||||
TF_Output grad_inputs[1];
|
||||
const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
|
||||
TF_Operation* grad_inputs_op =
|
||||
FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
|
||||
grad_inputs[0] = TF_Output{grad_inputs_op, 0};
|
||||
TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
|
||||
s_, grad_outputs);
|
||||
} else {
|
||||
TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
|
||||
grad_outputs);
|
||||
}
|
||||
}
|
||||
|
||||
void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) {
|
||||
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
|
||||
TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
|
||||
TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad");
|
||||
inputs[0] = TF_Output{const0, 0};
|
||||
outputs[0] = TF_Output{nograd, 0};
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
}
|
||||
|
||||
void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
|
||||
// Construct the following graph:
|
||||
// |
|
||||
// z|
|
||||
// |
|
||||
// MatMul
|
||||
// / \
|
||||
// ^ ^
|
||||
// | |
|
||||
// x| y|
|
||||
// | |
|
||||
// | |
|
||||
// Const_0 Const_1
|
||||
//
|
||||
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
|
||||
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
|
||||
TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0");
|
||||
TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1");
|
||||
TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul");
|
||||
inputs[0] = TF_Output{const0, 0};
|
||||
inputs[1] = TF_Output{const1, 0};
|
||||
outputs[0] = TF_Output{matmul, 0};
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
||||
}
|
||||
|
||||
void BuildExpectedGraph(bool grad_inputs_provided,
|
||||
TF_Output* expected_grad_outputs) {
|
||||
// The expected graph looks like this if grad_inputs_provided.
|
||||
// If grad_inputs_provided is false, Const_0 will be a OnesLike op.
|
||||
// ^ ^
|
||||
// dy| dx| // MatMul Gradient Graph
|
||||
// | |
|
||||
// MatMul_2 MatMul_1
|
||||
// ^ ^ ^ ^
|
||||
// | |----------| |
|
||||
// | ^ |
|
||||
// | dz| |
|
||||
// | | |
|
||||
// | Const_3 |
|
||||
// | |
|
||||
// | ^ |
|
||||
// | z| | // MatMul Forward Graph
|
||||
// | | |
|
||||
// | MatMul |
|
||||
// | / \ |
|
||||
// | ^ ^ |
|
||||
// | | | |
|
||||
// |---x| y|----|
|
||||
// | |
|
||||
// | |
|
||||
// Const_0 Const_1
|
||||
//
|
||||
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
|
||||
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
|
||||
TF_Operation* const0 =
|
||||
FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
|
||||
TF_Operation* const1 =
|
||||
FloatConst2x2(expected_graph_, s_, const1_val, "Const_1");
|
||||
TF_Operation* matmul =
|
||||
MatMul(expected_graph_, s_, const0, const1, "MatMul");
|
||||
|
||||
TF_Operation* const3;
|
||||
if (grad_inputs_provided) {
|
||||
const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
|
||||
const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
|
||||
} else {
|
||||
const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike");
|
||||
}
|
||||
|
||||
TF_Operation* matmul1 =
|
||||
MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true);
|
||||
TF_Operation* matmul2 =
|
||||
MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false);
|
||||
expected_grad_outputs[0] = {matmul1, 0};
|
||||
expected_grad_outputs[1] = {matmul2, 0};
|
||||
}
|
||||
|
||||
TF_Tensor* FloatTensor2x2(const float* values) {
|
||||
const int64_t dims[2] = {2, 2};
|
||||
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
|
||||
memcpy(TF_TensorData(t), values, sizeof(float) * 4);
|
||||
return t;
|
||||
}
|
||||
|
||||
TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s,
|
||||
const float* values, const char* name) {
|
||||
unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor);
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
|
||||
TF_SetAttrTensor(desc, "value", tensor.get(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||
TF_SetAttrType(desc, "dtype", TF_FLOAT);
|
||||
TF_Operation* op = TF_FinishOperation(desc, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l,
|
||||
TF_Operation* r, const char* name,
|
||||
bool transpose_a = false, bool transpose_b = false) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
|
||||
if (transpose_a) {
|
||||
TF_SetAttrBool(desc, "transpose_a", 1);
|
||||
}
|
||||
if (transpose_b) {
|
||||
TF_SetAttrBool(desc, "transpose_b", 1);
|
||||
}
|
||||
TF_AddInput(desc, {l, 0});
|
||||
TF_AddInput(desc, {r, 0});
|
||||
TF_Operation* op = TF_FinishOperation(desc, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in,
|
||||
const char* name) {
|
||||
TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name);
|
||||
TF_AddInput(desc, {in, 0});
|
||||
TF_Operation* op = TF_FinishOperation(desc, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in,
|
||||
const char* name) {
|
||||
TF_OperationDescription* desc =
|
||||
TF_NewOperation(graph, "TestOpWithNoGradient", name);
|
||||
TF_AddInput(desc, {in, 0});
|
||||
TF_Operation* op = TF_FinishOperation(desc, s);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
return op;
|
||||
}
|
||||
|
||||
TF_Status* s_;
|
||||
TF_Graph* graph_;
|
||||
TF_Graph* expected_graph_;
|
||||
};
|
||||
|
||||
TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); }
|
||||
|
||||
TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
|
||||
TestGradientsSuccess(false);
|
||||
}
|
||||
|
||||
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
|
||||
TestGradientsError(true);
|
||||
}
|
||||
|
||||
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
|
||||
TestGradientsError(false);
|
||||
}
|
||||
|
||||
// Create a tensor with values of type TF_INT8 provided by `values`.
|
||||
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
|
||||
int64_t num_values = 1;
|
||||
|
@ -122,7 +122,10 @@ cc_library_with_android_deps(
|
||||
|
||||
cc_library_with_android_deps(
|
||||
name = "scope",
|
||||
srcs = ["framework/scope.cc"],
|
||||
srcs = [
|
||||
"framework/scope.cc",
|
||||
"framework/scope_internal.h",
|
||||
],
|
||||
hdrs = ["framework/scope.h"],
|
||||
android_deps = ["//tensorflow/core:android_tensorflow_lib"],
|
||||
common_deps = [
|
||||
@ -136,6 +139,15 @@ cc_library_with_android_deps(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library_with_android_deps(
|
||||
name = "scope_internal",
|
||||
hdrs = ["framework/scope_internal.h"],
|
||||
common_deps = [
|
||||
":scope",
|
||||
],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "framework_scope_test",
|
||||
srcs = ["framework/scope_test.cc"],
|
||||
|
@ -32,7 +32,13 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) {
|
||||
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const {
|
||||
auto iter = registry_.find(op);
|
||||
if (iter == registry_.end()) {
|
||||
return errors::NotFound("No gradient defined for op: ", op);
|
||||
const string error_msg =
|
||||
"No gradient defined for op: " + op +
|
||||
". Please see "
|
||||
"https://www.tensorflow.org/code/"
|
||||
"tensorflow/cc/gradients/README.md"
|
||||
" for instructions on how to add C++ gradients.";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
*func = iter->second;
|
||||
return Status::OK();
|
||||
|
@ -367,6 +367,19 @@ Status AddSymbolicGradients(const Scope& scope,
|
||||
return builder.AddGradients();
|
||||
}
|
||||
|
||||
Status AddSymbolicGradients(const Scope& scope,
|
||||
const std::vector<Output>& outputs,
|
||||
const std::vector<Output>& inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
std::vector<Output> grad_inputs;
|
||||
grad_inputs.reserve(outputs.size());
|
||||
for (const Output& output : outputs) {
|
||||
grad_inputs.emplace_back(ops::OnesLike(scope, output));
|
||||
}
|
||||
return AddSymbolicGradients(scope, outputs, inputs, grad_inputs,
|
||||
grad_outputs);
|
||||
}
|
||||
|
||||
Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); }
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -27,16 +27,19 @@ namespace tensorflow {
|
||||
/// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes
|
||||
/// to the graph associated with 'scope', which compute (and return in
|
||||
/// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'.
|
||||
///
|
||||
|
||||
// TODO(andydavis) Add overload of this function with no 'grad_inputs' arg.
|
||||
// Implementation will fill in 'OnesLike' for all shapes in 'outputs'.
|
||||
Status AddSymbolicGradients(const Scope& scope,
|
||||
const std::vector<Output>& outputs,
|
||||
const std::vector<Output>& inputs,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs);
|
||||
|
||||
// Same as above, but uses 'OnesLike' for all shapes in
|
||||
// 'outputs' as grad_inputs.
|
||||
Status AddSymbolicGradients(const Scope& scope,
|
||||
const std::vector<Output>& outputs,
|
||||
const std::vector<Output>& inputs,
|
||||
std::vector<Output>* grad_outputs);
|
||||
|
||||
/// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient
|
||||
/// flows along some graph edge during backpropagation).
|
||||
/// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients'
|
||||
|
@ -40,7 +40,7 @@ class GradientsTest : public ::testing::Test {
|
||||
TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test));
|
||||
GraphDef gdef_exp;
|
||||
TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp));
|
||||
TF_EXPECT_GRAPH_EQ(gdef_test, gdef_exp);
|
||||
TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test);
|
||||
}
|
||||
|
||||
Scope scope_expected_;
|
||||
@ -98,6 +98,32 @@ TEST_F(GradientsTest, OneMatMul) {
|
||||
CompareTestAndExpectedGraphs();
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, OneMatMul_InferGradInputs) {
|
||||
for (const bool expected : {false, true}) {
|
||||
const Scope& scope = expected ? scope_expected_ : scope_test_;
|
||||
// Construct forward graph.
|
||||
auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}});
|
||||
auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}});
|
||||
auto z = MatMul(scope, x, y);
|
||||
TF_ASSERT_OK(scope.status());
|
||||
CHECK_NOTNULL(z.node());
|
||||
|
||||
if (expected) {
|
||||
// Construct backward graph.
|
||||
// The gradients function adds a OnesLike to create a dz of ones with the
|
||||
// shape of z.
|
||||
auto dz = OnesLike(scope, z);
|
||||
auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true));
|
||||
auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true));
|
||||
} else {
|
||||
// Call AddSymbolicGradients.
|
||||
std::vector<Output> grad_outputs;
|
||||
TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs));
|
||||
}
|
||||
}
|
||||
CompareTestAndExpectedGraphs();
|
||||
}
|
||||
|
||||
TEST_F(GradientsTest, TwoMatMuls_Chained) {
|
||||
for (const bool expected : {false, true}) {
|
||||
const Scope& scope = expected ? scope_expected_ : scope_test_;
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/framework/scope_internal.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
@ -25,6 +25,20 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class Scope::Impl {
|
||||
public:
|
||||
// A NameMap is used to keep track of suffixes for names used in a scope. A
|
||||
// name that has not been used so far in a scope will get no suffix. Later
|
||||
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
|
||||
// can share the same NameMap. For instance, a new scope created using
|
||||
// WithControlDependencies() should would share the same NameMap with the
|
||||
// parent.
|
||||
typedef std::unordered_map<string, int> NameMap;
|
||||
|
||||
Impl(const std::shared_ptr<Graph>& graph,
|
||||
const std::shared_ptr<Status>& status,
|
||||
const std::shared_ptr<NameMap>& name_map,
|
||||
const std::shared_ptr<ShapeRefiner>& refiner);
|
||||
|
||||
private:
|
||||
friend class Scope;
|
||||
|
||||
@ -40,14 +54,6 @@ class Scope::Impl {
|
||||
enum class Colocate;
|
||||
};
|
||||
|
||||
// A NameMap is used to keep track of suffixes for names used in a scope. A
|
||||
// name that has not been used so far in a scope will get no suffix. Later
|
||||
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
|
||||
// can share the same NameMap. For instance, a new scope created using
|
||||
// WithControlDependencies() should would share the same NameMap with the
|
||||
// parent.
|
||||
typedef std::unordered_map<string, int> NameMap;
|
||||
|
||||
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner);
|
||||
Impl(const Scope& other, Tags::ScopeName, const string& name,
|
||||
bool copy_names);
|
||||
@ -116,6 +122,17 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
|
||||
scope_used_(nullptr),
|
||||
colocation_constraints_() {}
|
||||
|
||||
Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
|
||||
const std::shared_ptr<Status>& status,
|
||||
const std::shared_ptr<NameMap>& name_map,
|
||||
const std::shared_ptr<ShapeRefiner>& refiner)
|
||||
: graph_(graph),
|
||||
status_(status),
|
||||
name_map_(name_map),
|
||||
refiner_(refiner),
|
||||
scope_used_(nullptr),
|
||||
colocation_constraints_() {}
|
||||
|
||||
Scope Scope::NewRootScope() {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
ShapeRefiner* refiner =
|
||||
@ -277,7 +294,7 @@ std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const {
|
||||
return impl()->graph_;
|
||||
}
|
||||
|
||||
Status Scope::status() const { return *impl()->status_; };
|
||||
Status Scope::status() const { return *impl()->status_; }
|
||||
|
||||
const std::vector<Operation>& Scope::control_deps() const {
|
||||
return impl()->control_deps_;
|
||||
@ -464,4 +481,26 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
|
||||
}
|
||||
}
|
||||
|
||||
class InternalScope {
|
||||
public:
|
||||
// NewScope doesn't take ownership of the inputs.
|
||||
static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
|
||||
Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
|
||||
for (const Node* node : graph->nodes()) {
|
||||
(*name_map)[node->name()] = 0;
|
||||
}
|
||||
// We provide null destructors for these shared ptrs (except for name_map)
|
||||
// since the caller owns them and doesn't want the scope to destroy them.
|
||||
return Scope(new Scope::Impl(
|
||||
std::shared_ptr<Graph>(graph, [](Graph*) {}),
|
||||
std::shared_ptr<Status>(status, [](Status*) {}),
|
||||
std::shared_ptr<Scope::Impl::NameMap>(name_map),
|
||||
std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {})));
|
||||
}
|
||||
};
|
||||
|
||||
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
|
||||
return InternalScope::NewScope(graph, status, refiner);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -204,6 +204,7 @@ class Scope {
|
||||
const std::vector<Operation>& control_deps() const;
|
||||
|
||||
private:
|
||||
friend class InternalScope;
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
Impl* impl() { return impl_.get(); }
|
||||
|
33
tensorflow/cc/framework/scope_internal.h
Normal file
33
tensorflow/cc/framework/scope_internal.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ShapeRefiner;
|
||||
|
||||
// NewInternalScope returns a new scope which doesn't take ownership of
|
||||
// graph, status, name_map, and refiner.
|
||||
// This is intended to enable the C API (which are used by other language
|
||||
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
|
||||
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_
|
@ -689,6 +689,14 @@ set (pywrap_tensorflow_internal_src
|
||||
"${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h"
|
||||
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h"
|
||||
"${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc"
|
||||
)
|
||||
|
||||
@ -707,6 +715,7 @@ if(WIN32)
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
$<TARGET_OBJECTS:tf_core_cpu>
|
||||
$<TARGET_OBJECTS:tf_core_framework>
|
||||
$<TARGET_OBJECTS:tf_cc_ops>
|
||||
$<TARGET_OBJECTS:tf_core_ops>
|
||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
|
||||
@ -742,6 +751,7 @@ add_library(pywrap_tensorflow_internal SHARED
|
||||
$<TARGET_OBJECTS:tf_core_lib>
|
||||
$<TARGET_OBJECTS:tf_core_cpu>
|
||||
$<TARGET_OBJECTS:tf_core_framework>
|
||||
$<TARGET_OBJECTS:tf_cc_ops>
|
||||
$<TARGET_OBJECTS:tf_core_ops>
|
||||
$<TARGET_OBJECTS:tf_core_direct_session>
|
||||
$<TARGET_OBJECTS:tf_tools_transform_graph_lib>
|
||||
|
Loading…
Reference in New Issue
Block a user