Function rustml::opt::opt [] [src]

pub fn opt<O, D>(f: &O, fd: &D, init: &[f64], opts: OptParams<f64>) -> OptResult<f64> where O: Fn(&[f64]) -> f64, D: Fn(&[f64]) -> Vec<f64>

Minimizes an objective using gradient descent.

The objective f is minimized using a standard gradient descent algorithm. The argument fd must return the values of the derivatives for each parameter and is executed in each iteration for the current parameters. The argument init contains the initial parameters and opts contains the options for the gradient descent algorithm.

If $f(\theta_0, \dots, \theta_n)$ is the objective that is to be minimized with the parameters $\theta_0, \dots, \theta_n$ the algorithm works as follows:

$[\theta_0, \dots, \theta_n] \leftarrow$ init
$\alpha \leftarrow$ opts.alpha
$\epsilon \leftarrow$ opts.epsilon

for i = 1 to opts.iter do
$tmp \leftarrow [\theta_0, \dots, \theta_n] - \alpha \left[ \frac{d}{\partial \theta_0} f(\theta_0, \dots, \theta_n) , \dots, \frac{d}{\partial \theta_n} f(\theta_0, \dots, \theta_n) \right]$
if for all $|{tmp}_i - \theta_i| \leq \epsilon \rightarrow$ stop
$[\theta_0, \dots, \theta_n] \leftarrow tmp$

done

The vector $\left[ \frac{d}{\partial \theta_0} f(\theta_0, \dots, \theta_n) , \dots, \frac{d}{\partial \theta_n} f(\theta_0, \dots, \theta_n) \right]$ needs to be returned by fd.

If alpha is not specified in opts the value 0.1 is used. If the number of iterations is not specified in opts the value 1000 is used. If epsilon is not specified in opts no stopping criterion is checked.

Example

use rustml::opt::*;
use num::pow;

// set the number of iterations to 10
let opts = empty_opts().iter(10);

let r = opt(
    &|p| pow(p[0] - 2.0, 2),       // objective to be minimized: (x-2)^2
    &|p| vec![2.0 * (p[0] - 2.0)], // derivative
    &[4.0],                        // initial parameters
    opts                           // optimization options
);

for (iter, i) in r.fvals.iter().enumerate() {
    println!("error after iteration {} was {}", iter + 1, i.1);
}
println!("solution: {:?}", r.params);
assert!(r.params[0] - 2.0 <= 0.3);