gnn_tracking_hpo.trainable#

Module Contents#

Classes#

HPOTrainable

Add additional 'restore' capabilities to tune.Trainable.

DefaultTrainable

A wrapper around TCNTrainer for use with Ray Tune.

PretrainedECTCNTrainable

A wrapper around TCNTrainer for use with Ray Tune.

ECTrainable

A wrapper around TCNTrainer for use with Ray Tune.

class gnn_tracking_hpo.trainable.HPOTrainable#

Bases: ray.tune.Trainable, abc.ABC

Add additional ‘restore’ capabilities to tune.Trainable.

classmethod reinstate(project: str, hash: str, *, epoch=-1, n_graphs: int | None = None, config_override: dict[str, Any] | None = None)#

Load config from wandb and restore on-disk checkpoint.

This is different from tune.Trainable.restore which is called from an instance, i.e., already needs to be initialized with a config.

Parameters:
  • project – The wandb project name that should also correspond to the local folder with the checkpoints

  • hash – The wandb run hash.

  • epoch – The epoch to restore. If -1, restore the last epoch. If 0, do not restore any checkpoint.

  • n_graphs – Total number of samples to load. None uses the values from training.

  • config_override – Update the config with these values.

class gnn_tracking_hpo.trainable.DefaultTrainable#

Bases: HPOTrainable

A wrapper around TCNTrainer for use with Ray Tune.

dispatcher_id: int = 0#
setup(config: dict[str, Any])#
hook_before_trainer_setup()#
get_model() torch.nn.Module#
get_edge_loss_function() tuple[torch.nn.Module, float]#
get_potential_loss_function() tuple[torch.nn.Module, dict[str, float]]#
get_background_loss_function() tuple[torch.nn.Module, float]#
get_loss_functions() dict[str, tuple[torch.nn.Module, Any]]#
get_cluster_functions() dict[str, Any]#
get_lr_scheduler()#
get_optimizer()#
get_loaders()#
get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer#
step()#
save_checkpoint(checkpoint_dir)#
load_checkpoint(checkpoint_path, **kwargs)#
class gnn_tracking_hpo.trainable.PretrainedECTCNTrainable#

Bases: DefaultTrainable

A wrapper around TCNTrainer for use with Ray Tune.

property _is_continued_run: bool#

We’re restoring a model from a previous run and continuing.

property _need_edge_loss: bool#
hook_before_trainer_setup()#

Ensure that we bring back the EC config before we initialize anything

_update_config_from_restored_tc() None#
_update_edge_loss_config_from_ec() None#

When we use an unfrozen pretrained EC, make sure the loss function configuration stays the same.

get_loss_functions() dict[str, tuple[torch.nn.Module, Any]]#
get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer#
_get_new_model() torch.nn.Module#
_get_restored_model() torch.nn.Module#

Load previously trained model to continue

get_model() torch.nn.Module#
class gnn_tracking_hpo.trainable.ECTrainable#

Bases: DefaultTrainable

A wrapper around TCNTrainer for use with Ray Tune.

property _is_continued_run: bool#

We’re restoring a model from a previous run and continuing.

get_loss_functions() dict[str, Any]#
get_cluster_functions() dict[str, Any]#
get_trainer() gnn_tracking.training.tcn_trainer.TCNTrainer#
_get_new_model() torch.nn.Module#

New model to be trained (rather than continuing training a pretrained one).

_get_restored_model() torch.nn.Module#

Load previously trained model to continue

get_model() torch.nn.Module#