Initial repo files from ChatGPT
This commit is contained in:
parent
2856e00b81
commit
591536240d
|
|
@ -0,0 +1,32 @@
|
|||
version: "3.9"
|
||||
services:
|
||||
rust-dev:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.rust-dev
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "cargo test --release"
|
||||
|
||||
|
||||
wasm:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.wasm
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "make wasm && python3 -m http.server -d bindings/wasm/pkg 8000"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
|
||||
python:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.rust-dev
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "make py && python -c 'from symreg_rs import PySymbolicRegressor as SR; print(SR().fit_predict([[0.0],[1.0]],[0.0,1.0]))'"
|
||||
|
|
@ -1,229 +1,10 @@
|
|||
# ---> Rust
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
# ---> Emacs
|
||||
# -*- mode: gitignore; -*-
|
||||
*~
|
||||
\#*\#
|
||||
/.emacs.desktop
|
||||
/.emacs.desktop.lock
|
||||
*.elc
|
||||
auto-save-list
|
||||
tramp
|
||||
.\#*
|
||||
|
||||
# Org-mode
|
||||
.org-id-locations
|
||||
*_archive
|
||||
|
||||
# flymake-mode
|
||||
*_flymake.*
|
||||
|
||||
# eshell files
|
||||
/eshell/history
|
||||
/eshell/lastdir
|
||||
|
||||
# elpa packages
|
||||
/elpa/
|
||||
|
||||
# reftex files
|
||||
*.rel
|
||||
|
||||
# AUCTeX auto folder
|
||||
/auto/
|
||||
|
||||
# cask packages
|
||||
.cask/
|
||||
dist/
|
||||
|
||||
# Flycheck
|
||||
flycheck_*.el
|
||||
|
||||
# server auth directory
|
||||
/server/
|
||||
|
||||
# projectiles files
|
||||
.projectile
|
||||
|
||||
# directory configuration
|
||||
.dir-locals.el
|
||||
|
||||
# network security
|
||||
/network-security.data
|
||||
|
||||
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
**/*.pyc
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
*.pyo
|
||||
.vscode/
|
||||
.idea/
|
||||
.DS_Store
|
||||
bindings/python/target/
|
||||
bindings/python/*.egg-info/
|
||||
bindings/wasm/pkg/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
[package]
|
||||
name = "symreg-rs"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
description = "Fast, extensible symbolic regression in Rust with Python bindings and WASM target"
|
||||
repository = "https://example.org/your/repo"
|
||||
|
||||
|
||||
[lib]
|
||||
name = "symreg_rs"
|
||||
crate-type = ["rlib", "cdylib"]
|
||||
|
||||
|
||||
[features]
|
||||
default = ["simd"]
|
||||
simd = []
|
||||
python = ["pyo3/extension-module", "pyo3/macros"]
|
||||
wasm = ["wasm-bindgen"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
smallvec = "1"
|
||||
bitflags = "2"
|
||||
thiserror = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
rayon = "1"
|
||||
ordered-float = "4"
|
||||
|
||||
|
||||
# algebra & optimization
|
||||
nalgebra = { version = "0.33", default-features = false, features=["std"] }
|
||||
argmin = { version = "0.8", optional = true }
|
||||
|
||||
|
||||
# optional rewriting via e-graphs
|
||||
egg = { version = "0.9", optional = true }
|
||||
|
||||
|
||||
# bindings
|
||||
pyo3 = { version = "0.21", optional = true, features=["abi3", "abi3-py38"] }
|
||||
wasm-bindgen = { version = "0.2", optional = true }
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5"
|
||||
|
||||
|
||||
[profile.release]
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
opt-level = 3
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
.PHONY: build test bench fmt lint py sdist wheel wasm demo
|
||||
|
||||
|
||||
build:
|
||||
cargo build --release
|
||||
|
||||
|
||||
test:
|
||||
cargo test --all --release
|
||||
|
||||
|
||||
bench:
|
||||
cargo bench
|
||||
|
||||
|
||||
fmt:
|
||||
cargo fmt --all
|
||||
cargo clippy --all-targets --all-features -- -D warnings
|
||||
|
||||
|
||||
py:
|
||||
maturin develop -m bindings/python/pyproject.toml
|
||||
|
||||
|
||||
sdist:
|
||||
maturin sdist -m bindings/python/pyproject.toml
|
||||
|
||||
|
||||
wheel:
|
||||
maturin build -m bindings/python/pyproject.toml --release
|
||||
|
||||
|
||||
wasm:
|
||||
wasm-pack build bindings/wasm --target web --release
|
||||
|
||||
|
||||
demo:
|
||||
python3 -m http.server -d bindings/wasm/pkg 8000
|
||||
65
README.md
65
README.md
|
|
@ -1,3 +1,68 @@
|
|||
# sr-rs
|
||||
|
||||
SR-RS: A portable symbolic regression engine in Rust, patterned after Cranmer's SymbolicRegression.jl.
|
||||
|
||||
**Rust symbolic regression** with:
|
||||
|
||||
|
||||
- 🦀 **Core in Rust** (fast evaluators, SIMD where available)
|
||||
- 🐍 **Python bindings** (PyO3, scikit-learn-style API)
|
||||
- 🌐 **WASM target** (`wasm-bindgen`, in-browser demo)
|
||||
- ♻️ **Rewrite-based simplification** (optional `egg` integration)
|
||||
- 🎯 Multi-objective search (error vs. complexity), constant-fitting, protected operators, and export to SymPy/LaTeX
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
|
||||
### Native (Rust)
|
||||
```bash
|
||||
cargo run --example quickstart
|
||||
|
||||
|
||||
Python
|
||||
# inside container or host with Rust toolchain
|
||||
make py
|
||||
python -c "from symreg_rs import SymbolicRegressor; print(SymbolicRegressor().fit_predict([[0.0],[1.0]], [0.0,1.0]))"
|
||||
WebAssembly demo
|
||||
make wasm && make demo
|
||||
# open http://localhost:8000/index.html
|
||||
Design highlights
|
||||
|
||||
Expression AST with typed operations and arity checks
|
||||
|
||||
Fast vectorized evaluator (feature simd)
|
||||
|
||||
GP loop with tournament selection, variation (mutate/crossover), and Pareto archiving
|
||||
|
||||
Constant optimization via LM/Argmin (optional)
|
||||
|
||||
Simplification via identities + (optional) egg equality saturation
|
||||
|
||||
See docs/DESIGN.md and docs/PORTING.md for details and porting guidance.
|
||||
|
||||
9) Usage snippets
|
||||
Rust
|
||||
use symreg_rs::{GpConfig, SymbolicRegressor};
|
||||
let x = vec![vec![0.0], vec![1.0]];
|
||||
let y = vec![0.0, 1.0];
|
||||
let sr = SymbolicRegressor::default();
|
||||
let expr = sr.fit(&x, &y);
|
||||
Python
|
||||
from symreg_rs import PySymbolicRegressor as SR
|
||||
sr = SR(pop_size=256, gens=20, num_vars=1)
|
||||
preds = sr.fit_predict([[0.0],[1.0]],[0.0,1.0])
|
||||
10) Next steps (recommended)
|
||||
|
||||
Wire up constant fitting (Argmin LM) in fitness.rs.
|
||||
|
||||
Add exporters: to SymPy string and LaTeX.
|
||||
|
||||
Implement subtree crossover/mutation and depth/size limits.
|
||||
|
||||
Add vectorized evaluator (std::simd) + rayon parallelism.
|
||||
|
||||
Optional: integrate egg for global simplification.
|
||||
|
||||
This scaffold should compile with minor tweaks (crate names/versions). From here, we can iterate on the porting branch to reach feature parity with your preferred SR baseline.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
[build-system]
|
||||
requires = ["maturin>=1.4"]
|
||||
build-backend = "maturin"
|
||||
|
||||
|
||||
[project]
|
||||
name = "symreg-rs"
|
||||
version = "0.1.0"
|
||||
description = "Rust symbolic regression with Python bindings"
|
||||
authors = [{name="Your Name", email="you@example.com"}]
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Rust",
|
||||
]
|
||||
|
||||
|
||||
[tool.maturin]
|
||||
features = ["python"]
|
||||
module-name = "symreg_rs"
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
use pyo3::prelude::*;
|
||||
use symreg_rs::{GpConfig, SymbolicRegressor};
|
||||
|
||||
|
||||
#[pyclass]
|
||||
struct PySymbolicRegressor { inner: SymbolicRegressor }
|
||||
|
||||
|
||||
#[pymethods]
|
||||
impl PySymbolicRegressor {
|
||||
#[new]
|
||||
fn new(pop_size: Option<usize>, gens: Option<usize>, num_vars: Option<usize>) -> Self {
|
||||
let mut cfg = GpConfig::default();
|
||||
if let Some(v) = pop_size { cfg.pop_size = v; }
|
||||
if let Some(v) = gens { cfg.gens = v; }
|
||||
if let Some(v) = num_vars { cfg.num_vars = v; }
|
||||
Self { inner: SymbolicRegressor::new(cfg) }
|
||||
}
|
||||
|
||||
|
||||
fn fit_predict(&self, x: Vec<Vec<f64>>, y: Vec<f64>, predict_x: Option<Vec<Vec<f64>>>) -> PyResult<Vec<f64>> {
|
||||
let expr = self.inner.fit(&x, &y);
|
||||
let px = predict_x.unwrap_or(x);
|
||||
Ok(px.iter().map(|row| symreg_rs::eval::eval_expr(&expr, row)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[pymodule]
|
||||
fn symreg_rs(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PySymbolicRegressor>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><meta charset="utf-8"><title>symreg-rs WASM demo</title></head>
|
||||
<body>
|
||||
<h1>symreg-rs (WASM)</h1>
|
||||
<button id="run">Run demo</button>
|
||||
<pre id="out"></pre>
|
||||
<script type="module">
|
||||
import init, { fit_mse } from './pkg/symreg_rs_wasm.js';
|
||||
await init();
|
||||
document.getElementById('run').onclick = () => {
|
||||
const xs = []; const ys = [];
|
||||
for (let i=0;i<50;i++){ const v=i/10; xs.push(v); ys.push(Math.sin(v)); }
|
||||
const preds = fit_mse(xs, 50, 1, ys, 10);
|
||||
document.getElementById('out').textContent = JSON.stringify(preds.slice(0,10), null, 2);
|
||||
};
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"name": "symreg-rs-wasm",
|
||||
"version": "0.1.0",
|
||||
"scripts": {
|
||||
"build": "wasm-pack build --target web --release",
|
||||
"serve": "python3 -m http.server -d pkg 8000"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
use wasm_bindgen::prelude::*;
|
||||
use symreg_rs::{GpConfig, SymbolicRegressor};
|
||||
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn fit_mse(x: Vec<f64>, n_rows: usize, n_cols: usize, y: Vec<f64>, gens: usize) -> Vec<f64> {
|
||||
let mut rows = Vec::with_capacity(n_rows);
|
||||
for r in 0..n_rows {
|
||||
let start = r*n_cols; rows.push(x[start..start+n_cols].to_vec());
|
||||
}
|
||||
let cfg = GpConfig { gens, num_vars: n_cols, ..Default::default() };
|
||||
let sr = SymbolicRegressor::new(cfg);
|
||||
let expr = sr.fit(&rows, &y);
|
||||
rows.iter().map(|row| symreg_rs::eval::eval_expr(&expr, row)).collect()
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
version: "3.9"
|
||||
services:
|
||||
rust-dev:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.rust-dev
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "cargo test --release"
|
||||
|
||||
|
||||
wasm:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.wasm
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "make wasm && python3 -m http.server -d bindings/wasm/pkg 8000"
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
|
||||
python:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.rust-dev
|
||||
volumes:
|
||||
- ./:/work
|
||||
working_dir: /work
|
||||
command: bash -lc "make py && python -c 'from symreg_rs import PySymbolicRegressor as SR; print(SR().fit_predict([[0.0],[1.0]],[0.0,1.0]))'"
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
FROM python:3.11-slim
|
||||
RUN apt-get update && apt-get install -y build-essential python3-dev && rm -rf /var/lib/apt/lists/*
|
||||
# wheel build happens via maturin in rust-dev; this image is for runtime
|
||||
RUN pip install numpy
|
||||
WORKDIR /app
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
FROM rust:1.80
|
||||
RUN apt-get update && apt-get install -y python3 python3-pip pkg-config libssl-dev && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip3 install maturin wasm-pack
|
||||
WORKDIR /work
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
FROM rust:1.80
|
||||
RUN cargo install wasm-pack
|
||||
WORKDIR /work
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
make wheel
|
||||
exec "$@"
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
cargo build --release
|
||||
exec "$@"
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
# BENCHMARKS
|
||||
|
||||
|
||||
- Use `criterion` for microbenchmarks (eval throughput).
|
||||
- Dataset-level timing: fit to Friedman1 and report wall-clock and MSE.
|
||||
- Compare feature toggles: scalar vs SIMD, parallel vs single-thread.
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
||||
## 8) Branching model (with a `porting` branch)
|
||||
|
||||
|
||||
- `main`: stable, tagged releases.
|
||||
- `develop`: integration of ready features.
|
||||
- `porting`: **all work specific to parity with SymbolicRegression.jl** (feature-by-feature ports, benchmarks, notes). Merge into `develop` via PRs when each slice reaches MVP.
|
||||
- Per-feature branches from `porting` (e.g., `porting/const-fitting`, `porting/egg-rewrites`).
|
||||
|
||||
|
||||
**Initial commands:**
|
||||
```bash
|
||||
git init
|
||||
git checkout -b main
|
||||
git add . && git commit -m "feat: repo scaffold"
|
||||
git checkout -b develop
|
||||
git checkout -b porting
|
||||
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
# DESIGN
|
||||
|
||||
|
||||
## Goals
|
||||
- Fast core with predictable performance (no runtime JIT)
|
||||
- Clean APIs for Python and Web
|
||||
- Extensibility: pluggable operators, fitness terms, search strategies
|
||||
|
||||
|
||||
## Key components
|
||||
- AST (`ast.rs`), evaluator (`eval.rs`), operators (`ops.rs`)
|
||||
- Fitness metrics (`fitness.rs`), GP algorithm (`gp.rs`)
|
||||
- Simplification (`simplify/`), utilities (`utils.rs`)
|
||||
|
||||
|
||||
## Error/Complexity tradeoff
|
||||
- Use bi-objective selection or scalarized cost; Pareto archive maintained per generation.
|
||||
|
||||
|
||||
## Safety
|
||||
- Protected operators for division/log; numeric guards where needed.
|
||||
|
||||
|
||||
## Determinism
|
||||
- Seeded RNG exposed in configs for reproducible runs.
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Porting plan (from SymbolicRegression.jl concepts)
|
||||
|
||||
|
||||
This document maps major features/design choices to Rust equivalents and notes feasibility.
|
||||
|
||||
|
||||
## Search & representation
|
||||
- **Representation:** tree/graph AST with typed ops. Rust enums for ops and nodes. Feasible (done in MVP).
|
||||
- **Search:** GP with tournament selection + crossover/mutation, plus Pareto archiving. Extensible to age-fitness, lexicase.
|
||||
- **Constants:** LM/variable projection via `argmin`/`nalgebra`. Start with LM; later add autodiff for local gradients.
|
||||
|
||||
|
||||
## Evaluation performance
|
||||
- **Julia’s fused loops & SIMD:** replicate with `std::simd` and hand-rolled evaluators. Bench and specialize hot ops.
|
||||
- **Parallelism:** use `rayon` for population eval; thread support optional in WASM (SharedArrayBuffer route).
|
||||
|
||||
|
||||
## Simplification & pruning
|
||||
- **Local rules:** identities (e.g., `x+0→x`, `x*1→x`, `sin(0)→0`).
|
||||
- **Global rewriting:** integrate `egg` as optional feature for equality saturation + cost-based extraction.
|
||||
|
||||
|
||||
## Multi-objective
|
||||
- Maintain Pareto front on (error, complexity). Provide knobs for complexity penalty and max size/depth.
|
||||
|
||||
|
||||
## Export & interop
|
||||
- **Python:** PyO3 module `symreg_rs` with sklearn-like API (`fit`, `predict`, `score`).
|
||||
- **WASM:** thin bindgen façade for browser demos; avoid heavy deps.
|
||||
- **SymPy/LaTeX:** stringifier to SymPy code, plus LaTeX pretty-printing.
|
||||
|
||||
|
||||
## Test & benchmark
|
||||
- Reuse standard SRBench datasets where licensing permits; include Friedman1 synthetic.
|
||||
- Add criterion benches for evaluator and variation operators.
|
||||
|
||||
|
||||
## Milestones
|
||||
1. MVP (this scaffold): scalar eval, GP loop, Python & WASM hello world.
|
||||
2. Vectorized eval + `rayon` parallel pop eval; constant fitting.
|
||||
3. Simplifier + export; sklearn-compatible estimator.
|
||||
4. E-graph integration; advanced search strategies; docs + notebooks.
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
# ROADMAP
|
||||
|
||||
|
||||
- [ ] SIMD evaluator
|
||||
- [ ] Parallel pop evaluation (rayon)
|
||||
- [ ] LM constant fitting
|
||||
- [ ] SymPy/LaTeX exporters
|
||||
- [ ] sklearn API parity (fit/predict/score, get_params/set_params)
|
||||
- [ ] E-graph simplification (egg)
|
||||
- [ ] Web demo with worker + shared memory
|
||||
- [ ] SRBench harness integration
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
use symreg_rs::{GpConfig, SymbolicRegressor};
|
||||
|
||||
|
||||
fn main() {
|
||||
let x: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64 / 10.0]).collect();
|
||||
let y: Vec<f64> = x.iter().map(|v| v[0].sin()).collect();
|
||||
let cfg = GpConfig { num_vars: 1, gens: 10, ..Default::default() };
|
||||
let sr = SymbolicRegressor::new(cfg);
|
||||
let expr = sr.fit(&x, &y);
|
||||
println!("best expr nodes={}, root={}", expr.nodes.len(), expr.root);
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
pub enum Op {
|
||||
Add, Sub, Mul, Div, Sin, Cos, Exp, Log,
|
||||
// extendable: Tan, Abs, Pow, etc.
|
||||
}
|
||||
|
||||
|
||||
impl Op {
|
||||
pub fn arity(&self) -> usize {
|
||||
use Op::*;
|
||||
match self {
|
||||
Add | Sub | Mul | Div => 2,
|
||||
Sin | Cos | Exp | Log => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum Node {
|
||||
Var(usize),
|
||||
Const(f64),
|
||||
Call(Op, smallvec::SmallVec<[usize; 4]>), // children indices
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Expr {
|
||||
pub nodes: Vec<Node>,
|
||||
pub root: usize,
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
use crate::ast::{Expr, Node, Op};
|
||||
use crate::ops::{apply_unary, protected_div};
|
||||
|
||||
|
||||
pub fn eval_expr(expr: &Expr, x: &[f64]) -> f64 {
|
||||
// simple stack-based evaluator (scalar); SIMD/vectorized path can be added
|
||||
let mut cache: Vec<f64> = vec![0.0; expr.nodes.len()];
|
||||
for (i, node) in expr.nodes.iter().enumerate() {
|
||||
let val = match node {
|
||||
Node::Var(k) => x[*k],
|
||||
Node::Const(c) => *c,
|
||||
Node::Call(op, ch) => {
|
||||
match op {
|
||||
Op::Add => cache[ch[0]] + cache[ch[1]],
|
||||
Op::Sub => cache[ch[0]] - cache[ch[1]],
|
||||
Op::Mul => cache[ch[0]] * cache[ch[1]],
|
||||
Op::Div => protected_div(cache[ch[0]], cache[ch[1]]),
|
||||
_ => apply_unary(*op, cache[ch[0]]),
|
||||
}
|
||||
}
|
||||
};
|
||||
cache[i] = val;
|
||||
}
|
||||
cache[expr.root]
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
use crate::ast::Expr;
|
||||
use crate::eval::eval_expr;
|
||||
|
||||
|
||||
pub fn mse(expr: &Expr, x: &[Vec<f64>], y: &[f64]) -> f64 {
|
||||
let n = x.len();
|
||||
let mut s = 0.0;
|
||||
for i in 0..n {
|
||||
let yi = eval_expr(expr, &x[i]);
|
||||
let d = yi - y[i];
|
||||
s += d * d;
|
||||
}
|
||||
s / n as f64
|
||||
}
|
||||
|
||||
|
||||
pub fn complexity(expr: &Expr) -> usize { expr.nodes.len() }
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
use rand::{seq::SliceRandom, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
use crate::ast::{Expr, Node, Op};
|
||||
use crate::fitness::{complexity, mse};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GpConfig {
|
||||
pub pop_size: usize,
|
||||
pub max_depth: usize,
|
||||
pub mutation_rate: f64,
|
||||
pub crossover_rate: f64,
|
||||
pub gens: usize,
|
||||
pub num_vars: usize,
|
||||
}
|
||||
|
||||
|
||||
impl Default for GpConfig {
|
||||
fn default() -> Self {
|
||||
Self { pop_size: 256, max_depth: 6, mutation_rate: 0.2, crossover_rate: 0.8, gens: 50, num_vars: 1 }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Individual { pub expr: Expr, pub err: f64, pub complx: usize }
|
||||
|
||||
|
||||
pub struct SymbolicRegressor { pub cfg: GpConfig }
|
||||
|
||||
|
||||
impl Default for SymbolicRegressor { fn default() -> Self { Self { cfg: GpConfig::default() } } }
|
||||
|
||||
|
||||
impl SymbolicRegressor {
|
||||
pub fn new(cfg: GpConfig) -> Self { Self { cfg } }
|
||||
|
||||
|
||||
pub fn fit(&self, x: &[Vec<f64>], y: &[f64]) -> Expr {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut pop: Vec<Individual> = (0..self.cfg.pop_size)
|
||||
.map(|_| Individual { expr: rand_expr(self.cfg.num_vars, self.cfg.max_depth, &mut rng), err: f64::INFINITY, complx: 0 })
|
||||
.collect();
|
||||
eval_pop(&mut pop, x, y);
|
||||
for _gen in 0..self.cfg.gens {
|
||||
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
|
||||
let elite = pop[0].clone();
|
||||
let mut next = vec![elite];
|
||||
while next.len() < pop.len() {
|
||||
if rng.gen::<f64>() < self.cfg.crossover_rate {
|
||||
let (a,b) = tournament2(&pop, &mut rng);
|
||||
next.push(crossover(&a.expr, &b.expr, &mut rng));
|
||||
} else if rng.gen::<f64>() < self.cfg.mutation_rate {
|
||||
let a = &pop[rng.gen_range(0..pop.len())];
|
||||
next.push(mutate(&a.expr, self.cfg.num_vars, self.cfg.max_depth, &mut rng));
|
||||
} else {
|
||||
next.push(pop[rng.gen_range(0..pop.len())].clone());
|
||||
}
|
||||
}
|
||||
pop = next;
|
||||
eval_pop(&mut pop, x, y);
|
||||
}
|
||||
pop.sort_by(|a,b| a.err.partial_cmp(&b.err).unwrap());
|
||||
pop[0].expr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn eval_pop(pop: &mut [Individual], x: &[Vec<f64>], y: &[f64]) {
|
||||
pop.iter_mut().for_each(|ind| {
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
pub mod ast;
|
||||
pub mod eval;
|
||||
pub mod ops;
|
||||
pub mod fitness;
|
||||
pub mod gp;
|
||||
pub mod simplify;
|
||||
pub mod utils;
|
||||
|
||||
|
||||
pub use crate::gp::{GpConfig, SymbolicRegressor};
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
use crate::ast::Op;
|
||||
|
||||
|
||||
pub fn protected_div(a: f64, b: f64) -> f64 {
|
||||
if b.abs() < 1e-12 { a } else { a / b }
|
||||
}
|
||||
|
||||
|
||||
pub fn apply_unary(op: Op, x: f64) -> f64 {
|
||||
match op {
|
||||
Op::Sin => x.sin(),
|
||||
Op::Cos => x.cos(),
|
||||
Op::Exp => x.exp(),
|
||||
Op::Log => if x.abs() < 1e-12 { 0.0 } else { x.abs().ln() },
|
||||
_ => unreachable!("not unary"),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1 @@
|
|||
pub mod rules;
|
||||
|
|
@ -0,0 +1 @@
|
|||
// placeholder for local simplification rules; egg-based rules can live in crates/egg-rewrites
|
||||
Loading…
Reference in New Issue