Skip to content

Commit

Permalink
update inference
Browse files Browse the repository at this point in the history
  • Loading branch information
weihaox committed Apr 11, 2024
1 parent 08ee563 commit d5b9242
Show file tree
Hide file tree
Showing 11 changed files with 1,144 additions and 1 deletion.
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
<p>UMBRAE decodes multimodal explanations from brain signals. (1) We introduce a <b>universal brain encoder</b> for multimodal-brain alignment and recover conceptual and spatial details by using multimodal large language models. (2) We introduce <b>cross-subject training</b> to overcome unique brain patterns of different individuals. This allows brain signals from multiple subjects to be trained within the same model This allows brain signals from multiple subjects to be trained within the same model. (3) Our method supports <b>weakly-supervised subject adaptation</b>, enabling the training of a model for a new subject in a data-efficient manner. (4) For evaluation, we introduce <b>BrainHub</b>, a brain understanding benchmark, based on NSD and COCO.

## News :triangular_flag_on_post:
- [2024/04/12] Inference and pretrained model available. Training code coming up soon.
- [2024/04/11] <a href="https://github.com/weihaox/BrainHub">BrainHub</a> is available.
- [2024/03/15] Both <a href="https://weihaox.github.io/UMBRAE">project</a> and <a href="https://arxiv.org/pdf/2404.07202">arXiv</a> are available.

## Method
Expand All @@ -53,6 +55,77 @@ Overview of UMBRAE. Our brain encoder includes subject-specific tokenizers and a
<img src="docs/images/overview.png" width="90%"/>
</tr></div>


## Installation

### Environment

```bash
conda create -n brainx python=3.10
conda activate brainx
pip install -r requirements.txt
```

### Download Data and Checkpoints

The training and inference scripts support automatically downloading the dataset if the designated path is empty. However, this process can be quite slow. You can try the following script to download all data in advance if this happens. Please fill out the NSD [Data Access form](https://forms.gle/xue2bCdM9LaFNMeb7) and agree to the [Terms and Conditions](https://cvnlab.slite.page/p/IB6BSeW_7o/Terms-and-Conditions).

Download Checkpoints from [Hugging Face](https://huggingface.co/datasets/weihaox/brainx).

```bash
bash download_data.sh
bash download_checkpoint.sh
```

## Inference

Our method inherits multimodal understanding capabilities of MLLMs, enabling the switch between different tasks through different prompts. You can either use the prompts listed in our paper or create customised instructions according to actual needs.

```bash
prompt_caption='Describe this image <image> as simply as possible.'
prompt_ground='Please interpret this image and give coordinates [x1,y1,x2,y2] for each object you mention.'

python inference.py --fmri_encoder 'brainx' --subj 1 --prompt "$prompt_ground" \
--data_path 'nsd_data' --brainx_path 'train_logs/brainx.pth' \
--save_path 'evaluation/eval_caption/caption_results/umbrae/sub01_dim1024'
```

Given that identified classes might be named differently, or simply absent from ground truth labels, we evaluate bounding boxes through REC. We use prompt `"Locate <expr> in <image> and provide its coordinates, please"`, but others like `"Can you point out <expr> in the image and provide the bounding boxes of its location?"` shall also work.

```bash
python inference_rec.py --data_path 'nsd_data' --fmri_encoder 'brainx' \
--subj 1 --brainx_path 'train_logs/brainx.pth' \
--save_path 'evaluation/eval_bbox_rec/rec_results/umbrae/sub01_dim1024'
```

## Evaluation

The benchmark, including groundtruth data, evaluation scripts, and baseline results, is in [brainhub](https://github.com/weihaox/BrainHub).

1. Download `brainhub` to the root path: `git clone https://github.com/weihaox/BrainHub`

2. Process groundtruth test images: `python processing/decode_images.py`

3. Run evaluation for brain captioning and grounding:

```bash
cd BrainHub
for sub in 1 2 5 7
do
python eval_caption.py ../umbrae/evaluation/caption_results/umbrae/sub0${sub}_dim1024/fmricap.json \
caption/images --references_json caption/fmri_cococap.json
done
```

```bash
for sub in 1 2 5 7
do
python eval_bbox_rec.py --path_out "../umbrae/evaluation/bbox_results/umbrae/sub0${sub}_dim1024"
done
```

We also provide baseline results associated with BrainHub, including the captioning results from [SDRecon](https://github.com/yu-takagi/StableDiffusionReconstruction), [BrainCap](https://arxiv.org/abs/2305.11560), and [OneLLM](https://onellm.csuhan.com/), as well as the captioning and grounding results from [UMBRAE](https://weihaox.github.io/UMBRAE/).

## Citation

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ <h3 style="text-align:center"><em>Multimodal Brain Decoding from our UMBRAE appl
<p style="text-align: center;">
<a href="https://arxiv.org/pdf/2404.07202" target="_blank">[Paper]</a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/weihaox/UMBRAE" target="_blank">[Code]</a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/weihaox/UMBRAE" target="_blank">[Data]</a>
<a href="https://github.com/weihaox/BrainHub" target="_blank">[BrainHub]</a>
</p>
</font>
</div>
Expand Down
10 changes: 10 additions & 0 deletions umbrae/download_checkpoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
# ------------------------------------------------------------------
# @File : download_checkpoint.sh
# @Time : 2024/03/16 17:30:00
# @Author : Weihao Xia ([email protected])
# @Version : 1.0
# @Desc : download Checkpoints from Hugging Face
# ------------------------------------------------------------------

python -c 'from huggingface_hub import snapshot_download; snapshot_download(repo_id="weihaox/brainx", repo_type="dataset", local_dir="./", , ignore_patterns=["all_images.pt", ".gitattributes"])'
47 changes: 47 additions & 0 deletions umbrae/download_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
# ------------------------------------------------------------------
# @File : download_data.sh
# @Time : 2024/02/13 22:00:00
# @Author : Weihao Xia ([email protected])
# @Version : 1.0
# @Desc : download the Natural Scenes Dataset from Hugging Face
# ------------------------------------------------------------------

# set the destination
destination="nsd"

subdirs=("train" "test" "val")

for subdir in "${subdirs[@]}"; do
full_destination="${destination}/webdataset_avg_split/${subdir}/"
mkdir -p "$full_destination"
done

declare -a i_values=(1 2 5 7)

# Download the train set
for i in "${i_values[@]}"; do
for j in {0..17}; do
url="https://huggingface.co/datasets/pscotti/naturalscenesdataset/resolve/main/webdataset_avg_split/train/train_subj0${i}_${j}.tar"
wget -P "$train_destination" "$url"
done
done

# Download the validation set
for i in "${i_values[@]}"; do
url="https://huggingface.co/datasets/pscotti/naturalscenesdataset/resolve/main/webdataset_avg_split/val/val_subj0${i}_0.tar"
wget -P "$val_destination" "$url"
done
done

# Download the test set
for i in "${i_values[@]}"; do
for j in {0..1}; do
url="https://huggingface.co/datasets/pscotti/naturalscenesdataset/resolve/main/webdataset_avg_split/test/test_subj0${i}_${j}.tar"
wget -P "$test_destination" "$url"
done
done


# download test set images (just for evaluation)
wget -P "../brainhub/caption" "https://huggingface.co/datasets/weihaox/brainx/resolve/main/all_images.pt"
210 changes: 210 additions & 0 deletions umbrae/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
@File : inference_brainx.py
@Time : 2024/02/15 16:54:30
@Author : Weihao Xia
@Version : 1.0
@Desc :
python inference.py --fmri_encoder 'brainx' --subj 1 \
--data_path '/home/wx258/project/nsd_data' \
--brainx_path 'train_logs/brainx.pth' \
--prompt 'Describe this image <image> as simply as possible.' \
--save_path 'evaluation/eval_caption/caption_results/brainx_sub01_dim1024'
'''

import os
import json
import time
import argparse
import braceexpand
import webdataset as wds
import utils

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from torchvision.transforms import ToPILImage

from model import BrainX
from utils import postprocess

parser = argparse.ArgumentParser()
parser.add_argument('--shikra_path', default='model_weights/shikra-7b')
parser.add_argument('--brainx_path', default='train_logs/training_demo/best.pth')
parser.add_argument('--adapter_path', default='model_weights/mm_projector.bin')
parser.add_argument('--fmri_encoder', type=str, default='brainx', help='type of brainnet', choices=['brainx'])
parser.add_argument('--use_norm', type=bool, default=False, help='whether to use norm layer in the model')
parser.add_argument('--use_token', type=bool, default=False, help='whether to use learnable token in the model')
parser.add_argument('--feat_dim', type=int, help='output dimension of the fmri encoder', default=1024, choices=[1024, 4096])
parser.add_argument('--data_path', type=str, default="nsd_data", help='path to nsd data')
parser.add_argument('--save_path', type=str, default='results', help='path to save results')
parser.add_argument('--save_image', type=bool, default=False, help='save image or not')
parser.add_argument('--prompt', required=True, help='prompt for the model')
parser.add_argument('--subj', type=int, default=1, choices=[1, 2, 5, 7])
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
globals()[attribute_name] = getattr(args, attribute_name)

# need non-deterministic CuDNN for conv3D to work
utils.seed_everything(seed, cudnn_deterministic=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# prepare models and data loaders
print('prepare NSD webdataset data...')
val_url = f"{data_path}/webdataset_avg_split/test/test_subj0{subj}_" + "{0..1}.tar"
meta_url = f"{data_path}/webdataset_avg_split/metadata_subj0{subj}.json"
num_val = 982

# result_dir = os.path.join(os.path.dirname(__file__), 'results/sub{:02d}_dim{}'.format(subj, feat_dim))
result_dir = os.path.join(save_path, 'sub{:02d}_dim{}'.format(subj, feat_dim))
os.makedirs(result_dir, exist_ok=True)

# save config in a json file
args_dict = vars(args)
with open(os.path.join(result_dir, 'config.json'), 'w') as file:
json.dump(args_dict, file, indent=4)

print('prepare train and validation dataloaders...')
to_tuple = ["voxels", "images"]
val_batch_size = 1
split_by_node = lambda urls: urls
val_url = list(braceexpand.braceexpand(val_url))
val_data = wds.WebDataset(val_url, resampled=False, cache_dir=data_path, nodesplitter=split_by_node) \
.decode("torch")\
.rename(images="jpg;png", voxels='nsdgeneral.npy', trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy") \
.to_tuple(*to_tuple) \
.batched(val_batch_size, partial=False)

val_dl = torch.utils.data.DataLoader(val_data, batch_size=None, num_workers=1, shuffle=False)

voxels_per_subj = {1: 15724, 2: 14278, 3: 15226, 4: 13153, 5: 13039, 6: 17907, 7: 12682, 8: 14386}
num_voxels = voxels_per_subj.get(subj)

kwargs = {'hidden_dim': 1024, 'out_dim': feat_dim, 'num_latents': 256, 'use_norm': use_norm, 'use_token': use_token}

if fmri_encoder == 'brainx':
voxel2emb = BrainX(**kwargs)
else:
raise ValueError("The fmri encoder is not implemented.")
voxel2emb.to(device)

checkpoint = torch.load(brainx_path, map_location='cpu')
voxel2emb.load_state_dict(checkpoint['model_state_dict'], strict=False)
voxel2emb.eval()

# inference: predict image features from fmri
print('inference: predict image features from fmri...')
emb_voxel_list, image_list = [], []
for val_i, (voxel, image) in enumerate(val_dl):
with torch.no_grad():
with torch.cuda.amp.autocast():
# repeat_index = val_i % 3
# voxel = voxel[:,repeat_index].float()
voxel = torch.mean(voxel, axis=1).float()

emb_voxel = voxel2emb(voxel.to(device), modal=f'fmri{subj}')

emb_voxel_list.append(emb_voxel)
if save_image:
image_list.append(image) # for visualization

# assign image features to the predicted features from fmri
image_features = torch.cat(emb_voxel_list, dim=0)
print(f"image_features.shape: {image_features.shape}")

if save_image:
image_list = torch.cat(image_list, dim=0)

# load llama with the fine-tuned shikra model
finetuned_llama = shikra_path # 'model_weights/shikra-7b' # shikra
tokenizer = LlamaTokenizer.from_pretrained(finetuned_llama, padding_side='left')
model = LlamaForCausalLM.from_pretrained(finetuned_llama)
model.to(device)

if feat_dim == 1024:
# load mm_projector
mm_projector = torch.nn.Linear(1024, 4096)
mm_projector_weights = torch.load(adapter_path, map_location='cpu')
if adapter_path == 'model_weights/mm_projector.bin':
adjusted_state_dict = {k.split('.')[-1]: v for k, v in mm_projector_weights.items()}
mm_projector.load_state_dict(adjusted_state_dict)
else:
mm_projector.load_state_dict(mm_projector_weights['model_state_dict'], strict=False)

mm_projector.to("cuda:0")
image_features = mm_projector(image_features.to(torch.float32))
print(f"image_features.shape: {image_features.shape}")

# process prompt
system = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER:"
user_image = " <im_start>" + "<im_patch>" * 256 + "<im_end> "

if '<image>' in prompt:
user_prompt = prompt.replace('<image>', user_image)
else:
user_prompt = prompt + user_image
input_text = system + user_prompt + " ASSISTANT:"

input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)#.cuda()
inputs_embeds = model.model.embed_tokens(input_ids)

gen_kwargs = dict(
use_cache=True,
do_sample=False,
pad_token_id=2, # tokenizer.pad_token_id,
bos_token_id=1, # tokenizer.bos_token_id,
eos_token_id=2, # tokenizer.eos_token_id,
max_new_tokens=512,
)

cap_result = {}
for cur_image_idx in range(image_features.shape[0]):
new_input_embeds = []
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
image_start_tokens = torch.where(cur_input_ids == 32001)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != 32002:
raise ValueError("The image end token should follow the image start token.")

cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos + 1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
new_input_embeds.append(cur_new_input_embeds)
inputs_embeds = torch.stack(new_input_embeds, dim=0)

st_time = time.time()
with torch.inference_mode():
with torch.autocast(dtype=torch.float16, device_type='cuda'):
output_ids = model.generate(inputs_embeds=inputs_embeds.float(), **gen_kwargs)
print(f"done generated in {time.time() - st_time} seconds")

response = tokenizer.batch_decode(output_ids)[0]

# print(f"input: {input_text}")
print(f"response: {response.strip(' <s></s>')}")

# save response in a txt file
with open(os.path.join(result_dir, 'response.txt'), 'a') as f:
f.write(f'response_{cur_image_idx}: \n')
f.write(response + '\n')

# save response in a json file
bbox, caption = utils.extract_id_bbox_caption(response)
cap_result[str(cur_image_idx)] = caption # {image_name: caption; ...}

# save processed image (only for bbox tasks)
if save_image:
_, processed_image = postprocess(response, image=ToPILImage()(image_list[cur_image_idx]), width=5)
if processed_image is not None:
output_path = os.path.join(result_dir, f'{cur_image_idx}_prompt.png')
processed_image.save(output_path)

with open(os.path.join(result_dir, 'fmricap.json'), 'w') as f:
json.dump(cap_result, f)
Loading

0 comments on commit d5b9242

Please sign in to comment.