sr-rs/src/gp.rs

74 lines
1.9 KiB
Rust

use rand::{seq::SliceRandom, Rng};
use serde::{Deserialize, Serialize};
use crate::ast::{Expr, Node, Op};
use crate::fitness::{complexity, mse};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GpConfig {
pub pop_size: usize,
pub max_depth: usize,
pub mutation_rate: f64,
pub crossover_rate: f64,
pub gens: usize,
pub num_vars: usize,
}
impl Default for GpConfig {
fn default() -> Self {
Self { pop_size: 256, max_depth: 6, mutation_rate: 0.2, crossover_rate: 0.8, gens: 50, num_vars: 1 }
}
}
#[derive(Debug, Clone)]
pub struct Individual { pub expr: Expr, pub err: f64, pub complx: usize }
pub struct SymbolicRegressor { pub cfg: GpConfig }
impl Default for SymbolicRegressor { fn default() -> Self { Self { cfg: GpConfig::default() } } }
impl SymbolicRegressor {
pub fn new(cfg: GpConfig) -> Self { Self { cfg } }
pub fn fit(&self, x: &[Vec<f64>], y: &[f64]) -> Expr {
let mut rng = rand::thread_rng();
let mut pop: Vec<Individual> = (0..self.cfg.pop_size)
.map(|_| Individual { expr: rand_expr(self.cfg.num_vars, self.cfg.max_depth, &mut rng), err: f64::INFINITY, complx: 0 })
.collect();
eval_pop(&mut pop, x, y);
for _gen in 0..self.cfg.gens {
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
let elite = pop[0].clone();
let mut next = vec![elite];
while next.len() < pop.len() {
if rng.gen::<f64>() < self.cfg.crossover_rate {
let (a,b) = tournament2(&pop, &mut rng);
next.push(crossover(&a.expr, &b.expr, &mut rng));
} else if rng.gen::<f64>() < self.cfg.mutation_rate {
let a = &pop[rng.gen_range(0..pop.len())];
next.push(mutate(&a.expr, self.cfg.num_vars, self.cfg.max_depth, &mut rng));
} else {
next.push(pop[rng.gen_range(0..pop.len())].clone());
}
}
pop = next;
eval_pop(&mut pop, x, y);
}
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
pop[0].expr.clone()
}
}
fn eval_pop(pop: &mut [Individual], x: &[Vec<f64>], y: &[f64]) {
pop.iter_mut().for_each(|ind| {
}