fixing style nits in gradient checker and gradients_util
This commit is contained in:
parent
4ffd477305
commit
0ef9935e9f
@ -106,10 +106,9 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
||||
// ========================= End Helper Functions==============================
|
||||
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
absl::Span<AbstractTensorHandle*> inputs,
|
||||
int input_index, bool use_function,
|
||||
AbstractTensorHandle** numerical_grad) {
|
||||
GradientRegistry registry;
|
||||
AbstractTensorHandle* theta =
|
||||
inputs[input_index]; // parameter we are grad checking
|
||||
|
||||
@ -123,15 +122,15 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
memcpy(&theta_data[0], TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
|
||||
// Initialize space for the numerical gradient
|
||||
// Initialize space for the numerical gradient.
|
||||
float dtheta_approx[num_elems];
|
||||
|
||||
// Get theta shape and store in theta_dims
|
||||
// Get theta shape and store in theta_dims.
|
||||
int num_dims = TF_NumDims(theta_tensor);
|
||||
int64_t theta_dims[num_dims];
|
||||
GetDims(theta_tensor, theta_dims);
|
||||
|
||||
// Initialize auxilary data structures
|
||||
// Initialize auxilary data structures.
|
||||
float thetaPlus_data[num_elems];
|
||||
float thetaMinus_data[num_elems];
|
||||
std::vector<AbstractTensorHandle*> f_outputs(1);
|
||||
@ -160,13 +159,13 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
|
||||
// Get f(theta + eps):
|
||||
inputs[input_index] = thetaPlus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
|
||||
absl::MakeSpan(f_outputs), use_function));
|
||||
AbstractTensorHandle* fPlus = f_outputs[0];
|
||||
|
||||
// Get f(theta - eps):
|
||||
inputs[input_index] = thetaMinus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
|
||||
absl::MakeSpan(f_outputs), use_function));
|
||||
AbstractTensorHandle* fMinus = f_outputs[0];
|
||||
|
||||
@ -191,7 +190,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
dtheta_approx[i] = grad_data[0];
|
||||
}
|
||||
|
||||
// Populate *numerical_grad with the data from dtheta_approx
|
||||
// Populate *numerical_grad with the data from dtheta_approx.
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(ctx, dtheta_approx, theta_dims,
|
||||
num_dims, numerical_grad));
|
||||
return Status::OK();
|
||||
|
@ -33,10 +33,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
|
||||
* parameter specified by `input_index`.
|
||||
*
|
||||
@ -49,7 +45,7 @@ using Model = std::function<Status(
|
||||
* hold the numerical gradient data at the end of the function.
|
||||
*/
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
absl::Span<AbstractTensorHandle*> inputs,
|
||||
int input_index, bool use_function,
|
||||
AbstractTensorHandle** numerical_grad);
|
||||
|
||||
|
@ -81,7 +81,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
|
||||
AbstractTensorHandle* grad_approx;
|
||||
Status s = CalcNumericalGrad(
|
||||
ctx.get(), MatMulModel, inputs, /*input_index=*/0,
|
||||
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
@ -136,7 +136,8 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
||||
float dapprox[1] = {0};
|
||||
AbstractTensorHandle* g;
|
||||
|
||||
Status s = CalcNumericalGrad(ctx.get(), MulModel, inputs, /*input_index=*/0,
|
||||
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &g);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
@ -213,7 +214,8 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
|
||||
|
||||
// Run numerical gradient approximation using the GradientChecker API.
|
||||
AbstractTensorHandle* g; // Will contain numerical approximation data.
|
||||
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, inputs, /*input_index=*/0,
|
||||
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()), &g);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
|
@ -35,15 +35,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx,
|
||||
float data[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
@ -79,14 +70,6 @@ Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
// Helper function for RunModel to build the function for graph mode.
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
// Helper function for RunModel to add params for graph mode.
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
@ -98,6 +81,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Builds context and returns inside *ctx.
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace gradients
|
||||
|
Loading…
Reference in New Issue
Block a user