fixing style nits in gradient checker and gradients_util

This commit is contained in:
amturati 2020-09-04 18:09:36 +00:00
parent 4ffd477305
commit 0ef9935e9f
4 changed files with 14 additions and 33 deletions

View File

@ -106,10 +106,9 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
// ========================= End Helper Functions==============================
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad) {
GradientRegistry registry;
AbstractTensorHandle* theta =
inputs[input_index]; // parameter we are grad checking
@ -123,15 +122,15 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
memcpy(&theta_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
// Initialize space for the numerical gradient
// Initialize space for the numerical gradient.
float dtheta_approx[num_elems];
// Get theta shape and store in theta_dims
// Get theta shape and store in theta_dims.
int num_dims = TF_NumDims(theta_tensor);
int64_t theta_dims[num_dims];
GetDims(theta_tensor, theta_dims);
// Initialize auxilary data structures
// Initialize auxilary data structures.
float thetaPlus_data[num_elems];
float thetaMinus_data[num_elems];
std::vector<AbstractTensorHandle*> f_outputs(1);
@ -160,13 +159,13 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
// Get f(theta + eps):
inputs[input_index] = thetaPlus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fPlus = f_outputs[0];
// Get f(theta - eps):
inputs[input_index] = thetaMinus.get();
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs,
absl::MakeSpan(f_outputs), use_function));
AbstractTensorHandle* fMinus = f_outputs[0];
@ -191,7 +190,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
dtheta_approx[i] = grad_data[0];
}
// Populate *numerical_grad with the data from dtheta_approx
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(ctx, dtheta_approx, theta_dims,
num_dims, numerical_grad));
return Status::OK();

View File

@ -33,10 +33,6 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
/* Returns numerical grad inside `dtheta_approx` given `forward` model and
* parameter specified by `input_index`.
*
@ -49,7 +45,7 @@ using Model = std::function<Status(
* hold the numerical gradient data at the end of the function.
*/
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs,
absl::Span<AbstractTensorHandle*> inputs,
int input_index, bool use_function,
AbstractTensorHandle** numerical_grad);

View File

@ -81,7 +81,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
AbstractTensorHandle* grad_approx;
Status s = CalcNumericalGrad(
ctx.get(), MatMulModel, inputs, /*input_index=*/0,
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
@ -136,7 +136,8 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
float dapprox[1] = {0};
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, inputs, /*input_index=*/0,
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
@ -213,7 +214,8 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
// Run numerical gradient approximation using the GradientChecker API.
AbstractTensorHandle* g; // Will contain numerical approximation data.
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, inputs, /*input_index=*/0,
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();

View File

@ -35,15 +35,6 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
@ -79,14 +70,6 @@ Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
// Helper function for RunModel to build the function for graph mode.
AbstractContext* BuildFunction(const char* fn_name);
// Helper function for RunModel to add params for graph mode.
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
@ -98,6 +81,7 @@ Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
// Builds context and returns inside *ctx.
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients