Tutorial 4: Integrate MB data with MaskGraphene

[1]:
import logging
import numpy as np
from tqdm import tqdm
import torch
import pickle
import sys
import os
import scanpy as sc
import sklearn.metrics.pairwise

# Get the parent directory of the current script
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

# Add the parent directory to the system path
sys.path.insert(0, parent_dir)

from utils import (
    build_args_ST,
    create_optimizer
)
from datasets.st_loading_utils import visualization_umap_spatial, create_dictionary_mnn
from models import build_model_ST
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/torchdata/datapipes/__init__.py:18: UserWarning:
################################################################################
WARNING!
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  deprecation_warning()
/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
args = build_args_ST()
args.max_epoch=3000
args.max_epoch_triplet=500
args.section_ids=['0','1','2','3','4','5','6','7','8','9']
# args.section_ids=['0','1','2','3','4','5','6','7','8']
# args.num_class=8
args.num_hidden="512,32"
args.alpha_l=3
args.lam=1
args.loss_fn="sce"
args.mask_rate=0.5
args.in_drop=0
args.attn_drop=0
args.remask_rate=0.5
args.seeds=[2024]
args.num_remasking=2
args.dataset="MB"
args.lr=0.0003

#### remember to change these paths to your data path/link path
args.st_data_dir="../../spatial_benchmarking/benchmarking_data/merfish_mouse_brain"
args.hl_dir="../hard_links/MB"
[ ]:
import dgl
import scipy
import anndata
from datasets.data_proc import load_ST_dataset

dataset_name = args.dataset
section_ids = args.section_ids

graph, (num_features, num_cls), ad_concat = load_ST_dataset(dataset_name=dataset_name, section_ids=section_ids, args_=args)
args.num_features = num_features
args.num_class = num_cls
x = graph.ndata["feat"]
[4]:
model = build_model_ST(args)
print(model)

device = args.device if args.device >= 0 else "cpu"
model.to(device)

optim_type = args.optimizer
lr = args.lr
weight_decay = args.weight_decay
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

use_scheduler = args.scheduler
max_epoch = args.max_epoch
max_epoch_triplet = args.max_epoch_triplet
if use_scheduler:
    logging.critical("Use scheduler")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

model.to(device)
graph = graph.to(device)
x = x.to(device)
=== Use sce_loss and alpha_l=3 ===
num_encoder_params: 148064, num_decoder_params: 148730, num_params_in_total: 332474
PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=254, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): ELU(alpha=1.0)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=32, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=32, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): ELU(alpha=1.0)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=254, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
  (encoder_to_decoder): Linear(in_features=32, out_features=32, bias=False)
  (projector): Sequential(
    (0): Linear(in_features=32, out_features=512, bias=True)
    (1): PReLU(num_parameters=1)
    (2): Linear(in_features=512, out_features=32, bias=True)
  )
  (projector_ema): Sequential(
    (0): Linear(in_features=32, out_features=512, bias=True)
    (1): PReLU(num_parameters=1)
    (2): Linear(in_features=512, out_features=32, bias=True)
  )
  (predictor): Sequential(
    (0): PReLU(num_parameters=1)
    (1): Linear(in_features=32, out_features=32, bias=True)
  )
  (encoder_ema): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=254, out_features=512, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): ELU(alpha=1.0)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=32, bias=False)
        (feat_drop): Dropout(p=0, inplace=False)
        (attn_drop): Dropout(p=0, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
)
[5]:
from maskgraphene_main import MG, MG_triplet

model, ad_concat_1 = MG(model, graph, x, optimizer, max_epoch, device, ad_concat, scheduler, logger=None, key_="MG")
model, ad_concat_2 = MG_triplet(model, graph, x, optimizer, max_epoch_triplet, device, adata_concat_=ad_concat_1, scheduler=scheduler, logger=None, key_="MG_triplet")
CRITICAL:root:start training..
100%|██████████| 3000/3000 [03:46<00:00, 13.22it/s]
CRITICAL:root:start training..
# Epoch 499: train_loss: 0.1089: 100%|██████████| 500/500 [03:55<00:00,  2.13it/s]
[6]:
exp_fig_dir = "./temp"
if not os.path.exists(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))):
    os.makedirs(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids)))

exp_fig_dir = os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))

import pandas as pd
ad_concat_2.obsm['spatial'] = pd.DataFrame(
    ad_concat_2.obsm['spatial'],
    columns=['X', 'Y'],
    index=ad_concat_2.obs.index
)

ari_ = visualization_umap_spatial(ad_temp=ad_concat_2, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name='merfishMB', num_iter="0", identifier="stage2", num_class=args.num_class, use_key="MG_triplet")
for i in range(10):
    print(section_ids[i], ', ARI = %01.3f' % ari_[i])
WARNING:rpy2.rinterface_lib.callbacks:R[write to console]:                    __           __
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 6.0.0
Type 'citation("mclust")' for citing this R package in publications.

fitting ...
  |======================================================================| 100%
0 , ARI = 0.540
1 , ARI = 0.556
2 , ARI = 0.415
3 , ARI = 0.520
4 , ARI = 0.548
5 , ARI = 0.599
6 , ARI = 0.577
7 , ARI = 0.643
8 , ARI = 0.633
9 , ARI = 0.599
_images/Tutorial_4_MaskGraphene_on_MB_6_2.png