Tutorial 2: Integrate DLPFC data with MaskGraphene
In this tutorial, we show how to use MG to integrate DLPFC data. As an example, we analyze the 151673/151674 sample pair of the dorsolateral prefrontal cortex (DLPFC) dataset.
We acquired the data from the spatialLIBD webpage, including manual annotations. Before running the model, please download the input data via zenodo link
Loading packages
[1]:
import logging
import numpy as np
from tqdm import tqdm
import torch
import pickle
import sys
import os
import scanpy as sc
import sklearn.metrics.pairwise
# Get the parent directory of the current script
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
# Add the parent directory to the system path
sys.path.insert(0, parent_dir)
from utils import (
build_args_ST,
create_optimizer
)
from datasets.st_loading_utils import visualization_umap_spatial, create_dictionary_mnn
from models import build_model_ST
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/torchdata/datapipes/__init__.py:18: UserWarning:
################################################################################
WARNING!
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################
deprecation_warning()
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[2]:
args = build_args_ST()
args.section_ids=["151673","151674"]
num_hidden = [50]
lr = 0.0003
args.max_epoch = 2000
args.max_epoch_triplet = 500
args.dataset_name = "DLPFC"
args.num_hidden = num_hidden
args.num_layers = len(num_hidden)
args.alpha_l = 1
args.lam = 1
args.loss_fn = "sce"
args.mask_rate = 0.4
args.in_drop = 0.1
args.attn_drop = 0.05
args.remask_rate = 0.1
args.seeds = [2024]
# args.num_remasking = 3
args.hvgs = 5000
args.lr = lr
args.device = 0
args.activation = "prelu"
args.negative_slope = 0.2
args.num_dec_layers = 1
#### remember to change these paths to your data path/link path
args.st_data_dir="../../spatial_benchmarking/benchmarking_data/DLPFC12"
args.hl_dir="../hard_links/DLPFC"
MG training
consecutive DLPFC pairs
[ ]:
import dgl
import scipy
import anndata
from datasets.data_proc import load_ST_dataset
dataset_name = args.dataset
section_ids = args.section_ids
graph, (num_features, num_cls), ad_concat = load_ST_dataset(dataset_name=dataset_name, section_ids=section_ids, args_=args)
args.num_features = num_features
args.num_class = num_cls
x = graph.ndata["feat"]
build model
[5]:
model = build_model_ST(args)
print(model)
device = args.device if args.device >= 0 else "cpu"
model.to(device)
optim_type = args.optimizer
lr = args.lr
weight_decay = args.weight_decay
optimizer = create_optimizer(optim_type, model, lr, weight_decay)
use_scheduler = args.scheduler
max_epoch = args.max_epoch
max_epoch_triplet = args.max_epoch_triplet
if use_scheduler:
logging.critical("Use scheduler")
scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
scheduler = None
model.to(device)
graph = graph.to(device)
x = x.to(device)
=== Use sce_loss and alpha_l=1 ===
num_encoder_params: 250150, num_decoder_params: 265000, num_params_in_total: 577014
PreModel(
(encoder): GAT(
(gat_layers): ModuleList(
(0): GATConv(
(fc): Linear(in_features=5000, out_features=50, bias=False)
(feat_drop): Dropout(p=0.1, inplace=False)
(attn_drop): Dropout(p=0.05, inplace=False)
(leaky_relu): LeakyReLU(negative_slope=0.2)
)
)
(head): Identity()
)
(decoder): GAT(
(gat_layers): ModuleList(
(0): GATConv(
(fc): Linear(in_features=50, out_features=5000, bias=False)
(feat_drop): Dropout(p=0.1, inplace=False)
(attn_drop): Dropout(p=0.05, inplace=False)
(leaky_relu): LeakyReLU(negative_slope=0.2)
)
)
(head): Identity()
)
(encoder_to_decoder): Linear(in_features=50, out_features=50, bias=False)
(projector): Sequential(
(0): Linear(in_features=50, out_features=512, bias=True)
(1): PReLU(num_parameters=1)
(2): Linear(in_features=512, out_features=50, bias=True)
)
(projector_ema): Sequential(
(0): Linear(in_features=50, out_features=512, bias=True)
(1): PReLU(num_parameters=1)
(2): Linear(in_features=512, out_features=50, bias=True)
)
(predictor): Sequential(
(0): PReLU(num_parameters=1)
(1): Linear(in_features=50, out_features=50, bias=True)
)
(encoder_ema): GAT(
(gat_layers): ModuleList(
(0): GATConv(
(fc): Linear(in_features=5000, out_features=50, bias=False)
(feat_drop): Dropout(p=0.1, inplace=False)
(attn_drop): Dropout(p=0.05, inplace=False)
(leaky_relu): LeakyReLU(negative_slope=0.2)
)
)
(head): Identity()
)
)
training
[6]:
from maskgraphene_main import MG, MG_triplet
model, ad_concat_1 = MG(model, graph, x, optimizer, max_epoch, device, ad_concat, scheduler, logger=None, key_="MG")
model, ad_concat_2 = MG_triplet(model, graph, x, optimizer, max_epoch_triplet, device, adata_concat_=ad_concat_1, scheduler=scheduler, logger=None, key_="MG_triplet")
CRITICAL:root:start training..
100%|██████████| 2000/2000 [02:13<00:00, 14.93it/s]
CRITICAL:root:start training..
# Epoch 499: train_loss: 0.5178: 100%|██████████| 500/500 [00:38<00:00, 13.02it/s]
evaluate
[7]:
exp_fig_dir = "./temp"
if not os.path.exists(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))):
os.makedirs(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids)))
exp_fig_dir = os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))
ari_ = visualization_umap_spatial(ad_temp=ad_concat_2, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name=dataset_name, num_iter="0", identifier="stage2", num_class=args.num_class, use_key="MG_triplet")
# ari_2.append(ari_[1])
print(section_ids[0], ', ARI = %01.3f' % ari_[0])
print(section_ids[1], ', ARI = %01.3f' % ari_[1])
WARNING:rpy2.rinterface_lib.callbacks:R[write to console]: __ __
____ ___ _____/ /_ _______/ /_
/ __ `__ \/ ___/ / / / / ___/ __/
/ / / / / / /__/ / /_/ (__ ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/ version 6.0.0
Type 'citation("mclust")' for citing this R package in publications.
fitting ...
|======================================================================| 100%
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/scanpy/plotting/_utils.py:430: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.
adata.uns[value_to_plot + '_colors'] = colors_list
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/scanpy/plotting/_utils.py:430: ImplicitModificationWarning: Trying to modify attribute `._uns` of view, initializing view as actual.
adata.uns[value_to_plot + '_colors'] = colors_list
151673 , ARI = 0.626
151674 , ARI = 0.645
[ ]: