74 lines
1.9 KiB
Rust
74 lines
1.9 KiB
Rust
use rand::{seq::SliceRandom, Rng};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
|
use crate::ast::{Expr, Node, Op};
|
|
use crate::fitness::{complexity, mse};
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct GpConfig {
|
|
pub pop_size: usize,
|
|
pub max_depth: usize,
|
|
pub mutation_rate: f64,
|
|
pub crossover_rate: f64,
|
|
pub gens: usize,
|
|
pub num_vars: usize,
|
|
}
|
|
|
|
|
|
impl Default for GpConfig {
|
|
fn default() -> Self {
|
|
Self { pop_size: 256, max_depth: 6, mutation_rate: 0.2, crossover_rate: 0.8, gens: 50, num_vars: 1 }
|
|
}
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Individual { pub expr: Expr, pub err: f64, pub complx: usize }
|
|
|
|
|
|
pub struct SymbolicRegressor { pub cfg: GpConfig }
|
|
|
|
|
|
impl Default for SymbolicRegressor { fn default() -> Self { Self { cfg: GpConfig::default() } } }
|
|
|
|
|
|
impl SymbolicRegressor {
|
|
pub fn new(cfg: GpConfig) -> Self { Self { cfg } }
|
|
|
|
|
|
pub fn fit(&self, x: &[Vec<f64>], y: &[f64]) -> Expr {
|
|
let mut rng = rand::thread_rng();
|
|
let mut pop: Vec<Individual> = (0..self.cfg.pop_size)
|
|
.map(|_| Individual { expr: rand_expr(self.cfg.num_vars, self.cfg.max_depth, &mut rng), err: f64::INFINITY, complx: 0 })
|
|
.collect();
|
|
eval_pop(&mut pop, x, y);
|
|
for _gen in 0..self.cfg.gens {
|
|
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
|
|
let elite = pop[0].clone();
|
|
let mut next = vec![elite];
|
|
while next.len() < pop.len() {
|
|
if rng.gen::<f64>() < self.cfg.crossover_rate {
|
|
let (a,b) = tournament2(&pop, &mut rng);
|
|
next.push(crossover(&a.expr, &b.expr, &mut rng));
|
|
} else if rng.gen::<f64>() < self.cfg.mutation_rate {
|
|
let a = &pop[rng.gen_range(0..pop.len())];
|
|
next.push(mutate(&a.expr, self.cfg.num_vars, self.cfg.max_depth, &mut rng));
|
|
} else {
|
|
next.push(pop[rng.gen_range(0..pop.len())].clone());
|
|
}
|
|
}
|
|
pop = next;
|
|
eval_pop(&mut pop, x, y);
|
|
}
|
|
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
|
|
pop[0].expr.clone()
|
|
}
|
|
}
|
|
|
|
|
|
fn eval_pop(pop: &mut [Individual], x: &[Vec<f64>], y: &[f64]) {
|
|
pop.iter_mut().for_each(|ind| {
|
|
}
|