diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 6e39deee636..af96ce70b69 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -38,6 +38,9 @@ tf_cuda_library( ], "//conditions:default": [ "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc:gradients", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -97,21 +100,22 @@ tf_cc_test( # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:grad_ops", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:direct_session", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:math", - "//third_party/eigen3", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index a8c360541c1..b2c6b13b690 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -21,6 +21,9 @@ limitations under the License. #include #ifndef __ANDROID__ +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/saved_model/loader.h" #endif #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -2113,6 +2116,75 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } +#ifndef __ANDROID__ +namespace { + +void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, + std::vector* outputs) { + outputs->resize(n); + for (int i = 0; i < n; i++) { + const TF_Output& tf_output = tf_outputs[i]; + (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); + } +} + +void TFOutputsFromOutputs(const std::vector& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} + +} // namespace +#endif // __ANDROID__ + +void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, TF_Output* dy) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Adding gradients is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); +#else + std::vector y_arg; + std::vector x_arg; + std::vector dy_arg; + OutputsFromTFOutputs(y, ny, status, &y_arg); + OutputsFromTFOutputs(x, nx, status, &x_arg); + + { + // We need to hold on to the lock while we have a scope that uses TF_Graph. + mutex_lock graph_lock(g->mu); + + const int max_node_id_before = g->graph.num_node_ids(); + + tensorflow::Scope scope = + NewInternalScope(&g->graph, &status->status, &g->refiner); + + if (dx != nullptr) { + std::vector dx_arg; + OutputsFromTFOutputs(dx, ny, status, &dx_arg); + status->status = + AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); + } else { + status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg); + } + + // Update g->name_map with the name_map from the scope, which will contain + // the new gradient ops. + for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) { + Node* n = g->graph.FindNodeId(i); + if (n == nullptr) continue; + g->name_map[n->name()] = n; + } + } + + // Unpack the results from grad_outputs_arg. + TFOutputsFromOutputs(dy_arg, dy); +#endif // __ANDROID__ +} + // TF_Session functions ---------------------------------------------- TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 9b08f9d9814..88438a35854 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -898,7 +898,7 @@ typedef struct TF_WhileParams { // TF_FinishWhile() or TF_AbortWhile(). // // Missing functionality (TODO): -// - Gradients (not yet implmented for any ops) +// - Gradients // - Reference-type inputs // - Directly referencing external tensors from the cond/body graphs (this is // possible in the Python API) @@ -921,7 +921,22 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, // called after a successful TF_NewWhile() call. void TF_AbortWhile(const TF_WhileParams* params); -// TODO(andydavis): Function to add gradients to a graph. +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, TF_Output* dy); // TODO(josh11b): Register OpDef, available to all operations added // to this graph. diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index d846daa71b3..0ddc59db20e 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/util/equal_graph_def.h" using tensorflow::int32; using tensorflow::string; @@ -1525,6 +1526,281 @@ TEST_F(CApiWhileLoopTest, BadTypes) { TF_AbortWhile(params_.get()); } +REGISTER_OP("TestOpWithNoGradient") + .Input("x: T") + .Output("y: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Test op with no grad registered. + +x: input +y: output +)doc"); + +class CApiGradientsTest : public ::testing::Test { + protected: + CApiGradientsTest() + : s_(TF_NewStatus()), + graph_(TF_NewGraph()), + expected_graph_(TF_NewGraph()) {} + + ~CApiGradientsTest() override { + TF_DeleteGraph(graph_); + TF_DeleteGraph(expected_graph_); + TF_DeleteStatus(s_); + } + + void TestGradientsSuccess(bool grad_inputs_provided) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + TF_Output expected_grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); + + AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); + + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Compare that the graphs match. + GraphDef expected_gdef; + GraphDef gdef; + EXPECT_TRUE(GetGraphDef(expected_graph_, &expected_gdef)); + EXPECT_TRUE(GetGraphDef(graph_, &gdef)); + TF_EXPECT_GRAPH_EQ(expected_gdef, gdef); + + // Compare that the output of the gradients of both graphs match. + RunGraphsAndCompareOutputs(grad_outputs, expected_grad_outputs); + } + + void TestGradientsError(bool grad_inputs_provided) { + TF_Output inputs[1]; + TF_Output outputs[1]; + TF_Output grad_outputs[1]; + + BuildErrorGraph(inputs, outputs); + + AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + + string expected_msg = + "No gradient defined for op: TestOpWithNoGradient. Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + EXPECT_EQ(expected_msg, TF_Message(s_)); + } + + // Run the graph and ensure that the gradient values are as expected. + void RunGraphsAndCompareOutputs(TF_Output* grad_outputs, + TF_Output* expected_grad_outputs) { + std::unique_ptr csession(new CSession(graph_, s_)); + std::unique_ptr expected_csession( + new CSession(expected_graph_, s_)); + + std::vector grad_outputs_vec; + grad_outputs_vec.assign(grad_outputs, grad_outputs + 2); + csession->SetOutputs(grad_outputs_vec); + csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* out0 = csession->output_tensor(0); + TF_Tensor* out1 = csession->output_tensor(1); + + std::vector expected_grad_outputs_vec; + expected_grad_outputs_vec.assign(expected_grad_outputs, + expected_grad_outputs + 2); + expected_csession->SetOutputs(expected_grad_outputs_vec); + expected_csession->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Tensor* expected_out0 = expected_csession->output_tensor(0); + TF_Tensor* expected_out1 = expected_csession->output_tensor(1); + + CompareTensors(out0, expected_out0); + CompareTensors(out1, expected_out1); + } + + void CompareTensors(TF_Tensor* a, TF_Tensor* b) { + float* a_data = static_cast(TF_TensorData(a)); + float* b_data = static_cast(TF_TensorData(b)); + EXPECT_EQ(*a_data, *b_data); + } + + void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, + TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + if (grad_inputs_provided) { + TF_Output grad_inputs[1]; + const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; + TF_Operation* grad_inputs_op = + FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); + grad_inputs[0] = TF_Output{grad_inputs_op, 0}; + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, + s_, grad_outputs); + } else { + TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, + grad_outputs); + } + } + + void BuildErrorGraph(TF_Output* inputs, TF_Output* outputs) { + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* nograd = NoGradientOp(graph_, s_, const0, "NoGrad"); + inputs[0] = TF_Output{const0, 0}; + outputs[0] = TF_Output{nograd, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) { + // Construct the following graph: + // | + // z| + // | + // MatMul + // / \ + // ^ ^ + // | | + // x| y| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = FloatConst2x2(graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = FloatConst2x2(graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = MatMul(graph_, s_, const0, const1, "MatMul"); + inputs[0] = TF_Output{const0, 0}; + inputs[1] = TF_Output{const1, 0}; + outputs[0] = TF_Output{matmul, 0}; + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void BuildExpectedGraph(bool grad_inputs_provided, + TF_Output* expected_grad_outputs) { + // The expected graph looks like this if grad_inputs_provided. + // If grad_inputs_provided is false, Const_0 will be a OnesLike op. + // ^ ^ + // dy| dx| // MatMul Gradient Graph + // | | + // MatMul_2 MatMul_1 + // ^ ^ ^ ^ + // | |----------| | + // | ^ | + // | dz| | + // | | | + // | Const_3 | + // | | + // | ^ | + // | z| | // MatMul Forward Graph + // | | | + // | MatMul | + // | / \ | + // | ^ ^ | + // | | | | + // |---x| y|----| + // | | + // | | + // Const_0 Const_1 + // + const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; + const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + TF_Operation* const0 = + FloatConst2x2(expected_graph_, s_, const0_val, "Const_0"); + TF_Operation* const1 = + FloatConst2x2(expected_graph_, s_, const1_val, "Const_1"); + TF_Operation* matmul = + MatMul(expected_graph_, s_, const0, const1, "MatMul"); + + TF_Operation* const3; + if (grad_inputs_provided) { + const float const3_val[] = {1.0, 1.0, 1.0, 1.0}; + const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); + } else { + const3 = OnesLike(expected_graph_, s_, matmul, "OnesLike"); + } + + TF_Operation* matmul1 = + MatMul(expected_graph_, s_, const3, const1, "MatMul_1", false, true); + TF_Operation* matmul2 = + MatMul(expected_graph_, s_, const0, const3, "MatMul_2", true, false); + expected_grad_outputs[0] = {matmul1, 0}; + expected_grad_outputs[1] = {matmul2, 0}; + } + + TF_Tensor* FloatTensor2x2(const float* values) { + const int64_t dims[2] = {2, 2}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4); + memcpy(TF_TensorData(t), values, sizeof(float) * 4); + return t; + } + + TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, + const float* values, const char* name) { + unique_tensor_ptr tensor(FloatTensor2x2(values), TF_DeleteTensor); + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); + TF_SetAttrTensor(desc, "value", tensor.get(), s); + if (TF_GetCode(s) != TF_OK) return nullptr; + TF_SetAttrType(desc, "dtype", TF_FLOAT); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, + TF_Operation* r, const char* name, + bool transpose_a = false, bool transpose_b = false) { + TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name); + if (transpose_a) { + TF_SetAttrBool(desc, "transpose_a", 1); + } + if (transpose_b) { + TF_SetAttrBool(desc, "transpose_b", 1); + } + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* OnesLike(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = TF_NewOperation(graph, "OnesLike", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Operation* NoGradientOp(TF_Graph* graph, TF_Status* s, TF_Operation* in, + const char* name) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "TestOpWithNoGradient", name); + TF_AddInput(desc, {in, 0}); + TF_Operation* op = TF_FinishOperation(desc, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return op; + } + + TF_Status* s_; + TF_Graph* graph_; + TF_Graph* expected_graph_; +}; + +TEST_F(CApiGradientsTest, Gradients_GradInputs) { TestGradientsSuccess(true); } + +TEST_F(CApiGradientsTest, Gradients_NoGradInputs) { + TestGradientsSuccess(false); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) { + TestGradientsError(true); +} + +TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { + TestGradientsError(false); +} + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index aaebdded9a5..42fa139282a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -122,7 +122,10 @@ cc_library_with_android_deps( cc_library_with_android_deps( name = "scope", - srcs = ["framework/scope.cc"], + srcs = [ + "framework/scope.cc", + "framework/scope_internal.h", + ], hdrs = ["framework/scope.h"], android_deps = ["//tensorflow/core:android_tensorflow_lib"], common_deps = [ @@ -136,6 +139,15 @@ cc_library_with_android_deps( ], ) +cc_library_with_android_deps( + name = "scope_internal", + hdrs = ["framework/scope_internal.h"], + common_deps = [ + ":scope", + ], + deps = [], +) + tf_cc_test( name = "framework_scope_test", srcs = ["framework/scope_test.cc"], diff --git a/tensorflow/cc/framework/grad_op_registry.cc b/tensorflow/cc/framework/grad_op_registry.cc index 0d6a377b507..254705736e7 100644 --- a/tensorflow/cc/framework/grad_op_registry.cc +++ b/tensorflow/cc/framework/grad_op_registry.cc @@ -32,7 +32,13 @@ bool GradOpRegistry::Register(const string& op, GradFunc func) { Status GradOpRegistry::Lookup(const string& op, GradFunc* func) const { auto iter = registry_.find(op); if (iter == registry_.end()) { - return errors::NotFound("No gradient defined for op: ", op); + const string error_msg = + "No gradient defined for op: " + op + + ". Please see " + "https://www.tensorflow.org/code/" + "tensorflow/cc/gradients/README.md" + " for instructions on how to add C++ gradients."; + return errors::NotFound(error_msg); } *func = iter->second; return Status::OK(); diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 2c60f947a55..4ada9351caa 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -367,6 +367,19 @@ Status AddSymbolicGradients(const Scope& scope, return builder.AddGradients(); } +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs) { + std::vector grad_inputs; + grad_inputs.reserve(outputs.size()); + for (const Output& output : outputs) { + grad_inputs.emplace_back(ops::OnesLike(scope, output)); + } + return AddSymbolicGradients(scope, outputs, inputs, grad_inputs, + grad_outputs); +} + Output NoGradient() { return SymbolicGradientBuilder::NoGradient(); } } // end namespace tensorflow diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index d076bc43b4f..717f6f0636d 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -27,16 +27,19 @@ namespace tensorflow { /// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes /// to the graph associated with 'scope', which compute (and return in /// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'. -/// - -// TODO(andydavis) Add overload of this function with no 'grad_inputs' arg. -// Implementation will fill in 'OnesLike' for all shapes in 'outputs'. Status AddSymbolicGradients(const Scope& scope, const std::vector& outputs, const std::vector& inputs, const std::vector& grad_inputs, std::vector* grad_outputs); +// Same as above, but uses 'OnesLike' for all shapes in +// 'outputs' as grad_inputs. +Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs); + /// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient /// flows along some graph edge during backpropagation). /// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients' diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 6c2c2fcd1e2..7783bdce3a7 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -40,7 +40,7 @@ class GradientsTest : public ::testing::Test { TF_ASSERT_OK(scope_test_.ToGraphDef(&gdef_test)); GraphDef gdef_exp; TF_ASSERT_OK(scope_expected_.ToGraphDef(&gdef_exp)); - TF_EXPECT_GRAPH_EQ(gdef_test, gdef_exp); + TF_EXPECT_GRAPH_EQ(gdef_exp, gdef_test); } Scope scope_expected_; @@ -98,6 +98,32 @@ TEST_F(GradientsTest, OneMatMul) { CompareTestAndExpectedGraphs(); } +TEST_F(GradientsTest, OneMatMul_InferGradInputs) { + for (const bool expected : {false, true}) { + const Scope& scope = expected ? scope_expected_ : scope_test_; + // Construct forward graph. + auto x = Const(scope, {{1.0, 2.0}, {3.0, 4.0}}); + auto y = Const(scope, {{1.0, 0.0}, {0.0, 1.0}}); + auto z = MatMul(scope, x, y); + TF_ASSERT_OK(scope.status()); + CHECK_NOTNULL(z.node()); + + if (expected) { + // Construct backward graph. + // The gradients function adds a OnesLike to create a dz of ones with the + // shape of z. + auto dz = OnesLike(scope, z); + auto dx = MatMul(scope, dz, y, MatMul::TransposeB(true)); + auto dy = MatMul(scope, x, dz, MatMul::TransposeA(true)); + } else { + // Call AddSymbolicGradients. + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope, {z}, {x, y}, &grad_outputs)); + } + } + CompareTestAndExpectedGraphs(); +} + TEST_F(GradientsTest, TwoMatMuls_Chained) { for (const bool expected : {false, true}) { const Scope& scope = expected ? scope_expected_ : scope_test_; diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 571c6e1e579..8b7fc1406f0 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -25,6 +25,20 @@ limitations under the License. namespace tensorflow { class Scope::Impl { + public: + // A NameMap is used to keep track of suffixes for names used in a scope. A + // name that has not been used so far in a scope will get no suffix. Later + // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes + // can share the same NameMap. For instance, a new scope created using + // WithControlDependencies() should would share the same NameMap with the + // parent. + typedef std::unordered_map NameMap; + + Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner); + private: friend class Scope; @@ -40,14 +54,6 @@ class Scope::Impl { enum class Colocate; }; - // A NameMap is used to keep track of suffixes for names used in a scope. A - // name that has not been used so far in a scope will get no suffix. Later - // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes - // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. - typedef std::unordered_map NameMap; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); @@ -116,6 +122,17 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, scope_used_(nullptr), colocation_constraints_() {} +Scope::Impl::Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner) + : graph_(graph), + status_(status), + name_map_(name_map), + refiner_(refiner), + scope_used_(nullptr), + colocation_constraints_() {} + Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = @@ -277,7 +294,7 @@ std::shared_ptr Scope::graph_as_shared_ptr() const { return impl()->graph_; } -Status Scope::status() const { return *impl()->status_; }; +Status Scope::status() const { return *impl()->status_; } const std::vector& Scope::control_deps() const { return impl()->control_deps_; @@ -464,4 +481,26 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } +class InternalScope { + public: + // NewScope doesn't take ownership of the inputs. + static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; + for (const Node* node : graph->nodes()) { + (*name_map)[node->name()] = 0; + } + // We provide null destructors for these shared ptrs (except for name_map) + // since the caller owns them and doesn't want the scope to destroy them. + return Scope(new Scope::Impl( + std::shared_ptr(graph, [](Graph*) {}), + std::shared_ptr(status, [](Status*) {}), + std::shared_ptr(name_map), + std::shared_ptr(refiner, [](ShapeRefiner*) {}))); + } +}; + +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + return InternalScope::NewScope(graph, status, refiner); +} + } // namespace tensorflow diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index ce70da70963..ec3543772d8 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -204,6 +204,7 @@ class Scope { const std::vector& control_deps() const; private: + friend class InternalScope; class Impl; std::unique_ptr impl_; Impl* impl() { return impl_.get(); } diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h new file mode 100644 index 00000000000..f2a911877f0 --- /dev/null +++ b/tensorflow/cc/framework/scope_internal.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +class ShapeRefiner; + +// NewInternalScope returns a new scope which doesn't take ownership of +// graph, status, name_map, and refiner. +// This is intended to enable the C API (which are used by other language +// bindings) to create a Scope and access C++ functionality (i.e. gradients). +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index c0c6cd54ac5..9fcb25c4f74 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -689,6 +689,14 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc" + "${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/grad_op_registry.cc" + "${tensorflow_source_dir}/tensorflow/cc/framework/ops.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc" + "${tensorflow_source_dir}/tensorflow/cc/framework/scope_internal.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc" "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc" ) @@ -707,6 +715,7 @@ if(WIN32) $ $ $ + $ $ $ $ @@ -742,6 +751,7 @@ add_library(pywrap_tensorflow_internal SHARED $ $ $ + $ $ $ $