gnn_tracking_hpo.trainable
#
Module Contents#
Classes#
Add additional 'restore' capabilities to tune.Trainable. |
|
A wrapper around TCNTrainer for use with Ray Tune. |
|
A wrapper around TCNTrainer for use with Ray Tune. |
|
A wrapper around TCNTrainer for use with Ray Tune. |
- class gnn_tracking_hpo.trainable.HPOTrainable#
Bases:
ray.tune.Trainable
,abc.ABC
Add additional ‘restore’ capabilities to tune.Trainable.
- classmethod reinstate(project: str, hash: str, *, epoch=-1, n_graphs: int | None = None, config_override: dict[str, Any] | None = None)#
Load config from wandb and restore on-disk checkpoint.
This is different from tune.Trainable.restore which is called from an instance, i.e., already needs to be initialized with a config.
- Parameters:
project – The wandb project name that should also correspond to the local folder with the checkpoints
hash – The wandb run hash.
epoch – The epoch to restore. If -1, restore the last epoch. If 0, do not restore any checkpoint.
n_graphs – Total number of samples to load.
None
uses the values from training.config_override – Update the config with these values.
- class gnn_tracking_hpo.trainable.DefaultTrainable#
Bases:
HPOTrainable
A wrapper around TCNTrainer for use with Ray Tune.
- dispatcher_id: int = 0#
- setup(config: dict[str, Any])#
- hook_before_trainer_setup()#
- get_model() torch.nn.Module #
- get_edge_loss_function() tuple[torch.nn.Module, float] #
- get_potential_loss_function() tuple[torch.nn.Module, dict[str, float]] #
- get_background_loss_function() tuple[torch.nn.Module, float] #
- get_loss_functions() dict[str, tuple[torch.nn.Module, Any]] #
- get_cluster_functions() dict[str, Any] #
- get_lr_scheduler()#
- get_optimizer()#
- get_loaders()#
- get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer #
- step()#
- save_checkpoint(checkpoint_dir)#
- load_checkpoint(checkpoint_path, **kwargs)#
- class gnn_tracking_hpo.trainable.PretrainedECTCNTrainable#
Bases:
DefaultTrainable
A wrapper around TCNTrainer for use with Ray Tune.
- property _is_continued_run: bool#
We’re restoring a model from a previous run and continuing.
- property _need_edge_loss: bool#
- hook_before_trainer_setup()#
Ensure that we bring back the EC config before we initialize anything
- _update_config_from_restored_tc() None #
- _update_edge_loss_config_from_ec() None #
When we use an unfrozen pretrained EC, make sure the loss function configuration stays the same.
- get_loss_functions() dict[str, tuple[torch.nn.Module, Any]] #
- get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer #
- _get_new_model() torch.nn.Module #
- _get_restored_model() torch.nn.Module #
Load previously trained model to continue
- get_model() torch.nn.Module #
- class gnn_tracking_hpo.trainable.ECTrainable#
Bases:
DefaultTrainable
A wrapper around TCNTrainer for use with Ray Tune.
- property _is_continued_run: bool#
We’re restoring a model from a previous run and continuing.
- get_loss_functions() dict[str, Any] #
- get_cluster_functions() dict[str, Any] #
- get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer #
- _get_new_model() torch.nn.Module #
New model to be trained (rather than continuing training a pretrained one).
- _get_restored_model() torch.nn.Module #
Load previously trained model to continue
- get_model() torch.nn.Module #