codes.train package#

Submodules#

codes.train.train_fcts module#

class codes.train.train_fcts.DummyLock#

Bases: object

acquire()#
release()#
codes.train.train_fcts.create_task_list_for_surrogate(config, surr_name)#

Creates a list of training tasks for a specific surrogate model based on the configuration file.

Parameters:
  • config (dict) – The configuration dictionary taken from the config file.

  • surr_name (str) – The name of the surrogate model.

Returns:

A list of training tasks for the surrogate model.

Return type:

list

codes.train.train_fcts.parallel_training(tasks, device_list, task_list_filepath)#

Execute the queued training tasks across multiple devices using worker threads.

Parameters:
  • tasks (list[tuple]) – Output of create_task_list_for_surrogate().

  • device_list (list[str]) – Devices allocated to training (e.g. [“cuda:0”, “cuda:1”]).

  • task_list_filepath (str) – Path to the persisted JSON task list that tracks progress.

Returns:

Elapsed wall-clock time reported by the shared progress bar.

Return type:

float

codes.train.train_fcts.sequential_training(tasks, device_list, task_list_filepath)#

Run all training tasks sequentially on a single device.

Parameters:
  • tasks (list[tuple]) – Task specification tuples generated from the config.

  • device_list (list[str]) – Contains exactly one element (typically “cpu” or a single CUDA id).

  • task_list_filepath (str) – Path to the JSON file used to resume interrupted runs.

Returns:

Total elapsed time once all tasks finish.

Return type:

float

codes.train.train_fcts.train_and_save_model(surr_name, mode, metric, training_id, seed=None, epochs=None, device='cpu', position=1, threadlock=<codes.train.train_fcts.DummyLock object>)#

Train and save a model for a specific benchmark mode.

Parameters:
  • surr_name (str) – The name of the surrogate model.

  • mode (str) – The benchmark mode.

  • metric (int) – The metric for the benchmark mode.

  • training_id (str) – The training ID for the current training session.

  • seed (int, optional) – Random seed for training.

  • epochs (int, optional) – Number of training epochs.

  • device (str, optional) – Device to run training on.

  • position (int, optional) – Model position in the task list.

  • threadlock (threading.Lock, optional) – Lock for deterministic threading.

codes.train.train_fcts.worker(task_queue, device, device_idx, overall_progress_bar, task_list_filepath, errors_encountered, threadlock)#

Worker function to process tasks from the task queue on the given device.

Parameters:
  • task_queue (Queue) – The in-memory queue containing the training tasks.

  • device (str) – The device to use for training.

  • device_idx (int) – The index of the device in the device list.

  • overall_progress_bar (tqdm) – The overall progress bar for the training.

  • task_list_filepath (str) – The filepath to the JSON task list.

  • errors_encountered (list[bool]) – A shared mutable flag array indicating if an error has occurred (True if at least one task failed).

  • threadlock (threading.Lock) – A lock to prevent threading issues with PyTorch.

Module contents#

class codes.train.DummyLock#

Bases: object

acquire()#
release()#
codes.train.create_task_list_for_surrogate(config, surr_name)#

Creates a list of training tasks for a specific surrogate model based on the configuration file.

Parameters:
  • config (dict) – The configuration dictionary taken from the config file.

  • surr_name (str) – The name of the surrogate model.

Returns:

A list of training tasks for the surrogate model.

Return type:

list

codes.train.parallel_training(tasks, device_list, task_list_filepath)#

Execute the queued training tasks across multiple devices using worker threads.

Parameters:
  • tasks (list[tuple]) – Output of create_task_list_for_surrogate().

  • device_list (list[str]) – Devices allocated to training (e.g. [“cuda:0”, “cuda:1”]).

  • task_list_filepath (str) – Path to the persisted JSON task list that tracks progress.

Returns:

Elapsed wall-clock time reported by the shared progress bar.

Return type:

float

codes.train.sequential_training(tasks, device_list, task_list_filepath)#

Run all training tasks sequentially on a single device.

Parameters:
  • tasks (list[tuple]) – Task specification tuples generated from the config.

  • device_list (list[str]) – Contains exactly one element (typically “cpu” or a single CUDA id).

  • task_list_filepath (str) – Path to the JSON file used to resume interrupted runs.

Returns:

Total elapsed time once all tasks finish.

Return type:

float

codes.train.train_and_save_model(surr_name, mode, metric, training_id, seed=None, epochs=None, device='cpu', position=1, threadlock=<codes.train.train_fcts.DummyLock object>)#

Train and save a model for a specific benchmark mode.

Parameters:
  • surr_name (str) – The name of the surrogate model.

  • mode (str) – The benchmark mode.

  • metric (int) – The metric for the benchmark mode.

  • training_id (str) – The training ID for the current training session.

  • seed (int, optional) – Random seed for training.

  • epochs (int, optional) – Number of training epochs.

  • device (str, optional) – Device to run training on.

  • position (int, optional) – Model position in the task list.

  • threadlock (threading.Lock, optional) – Lock for deterministic threading.

codes.train.worker(task_queue, device, device_idx, overall_progress_bar, task_list_filepath, errors_encountered, threadlock)#

Worker function to process tasks from the task queue on the given device.

Parameters:
  • task_queue (Queue) – The in-memory queue containing the training tasks.

  • device (str) – The device to use for training.

  • device_idx (int) – The index of the device in the device list.

  • overall_progress_bar (tqdm) – The overall progress bar for the training.

  • task_list_filepath (str) – The filepath to the JSON task list.

  • errors_encountered (list[bool]) – A shared mutable flag array indicating if an error has occurred (True if at least one task failed).

  • threadlock (threading.Lock) – A lock to prevent threading issues with PyTorch.