Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize and pre-train a smaller model #10

Open
win10ogod opened this issue Jul 10, 2024 · 2 comments
Open

Initialize and pre-train a smaller model #10

win10ogod opened this issue Jul 10, 2024 · 2 comments

Comments

@win10ogod
Copy link

Initialize and pre-train a smaller model
Please try initializing a smaller model

@EthanC111
Copy link
Collaborator

Thank you for your interest! Releasing a quantized model is on our TODO list!

@win10ogod
Copy link
Author

@EthanC111 Could you please look at this approach? Create smaller model configurations.

import torch
import deepspeed
import jsonlines

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import ChameleonForCausalLM, Trainer, TrainingArguments, AutoConfig

from constants_training import (
    ANOLE_PATH_HF,
    ANOLE_PATH_HF_TRAINED,
    DATASET_TOKENIZED_PATH
)

# Define the dataset class
class TokenizedDataset(Dataset):
    def __init__(self, filepath):
        self.tokenized_data = []
        with jsonlines.open(filepath) as reader:
            for obj in reader:
                self.tokenized_data.append(torch.tensor(obj['text_tokens'] + obj['image_tokens'], dtype=torch.long))
    
    def __len__(self):
        return len(self.tokenized_data)
    
    def __getitem__(self, idx):
        return self.tokenized_data[idx],

# Define custom collate function for DataLoader
def collate_fn(batch):
    batch_inputs = [item[0] for item in batch]
    batch_inputs_padded = pad_sequence(batch_inputs, batch_first=True, padding_value=-100)

    # Create attention masks
    attention_masks = torch.zeros_like(batch_inputs_padded, dtype=torch.long)
    attention_masks = attention_masks.masked_fill(batch_inputs_padded != -100, 1)
   
    return {'input_ids': batch_inputs_padded, 'attention_mask': attention_masks, 'labels': batch_inputs_padded.clone()}

# Function to create a smaller model configuration
def create_smaller_config(original_config, target_params=500_000_000):
    smaller_config = AutoConfig.from_pretrained(original_config)
    
    # Adjust the model size (this is a simplified approach and may need fine-tuning)
    scale_factor = (target_params / original_config.num_parameters()) ** (1/3)
    
    smaller_config.hidden_size = int(smaller_config.hidden_size * scale_factor)
    smaller_config.intermediate_size = int(smaller_config.intermediate_size * scale_factor)
    smaller_config.num_hidden_layers = int(smaller_config.num_hidden_layers * scale_factor)
    
    return smaller_config

# Initialize the original model configuration
original_config = AutoConfig.from_pretrained(ANOLE_PATH_HF)

# Create a smaller model configuration
smaller_config = create_smaller_config(original_config)

# Initialize the smaller model
model = ChameleonForCausalLM(config=smaller_config)

# Load weights from the larger model
large_model = ChameleonForCausalLM.from_pretrained(ANOLE_PATH_HF)

# Transfer weights where possible (this is a simplified approach and may need adjustment)
model.load_state_dict(large_model.state_dict(), strict=False)

print(model)

# Initialize the dataset
dataset = TokenizedDataset(DATASET_TOKENIZED_PATH)

# Define training arguments
training_args = TrainingArguments(
    output_dir=ANOLE_PATH_HF_TRAINED,
    learning_rate=1e-3,
    num_train_epochs=10,
    per_device_train_batch_size=1,
    save_steps=3000,
    fp16=False,
    logging_strategy="steps",
    logging_steps=1,  # Log every 1 steps
    deepspeed="ds_config.json"
)

# Initialize the Trainer with custom collate_fn
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collate_fn
)

# Train the model
trainer.train()

# Save the model
torch.save(model.state_dict(), ANOLE_PATH_HF_TRAINED / 'pytorch_model.bin')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants