model
- class model.CANAL_model(gpu_option='0', CLASS=7, EMBEDDING_DIM=64, VALIDATE_EVERY=1, PATIENCE=10, BATCH_SIZE=14, EPOCHS=100, LEARNING_RATE=0.0002)
Bases:
objectConstruct the CANAL model, which continually adapts pre-trained language model to universal annotation of scRNA-seq data
- evaluation(pred_cell_type, true_celltype, novel_celltype=None)
If the test dataset has ground-truth cell-type labels, evaluate the performance of the CANAL model. If no novel cells in the test data, we will exhibit total annotation accuracy, F1 score as well as ARI; if novel cells exist in the test data, we will exhibit H-score, total annotation accuracy, annotation accuracy of known cells as well as accuracy of unknown cells
- Pred_cell_type (array):
predicted cell types of the test dataset via the CANAL model
- True_cell_type (array):
true cell types of the test dataset
- Novel_celltype (array):
novel cell type of the test dataset that have never appeared in the fine-tuning data stream. The default value is None.
- predict(adata_predict, ckpt_dir, experiments, stage_num, dataset, novel=False, temperature=0.8)
Predict cell types of an unlabelled dataset using the CANAL model, which has completed fine-tuning after certain stage
- Adata_predict (AnnData object):
unlabeled test data we need to annotate
- Ckpt_dir:
path of the saved fine-tuned CANAL model
- Experiments (str):
name of experiments
- Stage_num (int):
which stage of fine-tuning the CANAL model has completed
- Dataset (str):
name of the dataset used during the given stage
- Novel (bool):
whether to detect novel cells in the test data, default value is False. If True, we will annotate cells as “Unassigned”, if their uncertainty score is larger than the automatic determined threshold.
- Temperature (float):
the temperature parameter in the energy function if novel cell detection is needed. The default value is 0.8
- Returns:
an array, which provides predicted cell types of the test dataset
- train(experiments, pre_dataset, dataset, adata, cell_type, current_stage, is_final_stage, ckpt_dir, rehearsal_size=1000, highly_variable_idx=None, lambda_KD=0.1, SEED=1)
Train the CANAL model at a certain stage
- Experiments (str):
name of experiments
- Pre_dataset (str):
name of previous dataset
- Dataset (str):
name of current dataset
- Adata (AnnData object):
current data used to fine-tune the model
- Cell_type (array):
the corresponding cell types of current data
- Current_stage(int):
which stage the model is fine-tuned
- Stage_num (int):
current stage of fine-tuning
- Is_final_stage (bool):
whether current stage is the final stage or not
- Ckpt_dir:
path to save the model
- Rehearsal_size (int):
the number of cells in total we can preserve in the example bank, default = 1000
- Highly_variable_idx (array):
index of highly variable genes, should be provided if this is the initial stage
- Lambda_KD (float):
trainin weight for the representation knowledge distillation loss, default = 0.1
- Gpu_option (int):
which device to run the model, default = ‘0’
- SEED (int):
which random seed is chosen to run the current model, default = 1
- class model.Identity(SEQ_LEN, dropout=0.0, h_dim=64, out_dim=10)
Bases:
ModuleConstruct network output layer
- forward(x)
Forward propagation
- training: bool
- class model.SCDataset(data, label, CLASS)
Bases:
DatasetConstruct SCDataset from AnnData
- model.example_bank_update(example_bank_previous, adata, new_model_annotation, embedding, prototype, current_label_dict, current_label_set, current_label, each_class_num, current_stage)
Update current example bank via prototypes
- model.get_embedding(data, CLASS, model)
get samples’ representation of the penultimate layer
- model.save_model(experiments, dataset, stage_num, is_final_stage, example_bank, label_dict, highly_variable_idx, model, ckpt_dir)
Save current model, cell-type library as well as example bank for continual learning and evaluation.
- Experiments (str):
name of experiments
- Stage_num (int):
current stage of fine-tuning
- Is_final_stage (bool):
whether current stage is the final stage or not;
- Example_bank (AnnData object):
the selected examples for repeatedly reviewing
- Label_dict (array):
cell type library with all the cell types that have appeared so far
- Highly_variable_idx (array):
index of highly variable genes
- Model:
the fine-tuned CANAL model after this stage
- Ckpt_dir:
path to save the model
- model.setup_seed(seed)
Choose a random seed to run the current model.
- Seed (int):
random seed