gnn_tracking_hpo.restore#

Module Contents#

Functions#

restore_model(→ torch.nn.Module)

Load pre-trained edge classifier

gnn_tracking_hpo.restore.restore_model(trainable_cls, tune_dir: str, run_hash: str, epoch: int = -1, *, config_update: dict | None = None, freeze: bool = True) torch.nn.Module#

Load pre-trained edge classifier

Parameters:
  • tune_dir (str) – Name of ray tune outptu directory

  • run_hash (str) – Hash of the run

  • epoch (int, optional) – Epoch to load. Defaults to -1 (last epoch).

  • config_update (dict, optional) – Update the config with this dict.