Extending The Benchmark#
Plugging in new datasets or surrogates should not require rewriting orchestration code. This guide captures the supported extension points and the conventions that keep everything interoperable.
Add a dataset#
Use datasets/_data_generation/make_new_dataset.py as a working example. It shows how to generate synthetic trajectories, parameters, and timesteps before calling codes.create_dataset.
Shape your arrays
data:(n_samples, n_timesteps, n_quantities)— the raw trajectories.params(optional):(n_samples, n_parameters)— per-trajectory parameters.timesteps:(n_timesteps,)— shared timeline for every trajectory.labels(optional): list of quantity names with lengthn_quantities.
Call
create_datasetfrom codes import create_dataset create_dataset( "my_new_dataset", data=full_dataset, params=full_params, timesteps=timesteps, labels=labels, split=(0.7, 0.1, 0.2), # train/test/val ratios )
create_datasetwritesdatasets/my_new_dataset/data.hdf5with the following groups:train,test,val, optional*_params, andtimesteps. The helper also ensures consistent shuffling and folder creation.Register the download link (optional) in
datasets/data_sources.yamlsoscripts/download_datasets.pyknows where to fetch the data. The docs pull this file automatically, so your dataset will appear in the catalog without extra work.Reference the dataset in configs via
dataset.name. Log transforms / normalization flags in the config can stay unchanged unless your data needs special treatment.
Once the folder exists, all CLI entry points (run_training.py, run_eval.py, run_tuning.py) will automatically pick it up based on dataset.name.
Add a surrogate model#
Every surrogate must inherit from codes.surrogates.AbstractSurrogate.AbstractSurrogateModel. This class wires together data preparation, training, logging, checkpointing, and evaluation.
Implement the class under
codes/surrogates/<YourModel>/<file>.py.forward(self, inputs) -> tuple[Tensor, Tensor]: receives exactly whatprepare_dataemits (for example(branch_input, trunk_input, targets)in MultiONet). Return(predictions, targets)in that order so the shared training/validation utilities can compute metrics without model-specific branching.prepare_data(self, datasets, metadata, …) -> tuple[DataLoader, DataLoader]: build the train/validation dataloaders from the benchmark datasets. This is where you slice parameter sets, apply custom transforms, or pack tuples that yourforwardexpects.fit(self, train_loader, val_loader, …): implements the training loop.AbstractSurrogateModel.train_and_save_modelsets up everything (optimizers, schedulers, loggers) before callingfit, so you can focus on iterating over the loaders and callingself.validate.
Reuse shared helpers
self.setup_progress_bar(...)draws status updates without clashing with the multi-process trainer.self.predict(loader, ...)is the canonical way to produce predictions/targets during validation and evaluation—this ensures consistent batching, buffer pre-allocation, and shape handling for every surrogate. Test your model with it to guarantee compatibility.self.validate(...)bundles metric computation, Optuna pruning hooks, checkpointing, and logging. Call it fromfit(typically once per epoch) after computing validation losses viaself.predict.Other protected helpers (
save,load,denormalise, optimizer/scheduler factories) already encode CODES conventions; override only when a model’s requirements truly differ.
Register the surrogate by appending
AbstractSurrogateModel.register("MySurrogateName")at the end of the file that defines the class. Use the string you want users to reference inconfig.yaml. Without this hook the CLI cannot instantiate your model.Expose configuration
Create a companion config dataclass (e.g.,
deeponet_config.py’sMultiONetBaseConfig). Inherit fromAbstractSurrogateBaseConfigwhenever possible so shared hyperparameters (learning rate, optimizer, scheduler, loss, activation) remain documented and gain sensible defaults.Keep model-specific knobs inside that dataclass rather than the global config; users set them via the
surrogate_configssection without touching other surrogates. The main benchmark config should stay model-agnostic unless you intentionally expose a cross-surrogate toggle.
Checkpointing + evaluation
AbstractSurrogateModelalready serializes weights, optimizer state, and scheduler state, andrun_eval.pyexpects that layout. If you modify the format, verify that evaluation still loads checkpoints without custom flags.predictand the evaluation pipeline assume consistent output shapes, so avoid ad-hoc reshaping in downstream scripts—handle it inforwardorprepare_data. This uniform path is what keeps the benchmark fair across models.
Before submitting a PR or relying on the surrogate in large runs, train it with the minimal config and confirm that run_eval.py + predict behave as expected.
Customize benchmark modes#
New benchmark modes (beyond interpolation/extrapolation/sparse/batch-scaling/uncertainty/iterative) follow the same pattern: place an object with enabled: bool and the parameters you want to sweep under the top-level config. Detailed documentation for additional modes is coming soon—refer to Running Benchmarks for the currently supported modalities and CLI flags.