:py:mod:`gnn_tracking_hpo.trainable`
====================================

.. py:module:: gnn_tracking_hpo.trainable


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   gnn_tracking_hpo.trainable.HPOTrainable
   gnn_tracking_hpo.trainable.DefaultTrainable
   gnn_tracking_hpo.trainable.PretrainedECTCNTrainable
   gnn_tracking_hpo.trainable.ECTrainable




.. py:class:: HPOTrainable


   Bases: :py:obj:`ray.tune.Trainable`, :py:obj:`abc.ABC`

   Add additional 'restore' capabilities to tune.Trainable.

   .. py:method:: reinstate(project: str, hash: str, *, epoch=-1, n_graphs: int | None = None, config_override: dict[str, Any] | None = None)
      :classmethod:

      Load config from wandb and restore on-disk checkpoint.

      This is different from `tune.Trainable.restore` which is called from
      an instance, i.e., already needs to be initialized with a config.

      :param project: The wandb project name that should also correspond to the local
                      folder with the checkpoints
      :param hash: The wandb run hash.
      :param epoch: The epoch to restore. If -1, restore the last epoch. If 0, do not
                    restore any checkpoint.
      :param n_graphs: Total number of samples to load. ``None`` uses the values from
                       training.
      :param config_override: Update the config with these values.



.. py:class:: DefaultTrainable


   Bases: :py:obj:`HPOTrainable`

   A wrapper around `TCNTrainer` for use with Ray Tune.

   .. py:attribute:: dispatcher_id
      :type: int
      :value: 0

      

   .. py:method:: setup(config: dict[str, Any])


   .. py:method:: hook_before_trainer_setup()


   .. py:method:: get_model() -> torch.nn.Module


   .. py:method:: get_edge_loss_function() -> tuple[torch.nn.Module, float]


   .. py:method:: get_potential_loss_function() -> tuple[torch.nn.Module, dict[str, float]]


   .. py:method:: get_background_loss_function() -> tuple[torch.nn.Module, float]


   .. py:method:: get_loss_functions() -> dict[str, tuple[torch.nn.Module, Any]]


   .. py:method:: get_cluster_functions() -> dict[str, Any]


   .. py:method:: get_lr_scheduler()


   .. py:method:: get_optimizer()


   .. py:method:: get_loaders()


   .. py:method:: get_trainer() -> gnn_tracking.training.tcn_trainer.TCNTrainer


   .. py:method:: step()


   .. py:method:: save_checkpoint(checkpoint_dir)


   .. py:method:: load_checkpoint(checkpoint_path, **kwargs)



.. py:class:: PretrainedECTCNTrainable


   Bases: :py:obj:`DefaultTrainable`

   A wrapper around `TCNTrainer` for use with Ray Tune.

   .. py:property:: _is_continued_run
      :type: bool

      We're restoring a model from a previous run and continuing.

   .. py:property:: _need_edge_loss
      :type: bool


   .. py:method:: hook_before_trainer_setup()

      Ensure that we bring back the EC config before we initialize anything


   .. py:method:: _update_config_from_restored_tc() -> None


   .. py:method:: _update_edge_loss_config_from_ec() -> None

      When we use an unfrozen pretrained EC, make sure the loss function
      configuration stays the same.


   .. py:method:: get_loss_functions() -> dict[str, tuple[torch.nn.Module, Any]]


   .. py:method:: get_trainer() -> gnn_tracking.training.tcn_trainer.TCNTrainer


   .. py:method:: _get_new_model() -> torch.nn.Module


   .. py:method:: _get_restored_model() -> torch.nn.Module

      Load previously trained model to continue


   .. py:method:: get_model() -> torch.nn.Module



.. py:class:: ECTrainable


   Bases: :py:obj:`DefaultTrainable`

   A wrapper around `TCNTrainer` for use with Ray Tune.

   .. py:property:: _is_continued_run
      :type: bool

      We're restoring a model from a previous run and continuing.

   .. py:method:: get_loss_functions() -> dict[str, Any]


   .. py:method:: get_cluster_functions() -> dict[str, Any]


   .. py:method:: get_trainer() -> gnn_tracking.training.tcn_trainer.TCNTrainer


   .. py:method:: _get_new_model() -> torch.nn.Module

      New model to be trained (rather than continuing training a pretrained
      one).


   .. py:method:: _get_restored_model() -> torch.nn.Module

      Load previously trained model to continue


   .. py:method:: get_model() -> torch.nn.Module



