C++ Gradients: Port a couple of gradient functions to the new C++ graph building API.
Change: 129016020
This commit is contained in:
parent
01336ac329
commit
84391613c3
tensorflow/cc
@ -73,6 +73,46 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grad_op_registry",
|
||||
srcs = ["framework/grad_op_registry.cc"],
|
||||
hdrs = ["framework/grad_op_registry.h"],
|
||||
deps = [
|
||||
":ops",
|
||||
":scope",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "math_grad",
|
||||
srcs = ["gradients/math_grad.cc"],
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":ops",
|
||||
":scope",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gradients/math_grad_test",
|
||||
deps = [
|
||||
":cc_ops",
|
||||
":grad_op_registry",
|
||||
":math_grad",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
|
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
42
tensorflow/cc/framework/grad_op_registry.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// static
|
||||
GradOpRegistry* GradOpRegistry::Global() {
|
||||
static GradOpRegistry* grad_op_registry = new GradOpRegistry;
|
||||
return grad_op_registry;
|
||||
}
|
||||
|
||||
bool GradOpRegistry::Register(const string& op, GradFunc func) {
|
||||
CHECK(registry_.insert({op, func}).second) << "Existing gradient for " << op;
|
||||
return true;
|
||||
}
|
||||
|
||||
Status GradOpRegistry::Lookup(const string& op, GradFunc* func) {
|
||||
auto iter = registry_.find(op);
|
||||
if (iter == registry_.end()) {
|
||||
return errors::NotFound("No gradient defined for op: ", op);
|
||||
}
|
||||
*func = iter->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace ops
|
||||
} // namespace tensorflow
|
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
75
tensorflow/cc/framework/grad_op_registry.h
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
// GradFunc is the signature for all gradient functions in GradOpRegistry.
|
||||
// Implementations should add operations to compute the gradient outputs of 'op'
|
||||
// (returned in 'grad_outputs') using 'scope' and 'grad_inputs'.
|
||||
typedef Status (*GradFunc)(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs);
|
||||
|
||||
// GradOpRegistry maintains a static registry of gradient functions.
|
||||
// Gradient functions are indexed in the registry by the forward op name (i.e.
|
||||
// "MatMul" -> MatMulGrad func).
|
||||
class GradOpRegistry {
|
||||
public:
|
||||
// Registers 'func' as the the gradient function for 'op'.
|
||||
// Returns true if registration was succesful, check fails otherwise.
|
||||
bool Register(const string& op, GradFunc func);
|
||||
|
||||
// Sets 'func' to the gradient function for 'op' and returns Status OK if
|
||||
// the gradient function for 'op' exists in the registry.
|
||||
// Note that 'func' can be null for ops that have registered no-gradient with
|
||||
// the registry.
|
||||
// Returns error status otherwise.
|
||||
Status Lookup(const string& op, GradFunc* func);
|
||||
|
||||
// Returns a pointer to the global gradient function registry.
|
||||
static GradOpRegistry* Global();
|
||||
|
||||
private:
|
||||
std::unordered_map<string, GradFunc> registry_;
|
||||
};
|
||||
|
||||
} // namespace ops
|
||||
|
||||
// Macros used to define gradient functions for ops.
|
||||
#define REGISTER_GRADIENT_OP(name, fn) \
|
||||
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn)
|
||||
|
||||
#define REGISTER_NO_GRADIENT_OP(name) \
|
||||
REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr)
|
||||
|
||||
#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \
|
||||
REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn)
|
||||
|
||||
#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \
|
||||
static bool unused_ret_val_##ctr = \
|
||||
::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_
|
@ -18,6 +18,44 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
Operation::Operation(Node* n) : inputs_(GetInputs(n)), node_(n) {}
|
||||
|
||||
Output Operation::input(int i) const {
|
||||
CHECK_NOTNULL(node_);
|
||||
CHECK_GE(i, 0);
|
||||
CHECK_LT(i, node_->num_inputs());
|
||||
// Handle the case where the input was unknown at the time this
|
||||
// Operation was constructed.
|
||||
if (inputs_[i].first == nullptr && inputs_[i].second == -1) {
|
||||
for (const Edge* e : node_->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
if (e->dst_input() == i) {
|
||||
return Output(e->src(), e->src_output());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Output(inputs_[i].first, inputs_[i].second);
|
||||
}
|
||||
|
||||
Output Operation::output(int i) const {
|
||||
CHECK_NOTNULL(node_);
|
||||
CHECK_GE(i, 0);
|
||||
CHECK_LT(i, node_->num_outputs());
|
||||
return Output(node_, i);
|
||||
}
|
||||
|
||||
Operation::Inputs Operation::GetInputs(Node* node) {
|
||||
Operation::Inputs inputs;
|
||||
if (node != nullptr) {
|
||||
inputs.resize(node->num_inputs(), {nullptr, -1});
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
inputs[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
Input::Initializer::Initializer(
|
||||
const std::initializer_list<Input::Initializer>& v) {
|
||||
if (v.size() < 1) {
|
||||
|
@ -27,17 +27,29 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
|
||||
class Output;
|
||||
|
||||
// Represents a node in the computation graph.
|
||||
class Operation {
|
||||
public:
|
||||
Operation() : node_(nullptr) {}
|
||||
explicit Operation(Node* n) : node_(n) {}
|
||||
explicit Operation(Node* n);
|
||||
|
||||
int num_inputs() const { return node_->num_inputs(); }
|
||||
DataType input_type(int o) const { return node_->input_type(o); }
|
||||
Output input(int i) const;
|
||||
|
||||
int num_outputs() const { return node_->num_outputs(); }
|
||||
DataType output_type(int o) const { return node_->output_type(o); }
|
||||
Output output(int i) const;
|
||||
|
||||
Node* node() const { return node_; }
|
||||
|
||||
private:
|
||||
typedef std::vector<std::pair<Node*, int64>> Inputs;
|
||||
static Inputs GetInputs(Node* node);
|
||||
|
||||
Inputs inputs_;
|
||||
Node* node_;
|
||||
};
|
||||
|
||||
@ -81,7 +93,7 @@ class Input {
|
||||
tensor = t;
|
||||
}
|
||||
|
||||
explicit Initializer(const Tensor& t) : tensor(t) {}
|
||||
Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit)
|
||||
|
||||
// Construct from a scalar value and an explicit shape
|
||||
template <typename T, typename = typename std::enable_if<
|
||||
|
91
tensorflow/cc/gradients/math_grad.cc
Normal file
91
tensorflow/cc/gradients/math_grad.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace ops {
|
||||
namespace {
|
||||
|
||||
// TODO(andydavis) Move this to a more appropriate file.
|
||||
REGISTER_NO_GRADIENT_OP("Const");
|
||||
|
||||
// MatMulGrad helper function used to compute two MatMul operations
|
||||
// based on input matrix transposition combinations.
|
||||
Status MatMulGradHelper(const Scope& scope, const Output& x0, const bool adj_x0,
|
||||
const Output& x1, const bool adj_x1, const Output& y0,
|
||||
const bool adj_y0, const Output& y1, const bool adj_y1,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
auto dx =
|
||||
MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
|
||||
grad_outputs->push_back(dx);
|
||||
auto dy =
|
||||
MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
|
||||
grad_outputs->push_back(dy);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// MatMulGrad common used to read and check node attr state, and determine
|
||||
// proper MatMul products for gradients based on input matrix transposition
|
||||
// combinations.
|
||||
// TODO(andydavis) Re-use this function for BatchMatMulGrad.
|
||||
Status MatMulGradCommon(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
const string& attr_adj_x, const string& attr_adj_y,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), "T", &dtype));
|
||||
if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
|
||||
return errors::Unimplemented(
|
||||
"MatMul gradient for complex data type is not supported yet.");
|
||||
}
|
||||
|
||||
bool ta;
|
||||
bool tb;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_x, &ta));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(op.output(0).node()->def(), attr_adj_y, &tb));
|
||||
|
||||
if (!ta && !tb) {
|
||||
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), true,
|
||||
op.input(0), true, grad_inputs[0], false,
|
||||
grad_outputs);
|
||||
} else if (!ta && tb) {
|
||||
return MatMulGradHelper(scope, grad_inputs[0], false, op.input(1), false,
|
||||
grad_inputs[0], true, op.input(0), false,
|
||||
grad_outputs);
|
||||
} else if (ta && !tb) {
|
||||
return MatMulGradHelper(scope, op.input(1), false, grad_inputs[0], true,
|
||||
op.input(0), false, grad_inputs[0], false,
|
||||
grad_outputs);
|
||||
}
|
||||
return MatMulGradHelper(scope, op.input(1), true, grad_inputs[0], true,
|
||||
grad_inputs[0], true, op.input(0), true,
|
||||
grad_outputs);
|
||||
}
|
||||
|
||||
Status MatMulGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
return MatMulGradCommon(scope, op, grad_inputs, "transpose_a", "transpose_b",
|
||||
grad_outputs);
|
||||
}
|
||||
|
||||
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
183
tensorflow/cc/gradients/math_grad_test.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/framework/grad_op_registry.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/default_device.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using namespace ops; // NOLINT(build/namespaces)
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(andydavis) Test gradient function against numeric gradients output.
|
||||
// TODO(andydavis) As more gradients are added move common test functions
|
||||
// to a testutil library.
|
||||
class MathGradTest : public ::testing::Test {
|
||||
protected:
|
||||
MathGradTest() : root_(Scope::NewRootScope()) {}
|
||||
|
||||
void ComputeMatMulGrad(const Output& x, const bool t_x, const Output& y,
|
||||
const bool t_y, const Output& dz,
|
||||
std::vector<Tensor>* out) {
|
||||
// Compute forward MatMul: z = MatMul(x, y).
|
||||
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||
TF_EXPECT_OK(root_.status());
|
||||
CHECK_NOTNULL(z.node());
|
||||
std::vector<Output> grad_outputs;
|
||||
// Call MatMulGrad which populates 'grad_outputs'.
|
||||
CallGradFunction(Operation(z.node()), {dz}, &grad_outputs);
|
||||
EXPECT_EQ(2, grad_outputs.size());
|
||||
// Run graph and return MatMul gradient tensors for 'dx' and 'dy' in 'out'.
|
||||
GetTensors(root_, {grad_outputs[0], grad_outputs[1]}, out);
|
||||
}
|
||||
|
||||
void CallGradFunction(const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
GradFunc grad_fn;
|
||||
TF_EXPECT_OK(GradOpRegistry::Global()->Lookup(op.node()->name(), &grad_fn));
|
||||
TF_EXPECT_OK(grad_fn(root_, op, grad_inputs, grad_outputs));
|
||||
TF_EXPECT_OK(root_.status());
|
||||
}
|
||||
|
||||
Tensor ComputeMatMul(const Output& x, const bool t_x, const Output& y,
|
||||
const bool t_y) {
|
||||
auto z = MatMul(root_, x, y, MatMul::TransposeA(t_x).TransposeB(t_y));
|
||||
TF_EXPECT_OK(root_.status());
|
||||
Tensor out;
|
||||
GetTensor(root_, z, &out);
|
||||
return out;
|
||||
}
|
||||
|
||||
void RandMatMulGradData(const bool tx, const bool ty,
|
||||
std::vector<Tensor>* data) {
|
||||
// z = MatMul(x, y)
|
||||
const int m = Rand();
|
||||
const int k = Rand();
|
||||
const int n = Rand();
|
||||
// x.shape = [m, k]
|
||||
const TensorShape x_shape = tx ? TensorShape({k, m}) : TensorShape({m, k});
|
||||
data->emplace_back(DT_FLOAT, x_shape);
|
||||
RandTensor(&data->back());
|
||||
// y.shape = [k, n]
|
||||
const TensorShape y_shape = ty ? TensorShape({n, k}) : TensorShape({k, n});
|
||||
data->emplace_back(DT_FLOAT, y_shape);
|
||||
RandTensor(&data->back());
|
||||
// z.shape = [m, n]
|
||||
data->emplace_back(DT_FLOAT, TensorShape({m, n}));
|
||||
RandTensor(&data->back());
|
||||
}
|
||||
|
||||
void RandTensor(Tensor* t) {
|
||||
test::FillFn<float>(
|
||||
t, [this](const int i) { return static_cast<float>(Rand()); });
|
||||
}
|
||||
|
||||
int Rand() { return 1 + (random::New64() % 10); }
|
||||
|
||||
// TODO(andydavis) Move 'GetTensors/GetTensor' to some testutil class.
|
||||
// Note: they should be moved to a general/non-grad specific testutil class.
|
||||
void GetTensors(const Scope& scope, OutputList tensors,
|
||||
std::vector<Tensor>* out) {
|
||||
SessionOptions options;
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
GraphDef def;
|
||||
scope.graph()->ToGraphDef(&def);
|
||||
|
||||
graph::SetDefaultDevice("/cpu:0", &def);
|
||||
|
||||
TF_CHECK_OK(session->Create(def));
|
||||
std::vector<string> names;
|
||||
for (const auto& t : tensors) {
|
||||
names.push_back(strings::StrCat(t.node()->name(), ":", t.index()));
|
||||
}
|
||||
TF_CHECK_OK(session->Run({}, names, {}, out));
|
||||
TF_CHECK_OK(session->Close());
|
||||
}
|
||||
|
||||
void GetTensor(const Scope& scope, Output tensor, Tensor* out) {
|
||||
std::vector<Tensor> outputs;
|
||||
GetTensors(scope, {tensor}, &outputs);
|
||||
*out = outputs[0];
|
||||
}
|
||||
|
||||
Scope root_;
|
||||
};
|
||||
|
||||
TEST_F(MathGradTest, MatMulGrad_NoTranspose) {
|
||||
std::vector<Tensor> data;
|
||||
RandMatMulGradData(false, false, &data);
|
||||
auto x = Const(root_, data[0]);
|
||||
auto y = Const(root_, data[1]);
|
||||
auto dz = Const(root_, data[2]);
|
||||
|
||||
std::vector<Tensor> grad_outputs;
|
||||
ComputeMatMulGrad(x, false, y, false, dz, &grad_outputs);
|
||||
|
||||
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, true));
|
||||
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, true, dz, false));
|
||||
}
|
||||
|
||||
TEST_F(MathGradTest, MatMulGrad_TransposeX) {
|
||||
std::vector<Tensor> data;
|
||||
RandMatMulGradData(true, false, &data);
|
||||
auto x = Const(root_, data[0]);
|
||||
auto y = Const(root_, data[1]);
|
||||
auto dz = Const(root_, data[2]);
|
||||
|
||||
std::vector<Tensor> grad_outputs;
|
||||
ComputeMatMulGrad(x, true, y, false, dz, &grad_outputs);
|
||||
|
||||
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, false, dz, true));
|
||||
test::ExpectClose(grad_outputs[1], ComputeMatMul(x, false, dz, false));
|
||||
}
|
||||
|
||||
TEST_F(MathGradTest, MatMulGrad_TransposeY) {
|
||||
std::vector<Tensor> data;
|
||||
RandMatMulGradData(false, true, &data);
|
||||
auto x = Const(root_, data[0]);
|
||||
auto y = Const(root_, data[1]);
|
||||
auto dz = Const(root_, data[2]);
|
||||
|
||||
std::vector<Tensor> grad_outputs;
|
||||
ComputeMatMulGrad(x, false, y, true, dz, &grad_outputs);
|
||||
|
||||
test::ExpectClose(grad_outputs[0], ComputeMatMul(dz, false, y, false));
|
||||
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, false));
|
||||
}
|
||||
|
||||
TEST_F(MathGradTest, MatMulGrad_TransposeX_TransposeY) {
|
||||
std::vector<Tensor> data;
|
||||
RandMatMulGradData(true, true, &data);
|
||||
auto x = Const(root_, data[0]);
|
||||
auto y = Const(root_, data[1]);
|
||||
auto dz = Const(root_, data[2]);
|
||||
|
||||
std::vector<Tensor> grad_outputs;
|
||||
ComputeMatMulGrad(x, true, y, true, dz, &grad_outputs);
|
||||
|
||||
test::ExpectClose(grad_outputs[0], ComputeMatMul(y, true, dz, true));
|
||||
test::ExpectClose(grad_outputs[1], ComputeMatMul(dz, true, x, true));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user