diff --git a/README.md b/README.md index 8515f39..b59fd35 100644 --- a/README.md +++ b/README.md @@ -1 +1,212 @@ -# memkk \ No newline at end of file +
+📃 Paper +• + Data +• + Perturbed Data +• +Project Page +
+ + +This repository provides the PyTorch implementation of the paper "On Memorization of Large Language Models in Logical Reasoning". + +Introduction: In this work, we investigate memorization of LLMs in reasoning tasks. +- We propose a memorizatioin metric for reasoning tasks and a dynamically generated logical reasoning benchmark based on Knights and Knaves (K&K) puzzles. +- LLMs could achieve high training accuracy after fine-tuning, yet fail when those puzzles are slightly perturbed, suggesting that the models heavily rely on memorization to solve those training puzzles. +- On the other hand, fine-tuning also consistently improves generalization performance. In-depth analyses with perturbation tests, cross difficulty-level transferability, probing model internals, and fine-tuning with wrong answers suggest that the LLMs learn to reason on K&K puzzles despite training data memorization. +- Finally, we use puzzle-based indicators and model-based indicators to classify puzzles solved by reasoning v.s. memorization. + ++ +
+ + +## Updates +* `10/31/2024`: Code, data, ArXiv article and project page are available. + + + +## 🛠️ Installation + +```bash +conda env create -f environment.yml +conda activate kk +``` + +## 📝 Synthetic data + ++ +
+ +### Option 1: use HF dataset + +When using our code for evaluation / fine-tuning, we import the datasets from huggingface: + +```python +import datasets +datasets.load_dataset('K-and-K/knights-and-knaves') +datasets.load_dataset('K-and-K/perturbed-knights-and-knaves') +``` + +### Option 2: generate data locally +To generate K&K data for {2,3,4,5,6,7,8}-people puzzles with a train/test split, run: + +```bash +python data_prep/data_gen_kk.py +``` + +Locally perturbed data will also be generated. The generated data will be stored in the `data` directory. + +In addition, you can use it to generate wrong answer data and wrong CoT data (including one wrong step and shuffuled CoT steps). + + + +## 🤖 Evaluation + +Some general evaluation parameters: +| Argument | Example | Description | +|-------------------------|--------------------------------|---------------------------------------------------------------------------------------------------------| +| `--max_token` | `2048` | Maximum number of tokens. | +| `--split` | `train`, `test` | Choose the data split for evaluation. | +| `--limit` | `100` | Limit the number of evaluation samples. | +| `--ntrain` | `0`, `1` | Number of demonstrations for 0-shot/few-shot prompting. | +| `--problem_type` | `clean`, `perturbed_statement`, `perturbed_leaf`, `random_pair`, `reorder_statement`, `uncommon_name`, `flip_role` | Type of problem, supporting various perturbations. | +| `--eval_nppl` | `2`,`3`,`4`,`5`,`6`,`7`,`8` | Number of people in K&K puzzles. If not set, it will evaluate all n-ppl tasks. | +| `--vllm` | `true` | Enable VLLM for faster inference for open-source models. | +| `--model` | `openai/gpt-4o-mini-2024-07-18` | The model to be evaluated. We support open-source and closed-sourced models. | + + +### Evaluation on test samples + +For each K&K task, evaluate all test samples (100 samples). + +Evaluate on test samples under 1/0-shot & with/without CoT by running: + +```bash +bash scripts/eval/run_test.sh +``` + +Evaluate under 0-shot & without CoT on 2 math-level perturbation types (`perturbed_statement`, `perturbed_leaf`): + +```bash +bash scripts/eval/eval_test_pertub.sh +``` + +### Evaluation on training samples +After fine-tuning the models following `## 4. Fine-Tuning`, we evaluate on training samples. +We evaluate the first 100 samples for the fine-tuned GPT-4o-mini, and all samples for open-source models. + +Evaluate under 0-shot & without CoT + +```bash +bash scripts/eval/eval_train.sh +``` + +Evaluation on Perturbed Training Samples: + +Evaluate under 0-shot & without CoT on 6 perturbation types (`perturbed_statement`, `perturbed_leaf` `random_pair`, `reorder_statement`, `uncommon_name`, `flip_role`): + +```bash +bash scripts/eval/eval_train_pertub.sh +``` + +#### Evaluation on closed-sourced models + +Provide API keys: +```bash +export OPENAI_API_KEY='your-api-key-here' +export ANTHROPIC_API_KEY='your-api-key-here' +``` + +Example usages for OpenAI/Anthropic models with direct prompting: +```bash +bash scripts/eval/gpt4omini_direct.sh +bash scripts/eval/claude-sonet.sh +``` + +Evaluate with cot prompting: +```bash +bash scripts/eval/gpt4omini_cot.sh +``` + +## 🚗 Fine-tuning + + +### Direct fine-tune + +To fine-tune the model directly on answers (without CoT), run: + +```bash +bash scripts/ft/ft_lm3.sh +``` + +### CoT fine-tune + +To fine-tune the model with CoT, run: + +```bash +bash scripts/ft/ft_lm3_cot.sh + +``` +You can change the saved model path `output_dir` in the above scripts. + +#### Merge fine-tuned adapter and base model + +Load the saved adapter from fine-tuning, as well as the base model, then save the merged model by running: + +```bash +bash scripts/ft/merge_adapter.sh +``` + +Make sure to change the model paths `base_model_path`, `adapter_path`, `base_model_path` in the script as needed. + +#### Fine-tune closed-sourced models +For closed-sourced models, we follow the [OpenAI finetuning API](https://platform.openai.com/docs/guides/fine-tuning) to finetune GPT-4o-mini. + + +## 🔍 Probe + +To probe the model's internal representations, update the model paths and the number of ppl in the puzzles for evaluation in the script: +```bash +bash scripts/probe/run.sh +``` + + +## 🗃️ Sample classification +Here we classify on consistenly solved v.s. non consistenly solved puzzles. + +Update the model paths and provide data with binary label of consistenly solved v.s. non consistenly solved for each training sample, and then run the following: + +Classification with puzzled-based indicators: + +```bash +bash scripts/mem_classify/model_indicator.sh +``` + +Classification with model-bases indicators: + +```bash +bash scripts/mem_classify/puzzle_indicator.sh +``` + +## 📚 Citation +If you find our work helpful, please consider citing it as follows: +```bibtex +@article{xie2024memorization, +title={On Memorization of Large Language Models in Logical Reasoning}, +author={Chulin Xie and Yangsibo Huang and Chiyuan Zhang and Da Yu and Xinyun Chen and Bill Yuchen Lin and Bo Li and Badih Ghazi and Ravi Kumar}, +year={2024}, +eprint={2410.23123}, +archivePrefix={arXiv}, +primaryClass={cs.CL}, +url={https://arxiv.org/abs/2410.23123}, +} +``` + + +## 📖 Questions +Please reach out to us if you have any suggestions or need any help in reproducing the results. You can submit an issue or pull request, or send an email to chulinx2@illinois.edu. diff --git a/data_prep/data_gen_kk.py b/data_prep/data_gen_kk.py new file mode 100644 index 0000000..58d2720 --- /dev/null +++ b/data_prep/data_gen_kk.py @@ -0,0 +1,417 @@ +import copy +import os +import sys +import importlib +import pprint + +module_path = os.path.abspath('.') +if not module_path in sys.path: + sys.path.append(module_path) +import lib_kk +importlib.reload(lib_kk) +import numpy as np +import json +import os +from utils import load_jsonl,write_jsonl, init_seed + +init_seed(42) + +import time + + +def convert_int_to_str(data): + return str(data) + + +def combine_train_data(data_folder,file_config, output_name): + result_records=[] + for config in file_config: + file_path = os.path.join(data_folder, config[0]) + records = load_jsonl(file_path) + print(f"Loaded {len(records)} records from {file_path}") + if config[1] < len(records): + records = records[:config[1]] + result_records.extend(records) + output_file=os.path.join(data_folder, output_name) + write_jsonl(output_file, result_records) + + +def format_solution_text(ans): + gold = ans.replace(" and ", "").replace(".", "") + gold_conditions=gold.split(",") + reformat_gold_conditions=[] + for condition in gold_conditions: + # Remove leading and trailing spaces + gold_condition=condition.strip() + reformat_gold_conditions.append(gold_condition) + + formatted_statements = "\n".join([f"({i+1}) {reformat_gold_conditions[i]}" for i in range(len(reformat_gold_conditions))]) + return formatted_statements + + +def generate_problems(n_problems, n_people, gen_perturb=True): + problems = [] + problem_seed=1234 + start_time = time.time() + problem_sampler = lib_kk.KKProblemSampler(problem_seed, n_people=n_people) + problems = problem_sampler.sample_valid_problems(n_problems) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time: {elapsed_time} seconds") + print(f'{len(problems)} valid problems generated') + if gen_perturb: + start_time = time.time() + per_stat = problem_sampler.perturb_problems(problems, perturb_type='statement', num_perturb=1) + perturbed_problems_statement = [item for inner_list in per_stat for item in inner_list] + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time: {elapsed_time} seconds") + print(f'{len(perturbed_problems_statement)} perturbed (statement) problems generated') + + start_time = time.time() + per_stat = problem_sampler.perturb_problems(problems, perturb_type='leaf', num_perturb=1) + perturbed_problems_leaf = [item for inner_list in per_stat for item in inner_list] + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time: {elapsed_time} seconds") + print(f'{len(perturbed_problems_leaf)} perturbed (leaf) problems generated') + + return problems, perturbed_problems_statement, perturbed_problems_leaf + else: + return problems, None, None + +def generate_wrong_problems(n_problems, n_people): + problems = [] + problem_seed=1234 + start_time = time.time() + problem_sampler = lib_kk.KKProblemSampler(problem_seed, n_people=n_people) + problems = problem_sampler.sample_invalid_problems(n_problems) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time: {elapsed_time} seconds") + print(f'{len(problems)} valid problems with wrong answers generated') + + return problems + + + +def generate_formatted_problem(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair,uncommon_name=False, reorder_statement=False): + data =[] + problem_seed=1234 + for i in range(item_start_idx, item_start_idx+ num_samples): + problem= problems[i] + if problem is None: + continue + + formatter_seed= problem_seed+i + formatter = lib_kk.KKProblemFormatter(formatter_seed, problem) + formatted_problem = formatter.format_problem(random_knight_knave_pairs=random_knight_knave_pairs, + flip_knight_knave_pair=flip_knight_knave_pair, + random_names=True, random_saying_template=True, + uncommon_name=uncommon_name, reorder_statement=reorder_statement) + + chain_of_thoughts = lib_kk.generate_chain_of_thoughts(problem['statements']) + header, steps, footer = lib_kk.format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=False, repeat_claim_for_contradiction=False) + + repeat_header, repeat_steps, repeat_footer = lib_kk.format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=True, repeat_claim_for_contradiction=True) + + item= copy.deepcopy(formatted_problem) + item["solution_text_format"]= format_solution_text(item["solution_text"]) + item["cot_head"]=header + item["cot_repeat_steps"]=repeat_steps + item["cot_foot"]=footer + item["statements"]=convert_int_to_str(problem["statements"]) # convert 0/1 into "0"/"1" for future json loading + item["index"] = i + + data.append(item) + return data + + +def generate_data(num_samples_test, num_samples_train, num_samples_val, n_people): + num_problems=num_samples_test+num_samples_train+num_samples_val + + clean_problems, perturbed_problems_statement, perturbed_problems_leaf = generate_problems(num_problems, n_people, gen_perturb=True) + problems_dict={ + "clean": clean_problems, + "perturbed_statement": perturbed_problems_statement, + "perturbed_leaf": perturbed_problems_leaf + } + + random_knight_knave_pairs=False + flip_knight_knave_pair=False + uncommon_name=False + for problem_type, problems in problems_dict.items(): + item_start_idx=0 + for (split, num_samples) in [("test", num_samples_test), ("train", num_samples_train), ("val", num_samples_val)]: + if num_samples==0: + continue + data= generate_formatted_problem(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair, uncommon_name) + + config=f"people{n_people}_num{num_samples}" + + if random_knight_knave_pairs: + config +="_random_pair" + if flip_knight_knave_pair: + config +="_flip_role" + if uncommon_name: + config +="_uncommon_name" + + output_folder=f"data/{split}/{problem_type}" + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, f'{config}.jsonl') + with open(output_file, 'w') as file: + for item in data: + json_line = json.dumps(item) + file.write(json_line + '\n') + print(f"Data has been written to {output_file}") + item_start_idx+=num_samples + +def generate_data_language_perturb(num_samples_test, num_samples_train, num_samples_val, n_people): + num_problems=num_samples_test+num_samples_train+num_samples_val + + clean_problems, _, _ = generate_problems(num_problems, n_people, gen_perturb=False) + problems_dict={ + "clean": clean_problems, + } + perturb_list=["random_pair", "flip_role", "uncommon_name", "reorder_statement"] + + for perturb_type in perturb_list: + random_knight_knave_pairs=False + flip_knight_knave_pair=False + uncommon_name=False + reorder_statement=False + if perturb_type=="random_pair": + random_knight_knave_pairs=True + elif perturb_type=="flip_role": + flip_knight_knave_pair=True + elif perturb_type=="uncommon_name": + uncommon_name=True + elif perturb_type=="reorder_statement": + reorder_statement=True + + item_start_idx=0 + for (split, num_samples) in [("test", num_samples_test), ("train", num_samples_train), ("val", num_samples_val)]: + if num_samples==0: + continue + data= generate_formatted_problem(clean_problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair, uncommon_name,reorder_statement) + + config=f"people{n_people}_num{num_samples}" + + + output_folder=f"data/{split}/{perturb_type}" + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, f'{config}.jsonl') + with open(output_file, 'w') as file: + for item in data: + json_line = json.dumps(item) + file.write(json_line + '\n') + print(f"Data has been written to {output_file}") + item_start_idx+=num_samples + + +def generate_formatted_wrong_problem(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair,uncommon_name=False): + data =[] + problem_seed=1234 + for i in range(item_start_idx, item_start_idx+ num_samples): + problem= problems[i] + if problem is None: + continue + + formatter_seed= problem_seed+i + formatter = lib_kk.KKProblemFormatter(formatter_seed, problem) + formatted_problem = formatter.format_problem(random_knight_knave_pairs=random_knight_knave_pairs, + flip_knight_knave_pair=flip_knight_knave_pair, + random_names=True, random_saying_template=True, + uncommon_name=uncommon_name) + + item= copy.deepcopy(formatted_problem) + item["solution_text_format"]= format_solution_text(item["solution_text"]) + item["cot_head"]="placeholder" + item["cot_repeat_steps"]=["placeholder"] + item["cot_foot"]="placeholder" + item["statements"]=convert_int_to_str(problem["statements"]) # convert 0/1 into "0"/"1" for future json loading + item["index"] = i + + data.append(item) + return data + + +def generate_wrong_data(num_samples_test, num_samples_train, num_samples_val, n_people): + num_problems=num_samples_test+num_samples_train+num_samples_val + + clean_problems = generate_wrong_problems(num_problems, n_people) + problems_dict={ + "clean": clean_problems, + } + random_knight_knave_pairs=False + flip_knight_knave_pair=False + uncommon_name=False + for problem_type, problems in problems_dict.items(): + item_start_idx=0 + for (split, num_samples) in [("test", num_samples_test), ("train", num_samples_train), ("val", num_samples_val)]: + if num_samples==0: + continue + data= generate_formatted_wrong_problem(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair, uncommon_name) + + config=f"people{n_people}_num{num_samples}" + + if random_knight_knave_pairs: + config +="_random_pair" + if flip_knight_knave_pair: + config +="_flip_role" + if uncommon_name: + config +="_uncommon_name" + + output_folder=f"data/wrong/{split}/{problem_type}" + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, f'{config}.jsonl') + with open(output_file, 'w') as file: + for item in data: + json_line = json.dumps(item) + file.write(json_line + '\n') + print(f"Data has been written to {output_file}") + item_start_idx+=num_samples + + + + +def generate_formatted_wrong_cot(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair,uncommon_name=False, wrong_type="shuffle" ): + data =[] + problem_seed=1234 + for i in range(item_start_idx, item_start_idx+ num_samples): + problem= problems[i] + if problem is None: + continue + + formatter_seed= problem_seed+i + rng = np.random.default_rng(formatter_seed) + formatter = lib_kk.KKProblemFormatter(formatter_seed, problem) + formatted_problem = formatter.format_problem(random_knight_knave_pairs=random_knight_knave_pairs, + flip_knight_knave_pair=flip_knight_knave_pair, + random_names=True, random_saying_template=True, + uncommon_name=uncommon_name) + + chain_of_thoughts = lib_kk.generate_chain_of_thoughts(problem['statements']) + header, steps, footer = lib_kk.format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=False, repeat_claim_for_contradiction=False) + + repeat_header, repeat_steps, repeat_footer = lib_kk.format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=True, repeat_claim_for_contradiction=True) + + if wrong_type=="shuffle": + rng.shuffle(repeat_steps) + item= copy.deepcopy(formatted_problem) + item["solution_text_format"]= format_solution_text(item["solution_text"]) + item["cot_head"]=header + item["cot_repeat_steps"]=repeat_steps + item["cot_foot"]=footer + item["statements"]=convert_int_to_str(problem["statements"]) # convert 0/1 into "0"/"1" for future json loading + item["index"] = i + + data.append(item) + + if wrong_type=="replace_one_step": + rng = np.random.default_rng(problem_seed) + for j , item in enumerate(data): + wrong_step_idx=rng.integers(0, len(item["cot_repeat_steps"])) + original_step=item["cot_repeat_steps"][wrong_step_idx] + + possible_replacements = [i for i in range(len((data))) if j != i] + + while True: + replace_item_idx= rng.choice(possible_replacements) + replace_item = data[replace_item_idx] + replace_step_idx=rng.integers(0, len(replace_item["cot_repeat_steps"])) + replace_step = replace_item["cot_repeat_steps"][replace_step_idx] + for name_idx, name in enumerate(replace_item["names"]): + replace_step=replace_step.replace(name, item["names"][name_idx]) + if original_step!=replace_step: + item["cot_repeat_steps"][wrong_step_idx]=replace_step + break + + return data + + +def generate_wrong_cot_data(num_samples_test, num_samples_train, num_samples_val, n_people, wrong_type="shuffle"): + num_problems=num_samples_test+num_samples_train+num_samples_val + + clean_problems, _, _ = generate_problems(num_problems, n_people, gen_perturb=False) + problems_dict={ + "clean": clean_problems, + } + random_knight_knave_pairs=False + flip_knight_knave_pair=False + uncommon_name=False + for problem_type, problems in problems_dict.items(): + item_start_idx=0 + for (split, num_samples) in [("test", num_samples_test), ("train", num_samples_train), ("val", num_samples_val)]: + if num_samples==0: + continue + data= generate_formatted_wrong_cot(problems, item_start_idx, num_samples, random_knight_knave_pairs, flip_knight_knave_pair, uncommon_name, wrong_type) + + config=f"people{n_people}_num{num_samples}" + + if random_knight_knave_pairs: + config +="_random_pair" + if flip_knight_knave_pair: + config +="_flip_role" + if uncommon_name: + config +="_uncommon_name" + + output_folder=f"data/wrong_cot_{wrong_type}/{split}/{problem_type}" + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, f'{config}_wrong1.jsonl') + with open(output_file, 'w') as file: + for item in data: + json_line = json.dumps(item) + file.write(json_line + '\n') + print(f"Data has been written to {output_file}") + item_start_idx+=num_samples + +#### main & leaf/statement perturbed generation +for n_people in [2]: + generate_data(num_samples_test=100,num_samples_train=200,num_samples_val=0, + n_people=n_people) + +for n_people in [3, 4,5,6,7,8]: + generate_data(num_samples_test=100,num_samples_train=1000,num_samples_val=0, + n_people=n_people) + + +#### LANAGUGE perturbation +for n_people in [2]: + generate_data_language_perturb(num_samples_test=100,num_samples_train=200,num_samples_val=0, + n_people=n_people) + + +for n_people in [3, 4,5,6,7,8]: + generate_data_language_perturb(num_samples_test=100,num_samples_train=1000,num_samples_val=0, + n_people=n_people) + + +# #### wrong CoT generation +# wrong_type="replace_one_step" + +# for n_people in [5]: +# generate_wrong_cot_data(num_samples_test=100,num_samples_train=1000,num_samples_val=0, +# n_people=n_people,wrong_type=wrong_type) + +# wrong_type="shuffle" + +# for n_people in [5]: +# generate_wrong_cot_data(num_samples_test=100,num_samples_train=1000,num_samples_val=0, +# n_people=n_people,wrong_type=wrong_type) + + +# #### wrong answer generation +# for n_people in [2]: +# generate_wrong_data(num_samples_test=100,num_samples_train=200,num_samples_val=0, +# n_people=n_people) +# for n_people in [3, 4, 5,6,7,8]: +# generate_wrong_data(num_samples_test=100,num_samples_train=1000,num_samples_val=0, +# n_people=n_people) + + + diff --git a/data_prep/hf_dataset.ipynb b/data_prep/hf_dataset.ipynb new file mode 100644 index 0000000..e4dab25 --- /dev/null +++ b/data_prep/hf_dataset.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['quiz', 'names', 'knight_knave', 'solution', 'solution_text', 'solution_text_format', 'cot_head', 'cot_repeat_steps', 'cot_foot', 'statements', 'index'],\n", + " num_rows: 100\n", + "})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import datasets\n", + "\n", + "train_ds = datasets.load_dataset('K-and-K/knights-and-knaves','test',split=\"2ppl\")\n", + "train_ds" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating train split: 200 examples [00:00, 39583.84 examples/s]\n", + "Generating test split: 100 examples [00:00, 10649.23 examples/s]\n" + ] + } + ], + "source": [ + "kk_dataset = datasets.load_dataset('K-and-K/knights-and-knaves',data_files={\n", + " \"train\": [\"train/people2_num200.jsonl\"],\n", + " \"test\": [\"test/people2_num100.jsonl\"],\n", + " },)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['quiz', 'names', 'knight_knave', 'solution', 'solution_text', 'solution_text_format', 'cot_head', 'cot_repeat_steps', 'cot_foot', 'statements', 'index'],\n", + " num_rows: 200\n", + " })\n", + " test: Dataset({\n", + " features: ['quiz', 'names', 'knight_knave', 'solution', 'solution_text', 'solution_text_format', 'cot_head', 'cot_repeat_steps', 'cot_foot', 'statements', 'index'],\n", + " num_rows: 100\n", + " })\n", + "})" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kk_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading data: 100%|██████████| 146k/146k [00:00<00:00, 2.16MB/s]\n", + "Generating train split: 100 examples [00:00, 17525.92 examples/s]\n" + ] + } + ], + "source": [ + "train_ds = datasets.load_dataset('K-and-K/perturbed-knights-and-knaves',data_files=\"test/random_pair/people2_num100.jsonl\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['quiz', 'names', 'knight_knave', 'solution', 'solution_text', 'solution_text_format', 'cot_head', 'cot_repeat_steps', 'cot_foot', 'statements', 'index'],\n", + " num_rows: 5\n", + "})" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_ds['train'].select(range(5))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "kk", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_prep/knight_and_knave.ipynb b/data_prep/knight_and_knave.ipynb new file mode 100644 index 0000000..cfabb5a --- /dev/null +++ b/data_prep/knight_and_knave.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# K&K Problem Generation Library\n", + "\n", + "**NOTE: This notebook is for demonstration purpose and the latest code is refactored into `lib_kk.py`**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import importlib\n", + "import pprint\n", + "\n", + "module_path = os.path.abspath('.')\n", + "if not module_path in sys.path:\n", + " sys.path.append(module_path)\n", + "import lib_kk\n", + "importlib.reload(lib_kk)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'all_solutions': [(True, False, True, False, True)],\n", + " 'solution': (True, False, True, False, True),\n", + " 'statements': (('or', ('telling-truth', 3), ('telling-truth', 2)),\n", + " ('not', ('telling-truth', 2)),\n", + " ('->', ('lying', 0), ('telling-truth', 3)),\n", + " ('->', ('lying', 1), ('lying', 4)),\n", + " ('not', ('lying', 0)))}\n" + ] + } + ], + "source": [ + "n_people = 5\n", + "problem_sampler = lib_kk.KKProblemSampler(1234, n_people=n_people)\n", + "problems = problem_sampler.sample_valid_problems(10)\n", + "\n", + "pprint.pprint(problems[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(False, False, False, False, False)]\n", + "[('proposal', {'assignment': True, 'outcome': 'ok', 'person': 0}),\n", + " ('proposal',\n", + " {'assignment': True,\n", + " 'conflict_statement': (0, True),\n", + " 'outcome': 'conflict',\n", + " 'person': 3}),\n", + " ('proposal',\n", + " {'assignment': False,\n", + " 'conflict_statement': (3, False),\n", + " 'outcome': 'conflict',\n", + " 'person': 3}),\n", + " ('reconsider', {'exhausted': [3], 'person': 0}),\n", + " ('proposal', {'assignment': False, 'outcome': 'ok', 'person': 0}),\n", + " ('proposal',\n", + " {'assignment': True,\n", + " 'conflict_statement': (3, True),\n", + " 'outcome': 'conflict',\n", + " 'person': 3}),\n", + " ('proposal', {'assignment': False, 'outcome': 'ok', 'person': 3}),\n", + " ('proposal',\n", + " {'assignment': True,\n", + " 'conflict_statement': (0, False),\n", + " 'outcome': 'conflict',\n", + " 'person': 4}),\n", + " ('proposal', {'assignment': False, 'outcome': 'ok', 'person': 4}),\n", + " ('proposal',\n", + " {'assignment': True,\n", + " 'conflict_statement': (2, True),\n", + " 'outcome': 'conflict',\n", + " 'person': 2}),\n", + " ('proposal', {'assignment': False, 'outcome': 'ok', 'person': 2}),\n", + " ('proposal',\n", + " {'assignment': True,\n", + " 'conflict_statement': (1, True),\n", + " 'outcome': 'conflict',\n", + " 'person': 1}),\n", + " ('proposal', {'assignment': False, 'outcome': 'ok', 'person': 1}),\n", + " ('success', {'assignments': (False, False, False, False, False)})]\n" + ] + } + ], + "source": [ + "problem = problems[2]\n", + "\n", + "print(lib_kk.find_solution(problem['statements']))\n", + "chain_of_thoughts = lib_kk.generate_chain_of_thoughts(problem['statements'])\n", + "pprint.pprint(chain_of_thoughts)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'knight_knave': {'Knave': 'Knave',\n", + " 'Knight': 'Knight',\n", + " 'a_knave': 'a knave',\n", + " 'a_knight': 'a knight',\n", + " 'knave': 'knave',\n", + " 'knight': 'knight'},\n", + " 'names': ['David', 'Zoey', 'Alexander', 'Aurora', 'Isabella'],\n", + " 'quiz': 'A very special island is inhabited only by knights and knaves. '\n", + " 'Knights always tell the truth, and knaves always lie. You meet 5 '\n", + " 'inhabitants: David, Zoey, Alexander, Aurora, and Isabella. In '\n", + " 'David\\'s words: \"Aurora is a knave and Isabella is a knight\". Zoey '\n", + " 'expressed that Aurora is a knave if and only if Isabella is a '\n", + " 'knight. Alexander said, \"Isabella is a knight.\" Aurora commented, '\n", + " '\"David is a knight\". According to Isabella, \"Alexander is a knight '\n", + " 'if and only if Alexander is a knave\". So who is a knight and who is '\n", + " 'a knave?',\n", + " 'solution': (False, False, False, False, False),\n", + " 'solution_text': 'David is a knave, Zoey is a knave, Alexander is a knave, '\n", + " 'Aurora is a knave, and Isabella is a knave.'}\n", + "============================================================\n", + "A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 5 inhabitants: David, Zoey, Alexander, Aurora, and Isabella. In David's words: \"Aurora is a knave and Isabella is a knight\". Zoey expressed that Aurora is a knave if and only if Isabella is a knight. Alexander said, \"Isabella is a knight.\" Aurora commented, \"David is a knight\". According to Isabella, \"Alexander is a knight if and only if Alexander is a knave\". So who is a knight and who is a knave?\n" + ] + } + ], + "source": [ + "formatter = lib_kk.KKProblemFormatter(1234, problem)\n", + "formatted_problem = formatter.format_problem(random_knight_knave_pairs=False, flip_knight_knave_pair=False)\n", + "pprint.pprint(formatted_problem)\n", + "print('='*60)\n", + "print(formatted_problem['quiz'])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Let's think step by step, by considering whether each person is lying and if that leads to contradiction.\n", + "\\begin{enumerate}[leftmargin=*, itemsep=0pt, topsep=0pt, partopsep=0pt]\n", + "\\item Assume David is a knight. No contradiction is found in their claim that Aurora is a knave and Isabella is a knight.\n", + "\\item Aurora cannot be a knight, because this would contradict the claim of David that Aurora is a knave and Isabella is a knight.\n", + "\\item Aurora cannot be a knave, because this would contradict the false claim of their own that David is a knight.\n", + "\\item We have exhausted all possibilities for Aurora, so let us go back and reconsider David.\n", + "\\item Assume David is a knave. No contradiction is found in their false claim that Aurora is a knave and Isabella is a knight.\n", + "\\item Aurora cannot be a knight, because this would contradict the claim of their own that David is a knight.\n", + "\\item Assume Aurora is a knave. No contradiction is found in their false claim that David is a knight.\n", + "\\item Isabella cannot be a knight, because this would contradict the false claim of David that Aurora is a knave and Isabella is a knight.\n", + "\\item Assume Isabella is a knave. No contradiction is found in their false claim that Alexander is a knight if and only if Alexander is a knave.\n", + "\\item Alexander cannot be a knight, because this would contradict the claim of their own that Isabella is a knight.\n", + "\\item Assume Alexander is a knave. No contradiction is found in their false claim that Isabella is a knight.\n", + "\\item Zoey cannot be a knight, because this would contradict the claim of their own that Aurora is a knave if and only if Isabella is a knight.\n", + "\\item Assume Zoey is a knave. No contradiction is found in their false claim that Aurora is a knave if and only if Isabella is a knight.\n", + "\\end{enumerate}\n", + "This leads to a feasible solution.\n" + ] + } + ], + "source": [ + "begin, steps, end = lib_kk.format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, \n", + " repeat_claim_for_assumption=True, repeat_claim_for_contradiction=True)\n", + "print(begin)\n", + "print('\\\\begin{enumerate}[leftmargin=*, itemsep=0pt, topsep=0pt, partopsep=0pt]')\n", + "for s in steps:\n", + " print(f'\\\\item {s}')\n", + "print('\\\\end{enumerate}')\n", + "print(end)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "kk", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data_prep/lib_kk.py b/data_prep/lib_kk.py new file mode 100644 index 0000000..040bcd7 --- /dev/null +++ b/data_prep/lib_kk.py @@ -0,0 +1,947 @@ +"""Knight and Knave problems. + +Each person can have the following (recursive) statements: + - assertion: (telling-truth, i), (lying, i) + - negation: (not, statement) + - conjunction: (and, statement1, statement2), could support more than 2 + - disjunction: (or, statement1, statement2), could support more than 2 + - implication: (->, statement1, statement2) + - equivalence: (<=>, statement1, statement2) + +Please see the unit tests at the bottom on examples of how to use each API. +""" + +import copy +import enum +import itertools +import pprint +import unittest + +import numpy as np + + +#################################################################################### +# Problem Solving +#################################################################################### +def find_solution(statements): + """Find solutions given a list of statements.""" + n_people = len(statements) + single_statement = ('and',) + tuple(('<=>', ('telling-truth', i), statements[i]) + for i in range(len(statements))) + # brute force + solutions = [] + for assignments in itertools.product([True, False], repeat=n_people): + if test_satisfiability(single_statement, assignments): + solutions.append(assignments) + + return solutions + + +def test_satisfiability(statement, assignments): + """Dumb recursive testing.""" + if statement[0] == 'telling-truth': + return assignments[statement[1]] + if statement[0] == 'lying': + return not assignments[statement[1]] + if statement[0] == 'not': + return not test_satisfiability(statement[1], assignments) + if statement[0] == 'and': + return np.all([test_satisfiability(statement[i], assignments) + for i in range(1, len(statement))]) + if statement[0] == 'or': + return np.any([test_satisfiability(statement[i], assignments) + for i in range(1, len(statement))]) + if statement[0] == '->': + val1 = test_satisfiability(statement[1], assignments) + val2 = test_satisfiability(statement[2], assignments) + return (not val1) or val2 + if statement[0] == '<=>': + val1 = test_satisfiability(statement[1], assignments) + val2 = test_satisfiability(statement[2], assignments) + return (val1 and val2) or ((not val1) and (not val2)) + raise KeyError(f'Unknown statement: {statement}') + + +#################################################################################### +# Problem Sampling +#################################################################################### +class KKProblemSampler: + """Problem Sampler for Knight and Knave. + + Args: + rand_seed: seed for random number generators. + n_people: number of people for K&K problems. + depth_constraint: the max depth of each person's statement. The depth refer to the level of + recursion of operators such as 'and', 'or', etc. Increasing the depth would allow + increasing the difficulty. Though currently the automatic formatting of the problems + into nautral languages does not support depth more than 2. + width_constraint: the max width (number of branches in operators such as 'and', 'or') of each + person's statement. + """ + + def __init__(self, rand_seed: int, n_people: int, depth_constraint: int = 2, width_constraint: int = 2): + self.rng = np.random.default_rng(rand_seed) + self.rng_wrong = np.random.default_rng(rand_seed+1) + self.n_people = n_people + self.depth_constraint = depth_constraint + self.width_constraint = width_constraint + + def sample(self): + """Sample a single K&K problem.""" + statements = tuple(self._sample_statement(person_id, self.depth_constraint) + for person_id in range(self.n_people)) + return self._immutable_statements(statements) + + def sample_valid_problems(self, n_problems: int, max_retry: int = 1000, + skip_no_solution: bool = True, skip_multiple_solutions: bool = True): + """Sample valid (has 1 unique solution) problems. + + Args: + n_problems: how many problems to sample. + max_retry: max number of retries per problem before giving up. + skip_no_solution: skip problems without a valid solution. + skip_multiple_solutions: skip problems with more than one solutions. + + Returns + A list of problems, each a dict with keys 'statements' and 'solution'. + """ + problems = [] + unique_statements = set() + for i_problem in range(n_problems): + for _ in range(max_retry): + statements = self.sample() + if statements in unique_statements: + continue # duplicated problem, retry + solutions = find_solution(statements) + if len(solutions) == 0 and skip_no_solution: + continue # retry + if len(solutions) > 1 and skip_multiple_solutions: + continue # retry + sol = solutions[0] if len(solutions) > 0 else None + problems.append({'statements': statements, 'solution': sol, + 'all_solutions': solutions}) + unique_statements.add(statements) + break # continue to next problem + if i_problem + 1 < len(problems): + raise RuntimeError(f'Failed to generate a valid problem after {max_retry} retries.') + return problems + + def sample_flipped_solution(self, solution): + length_of_solution = len(solution) + # Randomly decide how many items to flip (at least one) + num_to_perturb = self.rng_wrong.integers(1, length_of_solution) + + # Randomly choose indices to perturb + indices_to_perturb = list(self.rng_wrong.choice(list(range(length_of_solution)), size=num_to_perturb, replace=False)) + + # Create a new solution with perturbed values + perturbed_solution = tuple( + not solution[i] if i in indices_to_perturb else solution[i] + for i in range(length_of_solution) + ) + return perturbed_solution + + + def sample_invalid_problems(self, n_problems: int, max_retry: int = 1000, + skip_no_solution: bool = True, skip_multiple_solutions: bool = True): + """Sample valid (has 1 unique solution) problems and then perturb the solution. + + Args: + n_problems: how many problems to sample. + max_retry: max number of retries per problem before giving up. + skip_no_solution: skip problems without a valid solution. + skip_multiple_solutions: skip problems with more than one solutions. + + Returns + A list of problems, each a dict with keys 'statements' and 'solution'. + """ + problems = [] + unique_statements = set() + for i_problem in range(n_problems): + for _ in range(max_retry): + statements = self.sample() + if statements in unique_statements: + continue # duplicated problem, retry + solutions = find_solution(statements) + if len(solutions) == 0 and skip_no_solution: + continue # retry + if len(solutions) > 1 and skip_multiple_solutions: + continue # retry + sol = solutions[0] if len(solutions) > 0 else None + ## perturbed + perturbed_sol=self.sample_flipped_solution(sol) + problems.append({'statements': statements, 'solution': perturbed_sol, + 'all_solutions': [perturbed_sol]}) + unique_statements.add(statements) + break # continue to next problem + if i_problem + 1 < len(problems): + raise RuntimeError(f'Failed to generate a valid problem after {max_retry} retries.') + return problems + + + def perturb_problems(self, problems, max_retry: int = 1000, perturb_type: str = 'statement', + num_perturb: int = 1): + """Perturb the problems (generated by this sampler). + + The perturbed problems will change in one place, and is guaranteed to have a different + solution. The 'leaf' perturbation type allows "small" perturbation, but it will have a + high chance of not able to generate valid perturbations when n_people is small (i.e. all + the single-step perturbations do not lead to a valid solution). One potential solution is + to enable `allow_failure` and filter out invalid ones (marked as None). + + Args: + problems: a list of problems generated by this sampler. + max_retry: max number of retries to generate an alternative and valid problem. + perturb_type: 'leaf' means perturbing only a random leaf node (i.e. not compond statements); + 'statement' means change the entire statement from a random person. + num_perturb: number of perturbations to generate. Note the actual returned perturbations + might be fewer than this number (or even an empty list), if max_retry is exhausted. + + Returns: + A list of perturbed problems. + """ + return [self._perturb_problem(p, max_retry=max_retry, perturb_type=perturb_type, num_perturb=num_perturb) + for p in problems] + + def _perturb_problem(self, problem, max_retry: int, perturb_type: str, num_perturb: int): + assert len(problem['statements']) == self.n_people # make sure parameters match + results_set = set() + results_list = [] + for _ in range(max_retry): + statements = self._copy_statements_as_mutable(problem['statements']) + if perturb_type == 'statement': + person = self.rng.integers(0, self.n_people) + statements[person] = self._sample_statement(person, depth_constraint=self.depth_constraint) + elif perturb_type == 'leaf': + person = self.rng.integers(0, self.n_people) + idx = person + container = statements + while not self._is_leaf_node(container[idx]): + container = container[idx] + idx = self.rng.integers(1, len(container)) + assert self._is_leaf_node(container[idx]) + # set depth_constraint to 1 to only sample new leaf node + container[idx] = self._sample_statement(person, depth_constraint=1) + + statements = self._immutable_statements(statements) + if len(set([statements, problem['statements']])) <= 1: + continue # perturbation is identical to the original, retry + + solutions = find_solution(statements) + if len(solutions) != 1: + continue # Not single unique solution, retry + + if len(set([solutions[0], problem['solution']])) <= 1: + continue # solution does not change after perturbation, retry + + if statements in results_set: + continue # duplicate perturbation, retry + + results_set.add(statements) + results_list.append({'statements': statements, 'solution': solutions[0]}) + if len(results_list) >= num_perturb: + break + + if len(results_list)==0: + return [None] + + return results_list + + def _copy_statements_as_mutable(self, statements): + """Make a deep copy of the statements of a problem, turning the tuples into (mutable) lists.""" + statements = copy.deepcopy(statements) + def _make_mutable(x): + if isinstance(x, tuple): + return [_make_mutable(child) for child in x] + return x + return [_make_mutable(s) for s in statements] + + def _immutable_statements(self, mutable_statements): + """Change list back to tuples.""" + def _make_immutable(x): + if isinstance(x, (list, tuple)): + return tuple(_make_immutable(child) for child in x) + if isinstance(x, np.str_): + return str(x) + if isinstance(x, np.int64): + return int(x) + return x + return tuple(_make_immutable(s) for s in mutable_statements) + + def _is_leaf_node(self, statement): + if statement[0] in ['telling-truth', 'lying']: + return True + return False + + def _sample_statement(self, person_id: int, depth_constraint: int): + """Sample a single statement.""" + dice = self.rng.integers(0, 6) + if depth_constraint == 1 or dice == 0: + while True: + knight_or_knave = self.rng.choice(['telling-truth', 'lying']) + person = self.rng.integers(0, self.n_people) + if not (knight_or_knave == 'lying' and person == person_id): + # avoid the trivially unsatisfiable statement + return (knight_or_knave, person) + + if dice == 1: + return ('not', self._sample_statement(person_id, depth_constraint-1)) + if dice in [2, 3]: + operator = ['and', 'or'][dice-2] + n_substatements = self.rng.integers(2, self.width_constraint+1) + + return (operator,) + self._sample_substatements(person_id, depth_constraint, n_substatements) + if dice in [4, 5]: + operator = ['->', '<=>'][dice-4] + return (operator,) + self._sample_substatements(person_id, depth_constraint, 2) + + def _sample_substatements(self, person_id: int, depth_constraint: int, count: int, dedup: bool = True): + """Sample substatements for an operator. + + Args: + person_id: the id of the person making the statements. + depth_constraint: the maximum depth of substatements. + count: number of substatements to generate. + dedup: if True, avoid duplicated substatements. + """ + sub_statements = [] + dedup_set = set() + while True: + stmt = self._sample_statement(person_id, depth_constraint-1) + if dedup: + if stmt in dedup_set: + continue + dedup_set.add(stmt) + + sub_statements.append(stmt) + if len(sub_statements) == count: + break + return tuple(sub_statements) + + +#################################################################################### +# Problem Formatting in natural language +#################################################################################### +COMMON_NAMES = ['Emma', 'Liam', 'Olivia', 'Noah', 'Ava', 'Ethan', 'Sophia', + 'Mason', 'Isabella', 'William', 'Mia', 'James', 'Charlotte', + 'Benjamin', 'Amelia', 'Lucas', 'Harper', 'Henry', 'Evelyn', + 'Alexander', 'Abigail', 'Michael', 'Emily', 'Daniel', 'Elizabeth', + 'Jacob', 'Sofia', 'Logan', 'Avery', 'Jackson', 'Ella', 'Sebastian', + 'Scarlett', 'Jack', 'Grace', 'Aiden', 'Chloe', 'Owen', 'Victoria', + 'Samuel', 'Riley', 'Matthew', 'Aria', 'Joseph', 'Lily', 'Luke', + 'Aurora', 'David', 'Zoey', 'Oliver', 'Penelope'] +UNCOMMON_NAMES = [ + 'Zephyr', 'Elowen', 'Caspian', 'Isolde', 'Osiris', 'Vesper', 'Thaddeus', 'Ondine', + 'Lysander', 'Xanthe', 'Oberon', 'Calliope', 'Leander', 'Eulalia', 'Florian', 'Forsythe', + 'Nephele', 'Peregrine', 'Ianthe', 'Lazarus', 'Elodie', 'Cillian', 'Ottoline', 'Evander', + 'Saffron', 'Caius', 'Zora', 'Cyprian', 'Amaryllis', 'Theron', 'Perdita', 'Ignatius', + 'Zephyrine', 'Balthazar', 'Melisande', 'Zinnia', 'Sylvester', 'Cosima', 'Leocadio', + 'Percival', 'Oceane', 'Evanthe', 'Zenobia', 'Eurydice', 'Quillan', 'Aeronwen', + 'Thorsten', 'Xiomara', 'Zephyrus', 'Ysolde' +] + +KNIGHT_KNAVE_PAIRS = [ + # NOTE: we simply add 's' to make plural, so be careful when choosing words + ['a pioneer', 'a laggard'], + ['a saint', 'a sinner'], + ['a hero', 'a villain'], + ['an angel', 'a devil'], + ['an altruist', 'an egoist'], + ['a sage', 'a fool'], +] +PREFIX = ('A very special island is inhabited only by {knight}s and {knave}s. ' + + '{Knight}s always tell the truth, and {knave}s always lie. ') +POSTFIX = 'So who is {a_knight} and who is {a_knave}?' +TEMPLATES = [ + '{name} said that {content}.', + '{name} told you that {content}.', + '{name} said, "{content}."', + '{name} stated, "{content}".', + 'According to {name}, "{content}".', + '''In {name}'s words: "{content}".''', + '{name} remarked, "{content}".', + '"{content}," {name} declared.', + '{name} was heard saying, "{content}".', + '{name} expressed that {content}.', + '"{content}" - {name}.', + 'As {name} put it, "{content}".', + '{name} asserted: "{content}".', + '"{content}," {name} mentioned.', + '{name} commented, "{content}".', + 'In a statement by {name}: "{content}".', + '{name} noted, "{content}".', + '"{content}," {name} claimed.', +] + + +class KKProblemFormatter: + + def __init__(self, rand_seed, problem): + self.rng = np.random.default_rng(rand_seed) + self.rng_perturb = np.random.default_rng(rand_seed+1) + self.problem = problem + + def format_problem(self, random_names=True, random_saying_template=True, + random_knight_knave_pairs=False, + flip_knight_knave_pair=False, uncommon_name=False, reorder_statement=False): + statements = copy.deepcopy(self.problem['statements']) + + n_people = len(statements) + names = COMMON_NAMES[:n_people] + if random_names: + if uncommon_name==False: + names = list(self.rng.choice(COMMON_NAMES, size=n_people, replace=False)) + else: + names = list(self.rng.choice(UNCOMMON_NAMES, size=n_people, replace=False)) + names = [str(x) for x in names] # convert np.str_ to str + + knight_knave = ['a knight', 'a knave'] + if random_knight_knave_pairs: + knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS) + knight_knave = [str(x) for x in knight_knave] # convert np.str_ to str + + if flip_knight_knave_pair: + knight_knave = knight_knave[::-1] + + knight_knave = {'knight': knight_knave[0].split()[1], + 'knave': knight_knave[1].split()[1], + 'a_knight': knight_knave[0], 'a_knave': knight_knave[1]} + knight_knave['Knight'] = knight_knave['knight'].capitalize() + knight_knave['Knave'] = knight_knave['knave'].capitalize() + + text = PREFIX.format(**knight_knave) + text += f'You meet {n_people} inhabitants: ' + text += ', '.join(names[:-1]) + ', and ' + names[-1] + '.' + + text_statements=[] + for i, stmt in enumerate(statements): + tmpl = TEMPLATES[0] + if random_saying_template: + tmpl = self.rng.choice(TEMPLATES) + + content = format_statement(names, knight_knave, stmt) + text_statements.append(' ' + tmpl.format(name=names[i], content=content)) + # text += ' ' + tmpl.format(name=names[i], content=content) + + if reorder_statement: + original_order = list(range(n_people)) + # Copy the original list + shuffled_order = original_order.copy() + + # Shuffle until it's different from the original + while True: + self.rng_perturb.shuffle(shuffled_order) + if shuffled_order != original_order: + break + for i in shuffled_order: + text += text_statements[i] + else: + text += ''.join(text_statements) + + text += ' ' + POSTFIX.format(**knight_knave) + if self.problem['solution'] is None: + solution_text = 'No valid solution exists.' + else: + solution_stmts = [] + for name, indicator in zip(names, self.problem['solution']): + if indicator: + solution_stmts.append(name + ' is ' + knight_knave['a_knight']) + else: + solution_stmts.append(name + ' is ' + knight_knave['a_knave']) + solution_text = ', '.join(solution_stmts[:-1]) + ', and ' + solution_stmts[-1] + '.' + return {'quiz': text, 'names': names, 'knight_knave': knight_knave, + 'solution': self.problem['solution'], + 'solution_text': solution_text} + + +# TODO: currently we do not support formatting of problems with depth more than +# 2. We may need to use LLM or think more about what would be the best way +# to format complicated recursive statements. +def format_knight_knave(names, knight_knave, statement, negation=False): + assert statement[0] in ('telling-truth', 'lying') + text = names[statement[1]] + ' is ' + if negation: + text += 'not ' + text += {'telling-truth': knight_knave['a_knight'], + 'lying': knight_knave['a_knave']}[statement[0]] + return text + + +def format_statement(names, knight_knave, statement): + if statement[0] == 'not': + return format_knight_knave(names, knight_knave, statement[1], negation=True) + if statement[0] in ['and', 'or']: + text = (' ' + statement[0] + ' ').join( + format_knight_knave(names, knight_knave, sub_stmt) for sub_stmt in statement[1:]) + return text + if statement[0] == '->': + return ('If ' + format_knight_knave(names, knight_knave, statement[1]) + ' then ' + + format_knight_knave(names, knight_knave, statement[2])) + if statement[0] == '<=>': + return (format_knight_knave(names, knight_knave, statement[1]) + ' if and only if ' + + format_knight_knave(names, knight_knave, statement[2])) + return format_knight_knave(names, knight_knave, statement) + + +#################################################################################### +# Chain of Thoughts +#################################################################################### +def generate_chain_of_thoughts(statements, dynamic_person_order: bool = True): + """Generate reasoning steps that can solve the problem. + + Args: + statements: the statements of the K&K problem. + dynamic_person_order: if False, it will always go through the list of person in the original order. If True, + it will use a more "natural" order. For example, if person1 mention person5 and person4, then the engine will + check person5 and person4 next, instead of checking person2 next. + """ + n_people = len(statements) + tape = [] + assignments = [None] * n_people + options = {p: [False, True] for p in range(n_people)} + persons_to_consider = tuple(range(n_people)) + p_cursor = 0 + while True: + if p_cursor >= n_people: + tape.append(('success', {'assignments': tuple(assignments)})) + break + + if not options[persons_to_consider[p_cursor]]: + exhausted = [] + while p_cursor >= 0 and not options[persons_to_consider[p_cursor]]: + options[persons_to_consider[p_cursor]] = [False, True] + assignments[persons_to_consider[p_cursor]] = None + exhausted.append(persons_to_consider[p_cursor]) + p_cursor -= 1 + if p_cursor >= 0: + tape.append(('reconsider', {'person': persons_to_consider[p_cursor], 'exhausted': exhausted})) + else: + # we have exhausted all options + tape.append(('fail',)) + break + + person = persons_to_consider[p_cursor] + assignments[person] = options[person].pop() + result, stmt_id = can_be_falsified_v2(statements, assignments) + if result: + tape.append(('proposal', {'person': person, 'assignment': assignments[person], + 'outcome': 'ok'})) + # re-order the next people to consider based on who is mentioned in the current statement + mentioned_people = _find_mentioned_people(statements[person]) + p_cursor += 1 + persons_to_consider = persons_to_consider[:p_cursor] + _reorder_people_sequence( + persons_to_consider[p_cursor:], mentioned_people) + else: + tape.append(('proposal', {'person': person, 'assignment': assignments[person], + 'outcome': 'conflict', 'conflict_statement': (stmt_id, assignments[stmt_id])})) + return tape + + +def _find_mentioned_people(statement): + """Find the id of people mentioned in the statement.""" + if statement[0] in ['lying', 'telling-truth']: + return [statement[1]] + if statement[0] in ['not', 'and', 'or', '->', '<=>']: + return sum([_find_mentioned_people(s) for s in statement[1:]], []) + raise KeyError(f'Unknown statement: {statement}') + + +def _reorder_people_sequence(remaining_people, mentioned_people): + """Reorder the remaining people by brining the mentioned ones to the front.""" + # dedup and keep order + set_uniq_mention = set() + list_uniq_mention = [] + for p in mentioned_people: + if p not in set_uniq_mention: + set_uniq_mention.add(p) + list_uniq_mention.append(p) + + for p in reversed(mentioned_people): + if not p in remaining_people: + continue + idx = remaining_people.index(p) + remaining_people = (p,) + remaining_people[:idx] + remaining_people[idx+1:] + return remaining_people + + +def can_be_falsified_v2(statements, assignments): + """Test falsifiability of partial assignment (v2). + + This version enumerate all possible remaining assignments. This is less efficient than v1. But v1 has + the potential issue that it cannot easily detect self contradictory statement such as + `('<=>', ('lying', 4), ('telling-truth', 4))` when the person 4's assignment is undecided yet. + """ + n_people = len(statements) + remap = [i for i, x in enumerate(assignments) if x is None] + n_unassigned = len(remap) + + for p_idx in range(n_people): + if assignments[p_idx] is None: + continue + p_statement = statements[p_idx] + if not assignments[p_idx]: + p_statement = ('not', p_statement) + has_solution = False + + for proposal in itertools.product([True, False], repeat=n_unassigned): + new_assignments = copy.copy(assignments) + for i, x in zip(remap, proposal): + new_assignments[i] = x + if test_satisfiability(p_statement, new_assignments): + has_solution = True + break + if not has_solution: + return (False, p_idx) # this person's statement cannot be satisfied + + return (True, None) + + +class TruthOrWhatever(enum.Enum): + FALSE = 0 + TRUE = 1 + WHATEVER = 2 + + @classmethod + def from_bool(cls, val: bool): + if val: + return cls.TRUE + else: + return cls.FALSE + + def f_not(self): + if self == self.TRUE: + return self.FALSE + if self == self.FALSE: + return self.TRUE + return self.WHATEVER + + def f_and(self, other): + if self == self.WHATEVER or other == self.WHATEVER: + return self.WHATEVER + if self == self.TRUE: + return self.from_bool(other == self.TRUE) + return self.FALSE + + def f_or(self, other): + if self == self.WHATEVER or other == self.WHATEVER: + return self.WHATEVER + if self == self.FALSE: + return self.from_bool(other == self.TRUE) + return self.TRUE + + +def can_be_falsified(statements, assignments): + """Test if the (partial) assignment can be falsified.""" + def _test(stmt) -> TruthOrWhatever: + if stmt[0] in ['telling-truth', 'lying'] and assignments[stmt[1]] is None: + return TruthOrWhatever.WHATEVER + if stmt[0] == 'telling-truth': + return TruthOrWhatever.from_bool(assignments[stmt[1]] is True) + if stmt[0] == 'lying': + return TruthOrWhatever.from_bool(assignments[stmt[1]] is False) + if stmt[0] == 'not': + return _test(stmt[1]).f_not() + if stmt[0] == 'and': + val = _test(stmt[1]) + for sub_stmt in stmt[2:]: + val = val.f_and(_test(sub_stmt)) + return val + if stmt[0] == 'or': + val = _test(stmt[1]) + for sub_stmt in stmt[2:]: + val = val.f_or(_test(sub_stmt)) + return val + if stmt[0] == '->': + val1 = _test(stmt[1]) + val2 = _test(stmt[2]) + return val1.f_not().f_or(val2) + if stmt[0] == '<=>': + val1 = _test(stmt[1]) + val2 = _test(stmt[2]) + return val1.f_and(val2).f_or(val1.f_not().f_and(val2.f_not())) + raise KeyError(f'Unknown statement: {stmt}') + + for i, (stmt, assmt) in enumerate(zip(statements, assignments)): + if assmt is None: + # this person's claim does not matter + continue + if assmt and _test(stmt) == TruthOrWhatever.FALSE: + return (False, i) + if not assmt and _test(stmt) == TruthOrWhatever.TRUE: + return (False, i) + return (True, None) + + +def format_chain_of_thoughts(problem, formatted_problem, tape, + repeat_claim_for_assumption: bool = True, + repeat_claim_for_contradiction: bool = False): + """Format generate chain-of-thoughts in natural language. + + Repeating the claim makes it a bit more natural, but also increas the number of tokens needed to handle. + + Args: + problem: the K&K problem. + formatted_problem: the formatted results of the K&K problem. + tape: the generated chain of thoughts. + repeat_claim_for_assumption: whether to repeat each person's claim after we assuming they are a knight or knave. + repeat_claim_for_contradiction: whether to repeat the contradicted claim when a contradiction is found. + + Returns: + (header, [step1, step2, ...], footer). The footer contains a conclusion of success or failure. Note the final + solution is not included in the footer. If needed, problem['solution_text'] can be appended here. + """ + format_dict = copy.copy(formatted_problem['knight_knave']) + n_person = len(problem['statements']) + for p in range(n_person): + format_dict[f'P{p}'] = formatted_problem['names'][p] + + header = "Let's think step by step, by considering whether each person is lying and if that leads to contradiction." + steps = [] + for step in tape[:-1]: # last step is fail / success + if step[0] == 'proposal': + t_person = '{P' + str(step[1]['person']) + '}' + t_assignment = '{a_knight}' if step[1]['assignment'] else '{a_knave}' + if step[1]['outcome'] == 'ok': + text = 'Assume ' + t_person + ' is ' + t_assignment + '.' + if repeat_claim_for_assumption: + t_claim = format_statement(formatted_problem['names'], formatted_problem['knight_knave'], + problem['statements'][step[1]['person']]) + text += ' No contradiction is found in their ' + if not step[1]['assignment']: + text += 'false ' + text += 'claim that ' + t_claim + '.' + elif step[1]['outcome'] == 'conflict': + conflict_p, conflict_assignment = step[1]['conflict_statement'] + text = t_person + ' cannot be ' + t_assignment + ', because this would contradict the ' + if not conflict_assignment: + text += 'false ' + text += 'claim of ' + if conflict_p == step[1]['person']: + text += 'their own' + else: + text += '{P' + str(conflict_p) + '}' + if repeat_claim_for_contradiction: + t_claim = format_statement(formatted_problem['names'], formatted_problem['knight_knave'], + problem['statements'][conflict_p]) + text += ' that ' + t_claim + '.' + else: + text += '.' + else: + raise KeyError(f'Unknown outcome for CoT step: {step}') + steps.append(text) + elif step[0] == 'reconsider': + text = 'We have exhausted all possibilities for ' + t_exhausted = ['{P' + str(p_idx) + '}' for p_idx in step[1]['exhausted']] + assert len(t_exhausted) > 0 + if len(t_exhausted) == 1: + text += t_exhausted[0] + elif len(t_exhausted) == 2: + text += ' and '.join(t_exhausted) + else: + t_exhausted[-1] = 'and ' + t_exhausted[-1] + text += ', '.join(t_exhausted) + text += ', so let us go back and reconsider {P' + str(step[1]['person']) + '}.' + steps.append(text) + else: + raise KeyError(f'Unknown CoT step: {step}') + + if tape[-1][0] == 'success': + footer = 'This leads to a feasible solution.' + elif tape[-1][0] == 'fail': + footer = 'All the configurations lead to contradictions.' + else: + raise KeyError(f'Expect success or fail, but get {tape[-1]}') + + steps = [x.format(**format_dict) for x in steps] + return (header, steps, footer) + + +#################################################################################### +# Unit Testing +#################################################################################### +class TestKK(unittest.TestCase): + + def test_find_solution(self): + statements = ( + ('lying', 1), + ('and', ('telling-truth', 0), ('telling-truth', 1)) + ) + sol = find_solution(statements) + self.assertEqual(sol, [(True, False)]) + + def test_sample_problems(self): + n_people = 3 + n_problems = 5 + problem_sampler = KKProblemSampler(1234, n_people=n_people) + problems = problem_sampler.sample_valid_problems(n_problems) + self.assertEqual(len(problems), n_problems) + for problem in problems: + self.assertEqual(set(problem.keys()), set(['statements', 'solution', 'all_solutions'])) + self.assertEqual(len(problem['statements']), n_people) + + def test_format_problems(self): + problem_sampler = KKProblemSampler(1234, n_people=3) + problems = problem_sampler.sample_valid_problems(20, skip_no_solution=False) + + for problem in problems: + formatter = KKProblemFormatter(rand_seed=1234, problem=problem) + formatted_results = formatter.format_problem() + self.assertIn('quiz', formatted_results) + self.assertIn('names', formatted_results) + self.assertIn('solution', formatted_results) + self.assertIn('solution_text', formatted_results) + if problem['solution'] is None: + self.assertEqual(formatted_results['solution_text'], 'No valid solution exists.') + + def test_perturb_problems(self): + n_people = 4 + n_perturb = 3 + problem_sampler = KKProblemSampler(1234, n_people=n_people) + problems = problem_sampler.sample_valid_problems(5) + for perturb_type in ['statement', 'leaf']: + perturbed_problems = problem_sampler.perturb_problems(problems, perturb_type=perturb_type, num_perturb=n_perturb) + self.assertEqual(len(problems), len(perturbed_problems)) + for p1, p2_list in zip(problems, perturbed_problems): + self.assertEqual(len(p2_list), n_perturb) # note this can actual fail, esp for small n_people + self.assertNotEqual(p1['solution'], p2_list[0]['solution']) + n_stmt_diff = 0 + for s1, s2 in zip(p1['statements'], p2_list[0]['statements']): + if s1 != s2: + n_stmt_diff += 1 + self.assertEqual(n_stmt_diff, 1) # exactly 1 statement is different + + def test_chain_of_thoughts(self): + n_people = 5 + n_problems = 120 + problem_sampler = KKProblemSampler(1234, n_people=n_people) + problems = problem_sampler.sample_valid_problems(n_problems, skip_no_solution=False) + for p in problems: + for dynamic_person_order in [False, True]: + tape = generate_chain_of_thoughts(p['statements'], dynamic_person_order=dynamic_person_order) + if p['solution'] is None: + self.assertTupleEqual(tape[-1], ('fail',)) + else: + self.assertEqual(tape[-1][0], ('success')) + self.assertTupleEqual(tape[-1][1]['assignments'], p['solution']) + + def test_chain_of_thoughts_regression(self): + # Regression test: NOTE the correct answer is not unique and it can change when the CoT generator code + # is changed. So the failure of this test does not necessarily mean the code is incorrect. If the code + # is changed and verified to be correct, this test can be updated with the new target outputs. + statements = (('and', ('telling-truth', 2), ('lying', 3)), + ('telling-truth', 2), + ('<=>', ('lying', 4), ('telling-truth', 4)), + ('and', ('lying', 2), ('lying', 4)), + ('lying', 2)) + expected_tape = [ + ('proposal', {'person': 0, 'assignment': True, 'outcome': 'ok'}), + ('proposal', + {'person': 2, + 'assignment': True, + 'outcome': 'conflict', + 'conflict_statement': (2, True)}), + ('proposal', + {'person': 2, + 'assignment': False, + 'outcome': 'conflict', + 'conflict_statement': (0, True)}), + ('reconsider', {'person': 0, 'exhausted': [2]}), + ('proposal', {'person': 0, 'assignment': False, 'outcome': 'ok'}), + ('proposal', + {'person': 2, + 'assignment': True, + 'outcome': 'conflict', + 'conflict_statement': (2, True)}), + ('proposal', {'person': 2, 'assignment': False, 'outcome': 'ok'}), + ('proposal', {'person': 4, 'assignment': True, 'outcome': 'ok'}), + ('proposal', + {'person': 3, + 'assignment': True, + 'outcome': 'conflict', + 'conflict_statement': (3, True)}), + ('proposal', {'person': 3, 'assignment': False, 'outcome': 'ok'}), + ('proposal', + {'person': 1, + 'assignment': True, + 'outcome': 'conflict', + 'conflict_statement': (1, True)}), + ('proposal', {'person': 1, 'assignment': False, 'outcome': 'ok'}), + ('success', {'assignments': (False, False, False, False, True)}) + ] + tape = generate_chain_of_thoughts(statements, dynamic_person_order=True) + self.assertEqual(tape, expected_tape) + +def test_chain_of_thoughts_format_regression(self): + # Regression test: NOTE the correct answer is not unique and it can change when the CoT generator code + # is changed. So the failure of this test does not necessarily mean the code is incorrect. If the code + # is changed and verified to be correct, this test can be updated with the new target outputs. + problem = { + 'statements': (('and', ('telling-truth', 2), ('lying', 3)), + ('telling-truth', 2), + ('<=>', ('lying', 4), ('telling-truth', 4)), + ('and', ('lying', 2), ('lying', 4)), + ('lying', 2)), + 'solution': (False, False, False, False, True), + 'all_slutions': [(False, False, False, False, True)] + } + chain_of_thoughts = generate_chain_of_thoughts(problem['statements']) + formatted_problem = {'knight_knave': {'Knave': 'Knave', + 'Knight': 'Knight', + 'a_knave': 'a knave', + 'a_knight': 'a knight', + 'knave': 'knave', + 'knight': 'knight'}, + 'names': ['David', 'Zoey', 'Alexander', 'Aurora', 'Isabella'], + 'quiz': 'A very special island is inhabited only by knights and knaves. ' + 'Knights always tell the truth, and knaves always lie. You meet 5 ' + 'inhabitants: David, Zoey, Alexander, Aurora, and Isabella. In ' + 'David\'s words: "Alexander is a knight and Aurora is a knave". Zoey ' + 'expressed that Alexander is a knight. Alexander said, "Isabella is a ' + 'knave if and only if Isabella is a knight." Aurora commented, ' + '"Alexander is a knave and Isabella is a knave". According to ' + 'Isabella, "Alexander is a knave". So who is a knight and who is a ' + 'knave?', + 'solution': (False, False, False, False, True), + 'solution_text': 'David is a knave, Zoey is a knave, Alexander is a knave, ' + 'Aurora is a knave, and Isabella is a knight.'} + cot_format = format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=True, + repeat_claim_for_contradiction=True) + expected_cot = ('Let us think step by step, by considering whether each person is lying and if that leads to contradiction.', + ['Assume David is a knight. No contradiction is found in their claim that Alexander is a knight and Aurora is a knave.', + 'Alexander cannot be a knight, because this would contradict the claim of their own.', + 'Alexander cannot be a knave, because this would contradict the claim of David.', + 'We have exhausted all possibilities for Alexander, so let us go back and reconsider David.', + 'Assume David is a knave. No contradiction is found in their false claim that Alexander is a knight and Aurora is a knave.', + 'Alexander cannot be a knight, because this would contradict the claim of their own.', + 'Assume Alexander is a knave. No contradiction is found in their false claim that Isabella is a knave if and only if Isabella is a knight.', + 'Assume Isabella is a knight. No contradiction is found in their claim that Alexander is a knave.', + 'Aurora cannot be a knight, because this would contradict the claim of their own.', + 'Assume Aurora is a knave. No contradiction is found in their false claim that Alexander is a knave and Isabella is a knave.', + 'Zoey cannot be a knight, because this would contradict the claim of their own.', + 'Assume Zoey is a knave. No contradiction is found in their false claim that Alexander is a knight.'], + 'This leads to a feasible solution.') + self.assertEqual(cot_format, expected_cot) + + cot_format = format_chain_of_thoughts(problem, formatted_problem, chain_of_thoughts, + repeat_claim_for_assumption=False, + repeat_claim_for_contradiction=False) + expected_cot = ('Let us think step by step, by considering whether each person is lying and if that leads to contradiction.', + ['Assume David is a knight.', + 'Alexander cannot be a knight, because this would contradict the claim of their own.', + 'Alexander cannot be a knave, because this would contradict the claim of David.', + 'We have exhausted all possibilities for Alexander, so let us go back and reconsider David.', + 'Assume David is a knave.', + 'Alexander cannot be a knight, because this would contradict the claim of their own.', + 'Assume Alexander is a knave.', + 'Assume Isabella is a knight.', + 'Aurora cannot be a knight, because this would contradict the claim of their own.', + 'Assume Aurora is a knave.', + 'Zoey cannot be a knight, because this would contradict the claim of their own.', + 'Assume Zoey is a knave.'], + 'This leads to a feasible solution.') + self.assertEqual(cot_format, expected_cot) + + +if __name__ == '__main__': + unittest.main() diff --git a/dataset/kk.py b/dataset/kk.py new file mode 100644 index 0000000..bbf8305 --- /dev/null +++ b/dataset/kk.py @@ -0,0 +1,138 @@ +import numpy as np +from .prompt import system_instruction, demonstration_2char, system_instruction_no_reason, demonstration_2char_no_reason + + +def num_tokens_from_string(string): + import tiktoken + """Returns the number of tokens in a text string.""" + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + num_tokens = len(encoding.encode(string)) + return num_tokens + + +def parse_cot_eval(pred_str, ans, + conclusion_patterns=['CONCLUSION:'], + verbose=False, + finish_patterns=["### Reason", "Let's think step by step again", "let's go back and check", "###"], + reformat_gold_conditions=None): + + def judge_string(input_str, reformat_gold_conditions, wrong_reason, finish_patterns): + correct_count = 0 + is_correct = False + beyond_id = len(reformat_gold_conditions)+1 + beyond_id_pattern = f"({beyond_id})" + + for finish_pattern in finish_patterns: + if finish_pattern in input_str: + input_str = input_str.split(finish_pattern)[0] + + if beyond_id_pattern in input_str: + is_correct = False + wrong_reason = "beyond_list" + elif "if" in input_str: + is_correct = False + wrong_reason = "contain_if" + else: + is_correct = True + for gold_condition in reformat_gold_conditions: + if gold_condition not in input_str: + is_correct = False + wrong_reason = "wrong_identity" + else: + correct_count += 1 + correct_ratio = correct_count/len(reformat_gold_conditions) + + return is_correct, wrong_reason, correct_ratio + + def check_numbers_in_string(s, N): + for i in range(1, N + 1): + if f"({i})" not in s: + return False + return True + + original_str = pred_str + pred_str = pred_str.split("### Question")[0] + pred_answer = pred_str + is_correct = False + correct_ratio = 0 + if reformat_gold_conditions is None: + gold = ans.replace(" and ", "").replace(".", "") + gold_conditions = gold.split(",") + reformat_gold_conditions = [] + for condition in gold_conditions: + gold_condition = condition.strip() # Remove leading and trailing spaces + reformat_gold_conditions.append(gold_condition) + + wrong_reason = "no_conclusion_matched" + for pattern in conclusion_patterns: + pred = pred_str.split(pattern) + if len(pred) > 1: + if len(pred[1]) > 0: # if the matched the answer is not empty + pred_answer = pred[1] + is_correct, wrong_reason, correct_ratio = judge_string( + pred_answer, reformat_gold_conditions, wrong_reason, finish_patterns) + break + if is_correct == False and wrong_reason == "no_conclusion_matched": + if check_numbers_in_string(pred_str, len(reformat_gold_conditions)): # the answer contains (1)..(2).. + is_correct, wrong_reason, correct_ratio = judge_string( + pred_str, reformat_gold_conditions, wrong_reason, finish_patterns) + if is_correct == False and verbose == True: + print("wrong_reason:",wrong_reason) + print("********* \nprediction before parse:\n", original_str) + print("********* \nprediction after parse:\n", pred_answer) + + return is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions + + +class KKProcessor: + def __init__(self, cot=True, no_linebreak=True): + self.cot = cot + self.no_linebreak = no_linebreak + + def format_example(self, test_records, idx, model_name=None): + + item = test_records[idx] + + prompt = "### Question: "+item["quiz"] + "\n" + if self.cot: + if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]: + prompt += "Please reason step by step, and put your final answer within \\boxed{}." + else: + prompt += "### Answer: Let's think step by step" + else: + if self.no_linebreak: + prompt += "### Answer:" + else: + prompt += "### Answer:\n" + answer = item["solution_text"] + return prompt, answer + + def gen_test_prompt(self, ntrain, test_records, idx, model_name=None): + if self.cot: + train_prompt = system_instruction + else: + train_prompt = system_instruction_no_reason + + if ntrain == 1: + if self.cot: + train_prompt += "\n\n"+demonstration_2char + else: + train_prompt += "\n\n"+demonstration_2char_no_reason + elif ntrain > 1: + raise NotImplementedError + + prompt_end, answer = self.format_example(test_records, idx, model_name) + prompt = train_prompt + "\n\n" + prompt_end + + return prompt, answer + + def _parse_cot_eval(self, pred_str, ans, model_name=None): + conclusion_patterns = ['CONCLUSION:', 'Conclusion:', 'conclusion:'] + + if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]: + conclusion_patterns = ['boxed{', 'CONCLUSION:', 'Conclusion:', 'conclusion:'] + + is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions = parse_cot_eval( + pred_str, ans, conclusion_patterns=conclusion_patterns, verbose=False) + + return is_correct, pred_answer, reformat_gold_conditions diff --git a/dataset/prompt.py b/dataset/prompt.py new file mode 100644 index 0000000..cb9e2d8 --- /dev/null +++ b/dataset/prompt.py @@ -0,0 +1,38 @@ +system_instruction='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. + +You must infer the identity of each character. First, explain your reasoning. At the end of your answer, you must clearly state the identity of each character by following the format: + +CONCLUSION: +(1) ... +(2) ... +(3) ... +''' + + +system_instruction_no_reason='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. + +You must infer the identity of each character. At the end of your answer, you must clearly state the identity of each character by following the format: + +CONCLUSION: +(1) ... +(2) ... +(3) ... +''' + +demonstration_2char_no_reason='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Jack, and Sophia. Jack tells you that Sophia is not a knave. Sophia says that If Jack is a knight then Sophia is a knight. So who is a knight and who is a knave? +### Answer: +CONCLUSION: +(1) Jack is a knight +(2) Sophia is a knight +''' + + + +demonstration_2char='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Ella, and Penelope. In a statement by Ella: \"Ella is a knight or Penelope is a knight\". According to Penelope, \"Ella is a knave if and only if Penelope is a knight\". So who is a knight and who is a knave? +### Answer: Let's think step by step, by considering whether each person is lying and if that leads to contradiction. Assume Ella is a knight. Penelope cannot be a knight, because this would contradict the claim of their own. Penelope cannot be a knave, because this would contradict the false claim of their own. We have exhausted all possibilities for Penelope, so let us go back and reconsider Ella. Assume Ella is a knave. Penelope cannot be a knight, because this would contradict the false claim of Ella. Assume Penelope is a knave. This leads to a feasible solution. +CONCLUSION: +(1) Ella is a knave +(2) Penelope is a knave +''' + + diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..17006da --- /dev/null +++ b/environment.yaml @@ -0,0 +1,229 @@ +name: kk +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.4.1=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - blas=1.0=mkl + - ca-certificates=2024.8.30=hbcca054_0 + - comm=0.2.2=pyhd8ed1ab_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.6.37=0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.6=3 + - debugpy=1.6.7=py39h6a678d5_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - executing=2.1.0=pyhd8ed1ab_0 + - filelock=3.13.1=py39h06a4308_0 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py39heeb90bb_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - ipykernel=6.29.5=pyh3099207_0 + - ipython=8.12.0=pyh41d4057_0 + - jedi=0.19.1=pyhd8ed1ab_0 + - jinja2=3.1.4=py39h06a4308_0 + - jupyter_client=7.3.4=pyhd8ed1ab_0 + - jupyter_core=5.7.2=py39hf3d152e_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.11.0.15=0 + - libcurand=10.3.7.37=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - markupsafe=2.1.3=py39h5eee18b_0 + - matplotlib-inline=0.1.7=pyhd8ed1ab_0 + - mkl=2023.1.0=h213fc3f_46344 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py39h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nest-asyncio=1.6.0=pyhd8ed1ab_0 + - networkx=3.2.1=py39h06a4308_0 + - openssl=3.0.14=h5eee18b_0 + - packaging=24.1=pyhd8ed1ab_0 + - parso=0.8.4=pyhd8ed1ab_0 + - pexpect=4.9.0=pyhd8ed1ab_0 + - pickleshare=0.7.5=py_1003 + - pip=24.2=py39h06a4308_0 + - platformdirs=4.2.2=pyhd8ed1ab_0 + - prompt-toolkit=3.0.47=pyha770c72_0 + - prompt_toolkit=3.0.47=hd8ed1ab_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.3=pyhd8ed1ab_0 + - pygments=2.18.0=pyhd8ed1ab_0 + - python=3.9.19=h955ad1f_1 + - python_abi=3.9=2_cp39 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.1=py39h5eee18b_0 + - readline=8.2=h5eee18b_0 + - setuptools=72.1.0=py39h06a4308_0 + - six=1.16.0=pyh6c4a22f_0 + - sqlite=3.45.3=h5eee18b_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - sympy=1.12=py39h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - tornado=6.1=py39hb9d737c_3 + - traitlets=5.14.3=pyhd8ed1ab_0 + - typing_extensions=4.11.0=py39h06a4308_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - wheel=0.43.0=py39h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.5=h6a678d5_0 + - zlib=1.2.13=h5eee18b_1 + - pip: + - accelerate==0.33.0 + - aiohappyeyeballs==2.3.5 + - aiohttp==3.10.3 + - aiosignal==1.3.1 + - annotated-types==0.7.0 + - anthropic==0.34.2 + - anyio==4.4.0 + - async-timeout==4.0.3 + - attrs==24.2.0 + - bitsandbytes==0.44.1 + - certifi==2024.7.4 + - charset-normalizer==3.3.2 + - click==8.1.7 + - cloudpickle==3.0.0 + - cmake==3.30.2 + - contourpy==1.3.0 + - cycler==0.12.1 + - datasets==2.20.0 + - dill==0.3.8 + - diskcache==5.6.3 + - distro==1.9.0 + - docker-pycreds==0.4.0 + - docstring-parser==0.16 + - eval-type-backport==0.2.0 + - evaluate==0.4.3 + - exceptiongroup==1.2.2 + - fastapi==0.112.0 + - fonttools==4.53.1 + - frozenlist==1.4.1 + - fsspec==2024.5.0 + - gitdb==4.0.11 + - gitpython==3.1.43 + - h11==0.14.0 + - httpcore==1.0.5 + - httptools==0.6.1 + - httpx==0.27.0 + - huggingface-hub==0.24.5 + - idna==3.7 + - importlib-resources==6.4.4 + - interegular==0.3.3 + - jiter==0.5.0 + - joblib==1.4.2 + - jsonschema==4.23.0 + - jsonschema-specifications==2023.12.1 + - kiwisolver==1.4.5 + - lark==1.2.1 + - llvmlite==0.43.0 + - lm-format-enforcer==0.10.1 + - markdown-it-py==3.0.0 + - matplotlib==3.9.2 + - mdurl==0.1.2 + - msgpack==1.0.8 + - multidict==6.0.5 + - multiprocess==0.70.16 + - ninja==1.11.1.1 + - nltk==3.9.1 + - numba==0.60.0 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-ml-py==12.555.43 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.6.20 + - nvidia-nvtx-cu12==12.1.105 + - openai==1.43.0 + - outlines==0.0.46 + - pandas==2.2.2 + - peft==0.12.0 + - pillow==10.4.0 + - prometheus-client==0.20.0 + - prometheus-fastapi-instrumentator==7.0.0 + - protobuf==5.27.3 + - psutil==6.0.0 + - py-cpuinfo==9.0.0 + - pyairports==2.1.1 + - pyarrow==17.0.0 + - pyarrow-hotfix==0.6 + - pycountry==24.6.1 + - pydantic==2.8.2 + - pydantic-core==2.20.1 + - pyparsing==3.1.4 + - python-dateutil==2.9.0.post0 + - python-dotenv==1.0.1 + - pytz==2024.1 + - pyzmq==26.1.0 + - ray==2.34.0 + - referencing==0.35.1 + - regex==2024.7.24 + - requests==2.32.3 + - rich==13.9.2 + - rpds-py==0.20.0 + - safetensors==0.4.4 + - scienceplots==2.1.1 + - scikit-learn==1.5.2 + - scipy==1.13.1 + - seaborn==0.13.2 + - sentencepiece==0.2.0 + - sentry-sdk==2.17.0 + - setproctitle==1.3.3 + - shtab==1.7.1 + - smmap==5.0.1 + - sniffio==1.3.1 + - starlette==0.37.2 + - threadpoolctl==3.5.0 + - tiktoken==0.7.0 + - tokenizers==0.19.1 + - torch==2.3.0 + - torchvision==0.18.0 + - tqdm==4.66.5 + - transformers==4.45.0.dev0 + - triton==2.3.0 + - trl==0.11.4 + - tyro==0.8.13 + - tzdata==2024.1 + - urllib3==2.2.2 + - uvicorn==0.30.6 + - uvloop==0.19.0 + - vllm==0.5.1 + - vllm-flash-attn==2.5.9 + - wandb==0.18.5 + - watchfiles==0.23.0 + - websockets==12.0 + - xformers==0.0.26.post1 + - xxhash==3.4.1 + - yarl==1.9.4 + - zipp==3.20.1 diff --git a/eval_kk.py b/eval_kk.py new file mode 100644 index 0000000..2b9d576 --- /dev/null +++ b/eval_kk.py @@ -0,0 +1,195 @@ +import argparse +import json +import os +import numpy as np +import random +import torch +import time +from dataset.kk import KKProcessor +from utils import load_eval_records, load_jsonl, write_jsonl, batch_decode_vllm, init_seed, load_llm + + + +def eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records): + """Evaluate one subject.""" + + cors = [] + start_index = len(exist_result_records) + print(f"Found existing {start_index} records in {subject}") + for i in range(start_index): + cors.append(exist_result_records[i]["correct"]) + + eval_start_time = time.time() + # Prepare all prompts + prompts = [] + labels = [] + for i in range(start_index, len(test_records)): + prompt, label = kk_proc.gen_test_prompt( + args.ntrain, test_records, i, args.model + ) + prompts.append(prompt) + if i == start_index: + print(f"Sample prompt:\n{prompt}") + labels.append(label) + + # Get responses + if args.use_vllm: + responses = batch_decode_vllm(llm, prompts, batch_size=args.batch_size) + else: + responses = [] + for index, prompt in enumerate(prompts): + response = llm.query(prompt) + responses.append(response) + if index % 1 == 0: + print(f"\nResponse {index}:\n{response}") + print(f"\nLabel {index}:\n{labels[index]}") + + # Process results + for i, (prompt, label, response) in enumerate(zip(prompts, labels, responses), start=start_index): + cor, parsed_pred, reformat_gold_conditions = kk_proc._parse_cot_eval(response, label, args.model) + + if i % 1 == 0: + print(f"\nPrompt {i}:{prompt}" + f"\nResponse {i}:{response}" + f"\nPrediction {i}:{parsed_pred}" + f"\nLabel {i}:{reformat_gold_conditions}" + f"\nCorrect {i}:{cor}") + + cors.append(cor) + new_item = { + 'quiz': test_records[i]['quiz'], + 'names': test_records[i]['names'], + 'solution': test_records[i]['solution'], + 'solution_text': test_records[i]['solution_text'], + 'solution_text_format': test_records[i]['solution_text_format'], + 'index': test_records[i]['index'], + 'predicts': parsed_pred, + 'labels': reformat_gold_conditions, + 'correct': cor, + 'response': response, + 'prompts': prompt, + } + exist_result_records.append(new_item) + + eval_end_time = time.time() + eval_time = eval_end_time - eval_start_time + acc = np.mean(cors) + cors = np.array(cors) + + print("Average accuracy {:.3f} - {}".format(acc, subject)) + print(f"Total evaluation time: {eval_time:.2f} seconds") + + return cors, acc, exist_result_records + + +def load_limited_test_records(args, subject, exist_result_records): + """Load limited test records based on given arguments.""" + test_records = load_eval_records(args, subject) + + if args.limit is not None: + test_records = test_records.select(range(min(args.limit, len(test_records)))) + if args.limit <= len(exist_result_records): + return None # have finished exp + + return test_records + +def save_final_acc_results(all_cors, results, fname): + """Process final results, calculate average accuracy, and save to file.""" + if all_cors: + weighted_acc = np.mean(np.concatenate(all_cors)) + results["weighted_accuracy"] = weighted_acc + print(f"Average accuracy: {weighted_acc:.3f}") + + with open(fname, "w") as f: + json.dump(results, f) + +def load_previous_acc_results(fname): + """Load previous accuracy results.""" + acc_results = {"subject": {}} + if os.path.isfile(fname): + with open(fname, 'r', encoding='utf-8') as file: + acc_results = json.load(file) + print(f"Previous Results loaded successfully: {acc_results}") + return acc_results + +def get_subjects_to_eval(args): + """Get subjects to evaluate.""" + + subjects = [] + if args.split == "test": + if args.eval_nppl == 0: + subjects = [f"people{nppl}_num100" for nppl in range(2, 9)] + else: + subjects = [f"people{args.eval_nppl}_num100"] + elif args.split == "train": + if args.eval_nppl == 2: + subjects = ["people2_num200"] + elif args.eval_nppl > 2: + subjects = [f"people{args.eval_nppl}_num1000"] + return subjects + + +def main(args): + + model_short_name = "/".join(args.model.split("/")[-2:]) + + prefix = os.path.join( + os.path.join(args.save_dir, "{}_{}shot".format( + model_short_name, args.ntrain)) + ) + + args.config += f"_token{args.max_token}{('_cot' if args.cot else '')}" \ + f"_{args.split}{('_' + args.problem_type if args.problem_type != 'clean' else '')}" + + output_folder = os.path.join(prefix, args.config) + acc_fname = os.path.join(prefix, f"result_{args.config}.json") + os.makedirs(output_folder, exist_ok=True) + + print("args.config", args.config, "\nprefix", prefix, "\noutput_folder", output_folder) + + kk_proc = KKProcessor(cot=args.cot, no_linebreak=args.no_linebreak) + + subjects = get_subjects_to_eval(args) + acc_results = load_previous_acc_results(acc_fname) + + llm = None + all_cors = [] + for subject in subjects: + result_outfile = os.path.join(output_folder, "{}.jsonl".format(subject)) + exist_result_records = load_jsonl(result_outfile) if os.path.exists(result_outfile) else [] + test_records = load_limited_test_records(args, subject, exist_result_records) + if test_records is None: + continue + + llm = llm or load_llm(args) + + cors, acc, result_records = eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records) + + write_jsonl(result_outfile, result_records) + all_cors.append(cors) + acc_results["subject"][subject] = acc + + save_final_acc_results(all_cors, acc_results, acc_fname) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluation script for KK dataset") + parser.add_argument("--ntrain", "-k", type=int, default=0, help="Number of training examples") + parser.add_argument("--data_dir", "-d", type=str, default="data", help="Data directory") + parser.add_argument("--save_dir", "-s", type=str, default="result_qa", help="Save directory") + parser.add_argument("--model", "-m", type=str, required=True, help="Model name or path") + parser.add_argument("--arch", type=str, default=None, help="Model architecture") + parser.add_argument("--config", "-c", type=str, default="", help="Configuration string") + parser.add_argument("--max_token", type=int, default=1024, help="Maximum number of tokens") + parser.add_argument("--limit", type=int, default=None, help="Limit the number of examples") + parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting") + parser.add_argument("--no_linebreak", action="store_true", help="Remove line breaks") + parser.add_argument("--use_vllm", action="store_true", help="Use VLLM for inference") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size for VLLM") + parser.add_argument("--split", type=str, default="test", choices=["test", "train"], help="Data split to use") + parser.add_argument("--eval_nppl", type=int, default=0, help="Number of people to evaluate") + parser.add_argument("--problem_type", type=str, default="clean", help="Problem perturbation type") + + args = parser.parse_args() + init_seed() + main(args) \ No newline at end of file diff --git a/figures/data-gen.png b/figures/data-gen.png new file mode 100644 index 0000000..af4a0c1 Binary files /dev/null and b/figures/data-gen.png differ diff --git a/figures/mem-score.png b/figures/mem-score.png new file mode 100644 index 0000000..d0c2e4e Binary files /dev/null and b/figures/mem-score.png differ diff --git a/finetune_kk.py b/finetune_kk.py new file mode 100644 index 0000000..5c186b4 --- /dev/null +++ b/finetune_kk.py @@ -0,0 +1,380 @@ +import os +import argparse +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTConfig, SFTTrainer +import wandb +from peft import LoraConfig +from torch.nn import functional as F +from datasets import load_dataset +import random +import numpy as np +from functools import partial + + +def init_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +class CustomSFTTrainer(SFTTrainer): + def __init__(self, response_template, *args, **kwargs): + super().__init__(*args, **kwargs) + self.response_template = response_template + self.after_answer_losses = [] + self.before_answer_losses = [] + self.current_epoch = 0 + self.steps_per_epoch = None + self.accumulated_steps = 0 + + def train(self, resume_from_checkpoint=None, **kwargs): + self.current_epoch = 0 # Reset epoch counter + self.accumulated_steps = 0 # Reset accumulated steps + return super().train(resume_from_checkpoint=resume_from_checkpoint, **kwargs) + + def compute_loss(self, model, inputs, return_outputs=False): + # Find the index of "### Answer" in the input_ids + answer_token_ids = self.tokenizer.encode( + self.response_template, add_special_tokens=False + ) + answer_token_ids = answer_token_ids[1:] + + answer_start_indices = [] + + for batch_idx, input_ids in enumerate(inputs["input_ids"]): + for i in range(len(input_ids) - len(answer_token_ids) + 1): + if ( + input_ids[i: i + len(answer_token_ids)].tolist() + == answer_token_ids + ): + answer_start_indices.append((batch_idx, i)) + break + + if not answer_start_indices: + + exit() + + return super().compute_loss(model, inputs, return_outputs) + + # Separate inputs into before and after "### Answer" + before_inputs = {k: [] for k in inputs.keys()} + after_inputs = {k: [] for k in inputs.keys()} + + for batch_idx, answer_start in answer_start_indices: + for k, v in inputs.items(): + if k == "labels": + labels_before = v[batch_idx].clone() + labels_before[answer_start:] = -100 + before_inputs[k].append(labels_before) + + labels_after = v[batch_idx].clone() + labels_after[:answer_start] = -100 + after_inputs[k].append(labels_after) + else: + before_inputs[k].append(v[batch_idx]) + after_inputs[k].append(v[batch_idx]) + + # Pad the inputs + max_before_len = max(len(seq) for seq in before_inputs["input_ids"]) + max_after_len = max(len(seq) for seq in after_inputs["input_ids"]) + + def pad_and_cut(sequences, max_len, pad_value): + return torch.stack( + [ + F.pad(seq[:max_len], (0, max_len - len(seq)), + value=pad_value) + for seq in sequences + ] + ) + + for k in before_inputs: + pad_value = 0 if k == "attention_mask" else self.tokenizer.pad_token_id + before_inputs[k] = pad_and_cut( + before_inputs[k], max_before_len, pad_value + ).to(model.device) + + for k in after_inputs: + pad_value = 0 if k == "attention_mask" else self.tokenizer.pad_token_id + after_inputs[k] = pad_and_cut(after_inputs[k], max_after_len, pad_value).to( + model.device + ) + + # Compute embeddings for the segment before "### Answer" without gradients + with torch.no_grad(): + before_outputs = model(**before_inputs) + before_loss = before_outputs.loss + + # Compute loss for the segment after "### Answer", conditioned on the segment before + after_outputs = model(**after_inputs) + after_loss = after_outputs.loss + + self.after_answer_losses.append(after_loss.item()) + self.before_answer_losses.append(before_loss.item()) + + self.accumulated_steps += 1 + # Check if an epoch has ended + if self.steps_per_epoch is None: + self.steps_per_epoch = len(self.train_dataset) // ( + self.args.train_batch_size * self.args.gradient_accumulation_steps + ) + + if ( + self.accumulated_steps % self.args.gradient_accumulation_steps == 0 + and (self.accumulated_steps // self.args.gradient_accumulation_steps) + % self.steps_per_epoch + == 0 + ): + self.on_epoch_end() + + if return_outputs: + return after_loss, (before_outputs, after_outputs) + return after_loss + + def on_epoch_end(self): + self.current_epoch += 1 + avg_after_loss = sum(self.after_answer_losses) / \ + len(self.after_answer_losses) + avg_before_loss = sum(self.before_answer_losses) / len( + self.before_answer_losses + ) + wandb.log( + { + "epoch_loss/avg_after_answer": avg_after_loss, + "epoch_loss/avg_before_answer": avg_before_loss, + }, + step=self.current_epoch * self.steps_per_epoch, + ) + + print("epoch_loss/avg_after_answer", avg_after_loss) + + self.after_answer_losses = [] + self.before_answer_losses = [] + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Fine-tune a language model on K&K with PEFT." + ) + parser.add_argument( + "--train_data", + type=str, + default="train/people3_num1000.jsonl", + help="Path to the training data file.", + ) + parser.add_argument( + "--test_data", + type=str, + default="test/people3_num100.jsonl", + help="Path to the test data file.", + ) + parser.add_argument( + "--model_checkpoint", + type=str, + default="meta-llama/Meta-Llama-3-8B", + help="Path to the model checkpoint.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./out", + help="Output directory for the fine-tuned model.", + ) + parser.add_argument( + "--num_train_epochs", type=int, default=2, help="Number of training epochs." + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=4, + help="Training batch size per device.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=8, + help="Number of gradient accumulation steps.", + ) + parser.add_argument( + "--learning_rate", type=float, default=5e-5, help="Learning rate." + ) + parser.add_argument( + "--max_seq_length", type=int, default=256, help="Maximum sequence length." + ) + parser.add_argument("--logging_steps", type=int, + default=1, help="Logging steps.") + parser.add_argument("--eval_steps", type=int, + default=2, help="eval steps.") + parser.add_argument( + "--save_steps", + type=float, + default=0, + help="Number of updates steps before two checkpoint saves if save_strategy=steps. Should be an integer or a float in range [0,1). If smaller than 1, will be interpreted as ratio of total training steps.", + ) + parser.add_argument( + "--save_strategy", + type=str, + default="steps", + help="The checkpoint save strategy to adopt during training. Possible values are: no, epoch, steps", + ) + parser.add_argument( + "--project_name", + type=str, + default="bench-conta", + help="Wandb project name.", + ) + parser.add_argument( + "--wandb_key", + default="", + type=str, + help="API key for W&B.", + ) + parser.add_argument( + "--run_name", type=str, default="kk_ft_sol_format", help="Wandb run name." + ) + parser.add_argument("--cot_ft", action="store_true") + parser.add_argument("--add_eos", action="store_true") + + return parser.parse_args() + + +# Formatting function +def formatting_prompts_func(example, eos_token): + output_texts = [] + + from dataset.prompt import system_instruction_no_reason + + for i in range(len(example["quiz"])): + text = ( + system_instruction_no_reason + + f"\n\n### Question: {example['quiz'][i]}\n### Answer:\nCONCLUSION:\n{example['solution_text_format'][i]}" + ) + text += eos_token + output_texts.append(text) + if i == 0: + print(text) + + return output_texts + + +def formatting_prompts_func_cot(example, eos_token): + output_texts = [] + from dataset.prompt import system_instruction + + cot_head = "Let's think step by step, by considering whether each person is lying and if that leads to contradiction." + for i in range(len(example["quiz"])): + cot_steps = example["cot_repeat_steps"][i] + cot_steps = " ".join(cot_steps) + cot_foot = example["cot_foot"][i] + text = ( + system_instruction + + f"\n\n### Question: {example['quiz'][i]}\n### Answer: {cot_head} {cot_steps} {cot_foot}\nCONCLUSION:\n{example['solution_text_format'][i]}" + ) + text += eos_token + + if i == 0: + print(text) + output_texts.append(text) + return output_texts + + +def main(): + init_seed() + args = parse_args() + peft_config = LoraConfig( + r=32, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + "lm_head", + ], + ) + + # Response template and data collator + if args.cot_ft: + response_template = "\n### Answer: Let's think step by step" + else: + response_template = "\n### Answer:\n" + + # Check if CUDA is available + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Initialize wandb + + _ = os.system("wandb login {}".format(args.wandb_key)) + os.environ["WANDB_API_KEY"] = args.wandb_key + wandb.init(project=args.project_name, name=args.run_name) + wandb.config.update(args) + + # Load dataset + kk_dataset = load_dataset('K-and-K/knights-and-knaves', data_files={ + "train": [args.train_data], + "test": [args.test_data], + },) + + model = AutoModelForCausalLM.from_pretrained( + args.model_checkpoint, + load_in_4bit=True, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint) + tokenizer.pad_token = tokenizer.eos_token + + if args.add_eos: + eos_token = tokenizer.eos_token + else: + eos_token = "" + print("eos_token", eos_token) + + new_format_func = partial( + formatting_prompts_func_cot if args.cot_ft else formatting_prompts_func, eos_token=eos_token) + + # Initialize trainer + trainer = CustomSFTTrainer( + response_template=response_template, + model=model, + train_dataset=kk_dataset["train"], + eval_dataset=kk_dataset["test"], + formatting_func=new_format_func, + args=SFTConfig( + output_dir=args.output_dir, # Set to None to disable saving + report_to="wandb", + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.train_batch_size, + per_device_eval_batch_size=args.train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + learning_rate=args.learning_rate, + fp16=True, + save_strategy=args.save_strategy, + save_steps=args.save_steps, + max_seq_length=args.max_seq_length, + logging_strategy="steps", + logging_steps=args.logging_steps, + evaluation_strategy="epoch", + eval_steps=args.eval_steps, + ), + peft_config=peft_config, + ) + + # Start training + trainer.train() + trainer.save_model(os.path.join(args.output_dir, "final_model")) + # Close wandb run + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/mem_cls_model.py b/mem_cls_model.py new file mode 100644 index 0000000..7f754e2 --- /dev/null +++ b/mem_cls_model.py @@ -0,0 +1,225 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from datasets import load_dataset +import tqdm +import argparse +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, roc_auc_score +from peft import PeftModel +import json +import os + + +def merge_adapter(base_model_path, adapter_path): + + print("Loading adapter...") + model = AutoModelForCausalLM.from_pretrained( + base_model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).cuda() + + if adapter_path != "": + tokenizer = AutoTokenizer.from_pretrained( + adapter_path, + trust_remote_code=True, + ) + + model.resize_token_embeddings(len(tokenizer)) + + model = PeftModel.from_pretrained(model, adapter_path) + model = model.merge_and_unload() + + return model + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run LLaMA model activations on dataset" + ) + parser.add_argument( + "--base_model_path", + type=str, + default="", + ) + parser.add_argument( + "--adapter_path", + type=str, + default="", + ) + parser.add_argument( + "--data_file", + type=str, + default="", + help="Path to the dataset file (e.g., /path/to/dataset.jsonl)", + ) + parser.add_argument( + "--output_file", + type=str, + default="cls_robust_results", + help="Path to the output JSON file to save results", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Load the dataset + kk_dataset = load_dataset( + "json", + data_files={ + "test": [args.data_file], + }, + ) + + statements = [] + robust_metrics = [] + for i in range(len(kk_dataset["test"])): + quiz = kk_dataset["test"]["quiz"][i] + statements.append(quiz) + + metric = kk_dataset["test"]["robust_metric"][i] + robust_metrics.append(metric) + + + # Load pre-trained LLaMA model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) # "meta-llama/Meta-Llama-3-8B" + tokenizer.pad_token = tokenizer.eos_token + + model = merge_adapter(args.base_model_path, args.adapter_path) + + # Define a forward hook to capture MLP activations + mlp_activations = { + i: [] for i in range(len(model.model.layers)) + } # One list per layer + + def get_mlp_activation_hook(layer_idx): + def hook(module, input, output): + mlp_activations[layer_idx].append(output.detach().cpu().numpy()) + + return hook + + # Register hooks to all MLP layers in the transformer blocks + for i, layer in enumerate(model.model.layers): + layer.mlp.register_forward_hook(get_mlp_activation_hook(i)) + + dataset = {i: [] for i in range(len(model.model.layers))} + labels = {i: [] for i in range(len(model.model.layers))} + + # Function to process statements and capture activations + def process_statements(statements, robust_metrics): + for text, metric in tqdm.tqdm(zip (statements, robust_metrics)): + from dataset.prompt import system_instruction_no_reason + input_prompt =system_instruction_no_reason + f"\n\n### Question: {text}\n### Answer:\n" + input_ids = tokenizer( + input_prompt, + return_tensors="pt", + + ).input_ids + for i in range(len(model.model.layers)): + mlp_activations[i] = [] # Reset activations for each layer + + # Run the model forward pass + with torch.no_grad(): + _ = model(input_ids.cuda()) + + # Store activations and corresponding labels + for i in range(len(model.model.layers)): + if mlp_activations[i]: # Check if activations were captured + dataset[i].append( + mlp_activations[i][0] + ) # Use the first batch output + labels[i].append(metric) + + # Process statements + process_statements(statements, robust_metrics ) # Label 1 for correct + + # Train classifiers for each layer's activations + classifiers = [] + accuracy_per_layer_train = [] # To store train accuracy + accuracy_per_layer_test = [] # To store test accuracy + + results = {} # Dictionary to store accuracy results + + # pdb.set_trace() + + # Splitting the data for each layer and training a classifier + for i in tqdm.tqdm(range(len(model.model.layers))): + X_layer = dataset[i] + y_layer = labels[i] + # Flatten the activations for the classifier + # X_layer = [x.flatten() for x in X_layer] + X_layer = [x.sum(axis=(0, 1)) for x in X_layer] + + # Split data into train and test sets + X_train, X_test, y_train, y_test = train_test_split( + X_layer, y_layer, test_size=0.2, random_state=42 + ) + + # import pdb + # pdb.set_trace() + # Initialize and train a simple logistic regression classifier + clf = LogisticRegression(max_iter=100000) + clf.fit(X_train, y_train) + + # Report train accuracy + y_train_pred = clf.predict(X_train) + train_accuracy = accuracy_score(y_train, y_train_pred) + accuracy_per_layer_train.append(train_accuracy) + + # Report test accuracy + y_test_pred = clf.predict(X_test) + test_accuracy = accuracy_score(y_test, y_test_pred) + accuracy_per_layer_test.append(test_accuracy) + + train_probs = clf.predict_proba(X_train) + test_probs = clf.predict_proba(X_test) + + train_auc= roc_auc_score(y_train, train_probs[:, 1]), + test_auc = roc_auc_score(y_test, test_probs[:, 1]), + + + classifiers.append(clf) # Save the classifier + # Store results for this layer + results[f"layer_{i}"] = { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + "train_auc": train_auc, + "test_auc": test_auc, + } + + print(f"Layer {i} train accuracy: {train_accuracy:.4f}") + print(f"Layer {i} classifier test accuracy: {test_accuracy:.4f}") + + # Save results to JSON + if args.adapter_path != "": + fname = ( + "-".join(args.adapter_path.split("/")[1:-1]) + .replace("_total_10ep", "") + .replace("_total_100ep", "") + ) + else: + fname = args.base_model_path.split("/")[-1] + # "base" + + if "meta-llama/Meta-Llama-3-8B" in args.base_model_path: + fname = "base_" + fname += args.data_file.split("/")[2].replace("_0shot", "") + + + if "leaf" in args.data_file: + fname += "_leaf" + elif "statement" in args.data_file: + fname += "_statement" + + with open(os.path.join(args.output_file, f"sysprompt_{fname}.json"), "w") as f: + json.dump(results, f, indent=4) + + print(f"Results saved to {args.output_file}") + + +if __name__ == "__main__": + main() diff --git a/mem_cls_puzzle.py b/mem_cls_puzzle.py new file mode 100644 index 0000000..c1a8c1c --- /dev/null +++ b/mem_cls_puzzle.py @@ -0,0 +1,269 @@ +import nltk + +# Download necessary NLTK data +nltk.download("punkt", quiet=True) +nltk.download("stopwords", quiet=True) +nltk.download("punkt_tab") + +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, classification_report +from sklearn.metrics import roc_auc_score +from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer +import numpy as np +import argparse +import string +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +import os +import json + +def preprocess_text(text): + try: + # If text is a list, concatenate all elements into a single string + if isinstance(text, list): + text = " ".join(text) + + # Lowercase and remove punctuation + text = text.lower().translate(str.maketrans("", "", string.punctuation)) + tokens = word_tokenize(text) + stop_words = set(stopwords.words("english")) + tokens = [word for word in tokens if word not in stop_words] + return " ".join(tokens) + except Exception as e: + print(f"Error processing text: {e}") + return "" + + +def vectorize_text(train, test, text_field="quiz", method="tfidf", num_ppl=5): + column_names = train.columns.tolist() + if f"clean_{text_field}" in column_names: + text_field = f"clean_{text_field}" # use clean data's field (not perturbed data) + + if method == "tfidf": + vectorizer = TfidfVectorizer(max_features=5000) + train_feature = vectorizer.fit_transform(train["processed_text"]) + test_feature = vectorizer.transform(test["processed_text"]) + train_feature = train_feature.toarray() + test_feature = test_feature.toarray() + elif method == "bow": + vectorizer = CountVectorizer(max_features=5000) + train_feature = vectorizer.fit_transform(train["processed_text"]) + test_feature = vectorizer.transform(test["processed_text"]) + train_feature = train_feature.toarray() + test_feature = test_feature.toarray() + + elif method == "charlength": + train_feature = np.asarray( + [len(s) for s in train["processed_text"].values] + ).reshape(-1, 1) + test_feature = np.asarray( + [len(s) for s in test["processed_text"].values] + ).reshape(-1, 1) + elif method == "wordlength": + train_feature = np.asarray( + [len(s.split(" ")) for s in train["processed_text"].values] + ).reshape(-1, 1) + test_feature = np.asarray( + [len(s.split(" ")) for s in test["processed_text"].values] + ).reshape(-1, 1) + + + + return train_feature, test_feature + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Run classification for memorization.") + parser.add_argument( + "--train_split", type=float, default=0.8, help="Fraction for training" + ) + parser.add_argument( + "--method", + type=str, + choices=["tfidf", "bow", "wordlength", "charlength", "combine",], + default="charlength", + help="Vectorization method", + ) + parser.add_argument( + "--text_field", + type=str, + choices=[ + "quiz", + "names", + "solution", + "solution_text", + "solution_text_format", + "cot_steps", + "cot_repeat_steps", + "statements", + "response", + "all_fields", + "state_quiz", + "state_quiz_resp", + "quiz_resp", + "state_resp", + ], + default="quiz", + help="The field to featurize", + ) + parser.add_argument( + "--input_file", + type=str, + default="", + help="Path to data jsonl file", + ) + parser.add_argument( + "--output_dir", type=str, default="result/", help="Directory to save output CSV" + ) + parser.add_argument("--no_balance_label", action="store_true") + + return parser.parse_args() + + +def prepare_cls_data(df, train_split=0.8): + return train_test_split( + df, + test_size=1 - train_split, + stratify=df["label"], + random_state=42, + ) + +def train_and_evaluate(train_feature, test_feature, train_label, test_label): + model = LogisticRegression(random_state=42,max_iter=10000) + model.fit(train_feature, train_label) + + train_pred = model.predict(train_feature) + test_pred = model.predict(test_feature) + + # Predict probabilities instead of labels + train_probs = model.predict_proba(train_feature) + test_probs = model.predict_proba(test_feature) + + + evaluation= { + "train_accuracy": accuracy_score(train_label, train_pred), + "test_accuracy": accuracy_score(test_label, test_pred), + "train_auc": roc_auc_score(train_label, train_probs[:, 1]), + "test_auc":roc_auc_score(test_label, test_probs[:, 1]), + + } + report= classification_report(test_label, test_pred,output_dict=True) + evaluation.update(report) + return evaluation + + +def main(): + args = parse_arguments() + data = pd.read_json(args.input_file, lines=True) + data["label"] = data["robust_metric"] + num_ppl= int(args.input_file.split("/")[-1].split("_")[0].replace("people","")) + print(num_ppl) + + if args.no_balance_label==False: + # Separate the data by label + data_0 = data[data["label"] == 0] + data_1 = data[data["label"] == 1] + + # Determine the size of the smaller class + min_size = min(len(data_0), len(data_1)) + + # Sample from each class to balance the dataset + balanced_data_0 = data_0.sample(n=min_size, random_state=42) + balanced_data_1 = data_1.sample(n=min_size, random_state=42) + + # Concatenate the balanced datasets + balanced_data = pd.concat([balanced_data_0, balanced_data_1]) + + # Shuffle the balanced dataset + data = balanced_data.sample(frac=1, random_state=42).reset_index(drop=True) + + train, test = prepare_cls_data(data, args.train_split) + + + methods=[] + if args.method=="combine": + methods=["tfidf", "bow", "wordlength" , "charlength",] + + else: + methods=[args.method] + + train_feature_list=[] + test_feature_list=[] + if args.text_field =="all_fields": + text_fields = ["statements", "quiz" ,"response" , "cot_repeat_steps", "cot_steps", ] + else: + text_fields = [args.text_field] + + for text_field in text_fields: + for method in methods: + + print(f"Processing {text_field} with {method}") + + train_feature, test_feature = vectorize_text( + train, test, text_field=text_field, method=method, num_ppl=num_ppl + ) + train_feature_list.append(train_feature) + test_feature_list.append(test_feature) + + # Initialize an empty array + concatenated_features = train_feature_list[0] + print(len(train_feature_list)) + # Use a for loop to concatenate the features + if len(train_feature_list)>1: + for i, feature in enumerate(train_feature_list[1:]): + concatenated_features = np.concatenate((concatenated_features, feature), axis=1) + train_feature=concatenated_features + + + print(len(test_feature_list)) + concatenated_features = test_feature_list[0] + if len(test_feature_list)>1: + for feature in test_feature_list[1:]: + concatenated_features = np.concatenate((concatenated_features, feature), axis=1) + test_feature=concatenated_features + + + print("Train_feature shape", train_feature.shape) + print("Test_feature shape", test_feature.shape) + + + evaluation={} + evaluation["method"] = args.method + evaluation["text_field"] = args.text_field + evaluation["input_file"] = args.input_file + evaluation_results = train_and_evaluate( + train_feature, + test_feature, + train["label"], + test["label"], + ) + evaluation.update(evaluation_results) + + print(evaluation) + + # # TODO: save eval results + os.makedirs(args.output_dir, exist_ok=True) + if args.no_balance_label: + output_file = os.path.join(args.output_dir, f"results_{num_ppl}_unbalanced.jsonl") + else: + output_file = os.path.join(args.output_dir, f"results_{num_ppl}_balanced.jonsl") + # Read existing data + existing_data = [] + if os.path.exists(output_file): + with open(output_file, 'r') as file: + for line in file: + existing_data.append(json.loads(line)) + existing_data.append(evaluation) + + # Write all data back to the file + with open(output_file, 'w') as file: + for item in existing_data: + json.dump(item, file) + file.write('\n') + + + +if __name__ == "__main__": + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/anthropic.py b/models/anthropic.py new file mode 100644 index 0000000..aa13f69 --- /dev/null +++ b/models/anthropic.py @@ -0,0 +1,60 @@ +import time +import tiktoken +from copy import deepcopy + +from .base import LLMBase +import anthropic + + +class Claude(LLMBase): + def __init__( + self, + model_path=None, + max_attempts=100, + max_tokens=2048, + temperature=0, + ): + self.client = anthropic.Anthropic() + self.max_attempts = max_attempts + self.delay_seconds = 1 + + self.model = model_path.replace("anthropic/", "") + + + self.parameters = {"max_tokens": max_tokens, "temperature": temperature} + + self.num_tokens = 0 + + def query(self, prompt): + pred = self.chat_query(prompt) + return pred + + def chat_query(self, prompt, messages=None): + + n_attempt = 0 + params = deepcopy(self.parameters) + + if messages is None: + messages = [{"role": "user", "content": prompt}] + + print("messages", messages) + while n_attempt < self.max_attempts: + try: + completion = self.client.messages.create( + model=self.model, + messages=messages, + **params, + ) + response = completion.content[0].text + return response + except Exception as e: + # Catch any exception that might occur and print an error message + print(f"An error occurred: {e}, retry {n_attempt}") + n_attempt += 1 + time.sleep(self.delay_seconds * n_attempt) + + if n_attempt == self.max_attempts: + print("Max number of attempts reached") + return "" + + return "" diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..58561a3 --- /dev/null +++ b/models/base.py @@ -0,0 +1,33 @@ +class LLMBase: + def __init__(self, model_path=None, api_key=None): + """ + Initialize a Large Language Model (LLM). + + Parameters: + + - model_path (str): The file path or URL to the model. Default is None. + - api_key (str): The API key for querying closed-source models. Default is None. + + """ + + self.model_path = model_path # file path or URL that points to the model + self.api_key = api_key # API key for accessing LLMs (e.g., ChatGPT) + self.num_tokens=0 + self.load_model() + + def load_model(self): + pass + + def query(self, text): + """ + Query a model with a given text prompt. + + Parameters: + - text (str): The text prompt to query the model. + + Returns: + - str: The model's output. + """ + pass + + diff --git a/models/hf.py b/models/hf.py new file mode 100644 index 0000000..e5fefc3 --- /dev/null +++ b/models/hf.py @@ -0,0 +1,119 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import numpy as np +from .base import LLMBase +from vllm import LLM, SamplingParams +import os, json + + +class CasualLM(LLMBase): + """Huggingface Casual Language Models. + + Parameters: + - model_path (str): The path/name for the desired language model. + - arch (str, optional): The model architecture if different from model_path. + - use_vllm (bool): Whether to use vLLM for inference. + - max_tokens (int): Maximum number of tokens to generate. + """ + + def __init__( + self, + model_path=None, + arch=None, + use_vllm=False, + max_tokens=2048, + ): + self.arch = arch if arch is not None else model_path + self.tokenizer_use_fast = True + self.max_tokens = max_tokens + self.use_vllm=use_vllm + super().__init__(model_path=model_path) + + def load_model(self, model_path=None): + if model_path is None: + model_path = self.model_path + if self.use_vllm: + self.model = LLM( + model=model_path, + tokenizer=model_path, + gpu_memory_utilization=0.9, + ) + + self.tokenizer = AutoTokenizer.from_pretrained(self.arch) + else: + + torch_dtype = torch.bfloat16 + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + load_in_8bit=False, + low_cpu_mem_usage=True, + device_map="auto", + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(self.arch) + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + model.generation_config.pad_token_id = model.generation_config.eos_token_id + + self.model = model + self.tokenizer = tokenizer + + print( + f"> Loading the provided {self.arch} checkpoint from '{model_path}'." + ) + + def query(self, prompt): + return self.query_generation(prompt) + + @torch.no_grad() + def query_generation(self, prompt): + try: + if self.use_vllm: + sampling_params = SamplingParams(max_tokens=self.max_tokens) + outputs = self.model.generate( + [prompt], sampling_params, + ) + pred = outputs[0].outputs[0].text + else: + if self.model_path in [ + "deepseek-ai/deepseek-math-7b-instruct", + "AI-MO/NuminaMath-7B-CoT", + "microsoft/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-medium-4k-instruct", + ]: + messages = [{"role": "user", "content": prompt}] + print(messages) + input_tensor = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt" + ) + outputs = self.model.generate( + input_tensor.to(self.model.device), + max_new_tokens=self.max_tokens, + ) + pred = self.tokenizer.decode( + outputs[0][input_tensor.shape[1] :], skip_special_tokens=True + ) + else: + model_inputs = self.tokenizer(prompt, return_tensors="pt").to( + self.model.device + ) + generated_ids = self.model.generate( + **model_inputs, max_new_tokens=self.max_tokens + ) + pred = self.tokenizer.batch_decode( + generated_ids[:, model_inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + )[0] + except Exception as e: + print(e) + pred = "" + return pred + + + +if __name__ == "__main__": + model = CasualLM("deepseek-ai/deepseek-math-7b-instruct") + print(model.query("what is your name?")) + print("DONE") diff --git a/models/openai.py b/models/openai.py new file mode 100644 index 0000000..d218c6e --- /dev/null +++ b/models/openai.py @@ -0,0 +1,55 @@ +import time +import tiktoken +from copy import deepcopy +from .base import LLMBase +import openai + + +class ChatGPT(LLMBase): + def __init__( + self, + model_path=None, + max_attempts=100, + max_tokens=2048, + temperature=0, + ): + + self.client = openai.Client() + self.max_attempts = max_attempts + self.delay_seconds = 1 + self.model = model_path.replace("openai/", "") + self.parameters = {"max_tokens": max_tokens, "temperature": temperature} + self.num_tokens = 0 + + def query(self, prompt): + pred = self.chat_query(prompt) + + return pred + + def chat_query(self, prompt, messages=None): + + n_attempt = 0 + params = deepcopy(self.parameters) + + if messages is None: + messages = [{"role": "user", "content": prompt}] + + print("messages", messages) + while n_attempt < self.max_attempts: + try: + completion = self.client.chat.completions.create( + model=self.model, messages=messages, **params + ) + response = completion.choices[0].message.content + return response + except Exception as e: + # Catch any exception that might occur and print an error message + print(f"An error occurred: {e}, retry {n_attempt}") + n_attempt += 1 + time.sleep(self.delay_seconds * n_attempt) + + if n_attempt == self.max_attempts: + print("Max number of attempts reached") + return "" + + return "" diff --git a/probe.py b/probe.py new file mode 100644 index 0000000..941da52 --- /dev/null +++ b/probe.py @@ -0,0 +1,211 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +import tqdm +import argparse +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score +from peft import PeftModel +import json +import os +import random + + +def merge_adapter(base_model_path, adapter_path): + + print("Loading adapter...") + model = AutoModelForCausalLM.from_pretrained( + base_model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).cuda() + + if adapter_path != "": + tokenizer = AutoTokenizer.from_pretrained( + adapter_path, + trust_remote_code=True, + ) + + model.resize_token_embeddings(len(tokenizer)) + + model = PeftModel.from_pretrained(model, adapter_path) + model = model.merge_and_unload() + + return model + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run LLaMA model activations on dataset" + ) + parser.add_argument( + "--base_model_path", + type=str, + default="", + ) + parser.add_argument( + "--adapter_path", + type=str, + default="", + ) + parser.add_argument( + "--nppl_eval", + type=int, + default=2, + help="# ppl task for probing", + ) + parser.add_argument( + "--output_file", + type=str, + default="probe_results", + help="Path to the output JSON file to save results", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Load the dataset + kk_dataset = load_dataset( + "json", + data_files={ + "test": [f"data/test/clean/people{args.nppl_eval}_num100.jsonl"], + }, + ) + + statement_wrong = [] + statement_correct = [] + for i in range(len(kk_dataset["test"])): + quiz = kk_dataset["test"]["quiz"][i] + names = kk_dataset["test"]['names'][i] + solutions = kk_dataset["test"]['solution'][i] + random_names = random.sample(names, 2) + + for name, is_knight in zip(names, solutions): + if name in random_names: + role = 'knight' if is_knight else 'knave' + wrong_role = 'knave' if is_knight else 'knight' + statement_correct.append(f'{quiz} {name} is {role}.') + statement_wrong.append(f'{quiz} {name} is {wrong_role}.') + else: + continue + print(len(statement_correct)) + + # Load pre-trained LLaMA model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.base_model_path) + tokenizer.pad_token = tokenizer.eos_token + + model = merge_adapter(args.base_model_path, args.adapter_path) + + # Define a forward hook to capture MLP activations + mlp_activations = { + i: [] for i in range(len(model.model.layers)) + } # One list per layer + + def get_mlp_activation_hook(layer_idx): + def hook(module, input, output): + mlp_activations[layer_idx].append(output.detach().cpu().numpy()) + + return hook + + # Register hooks to all MLP layers in the transformer blocks + for i, layer in enumerate(model.model.layers): + layer.mlp.register_forward_hook(get_mlp_activation_hook(i)) + + dataset = {i: [] for i in range(len(model.model.layers))} + labels = {i: [] for i in range(len(model.model.layers))} + + # Function to process statements and capture activations + def process_statements(statements, label): + for text in tqdm.tqdm(statements): + input_ids = tokenizer( + text, + return_tensors="pt", + ).input_ids + for i in range(len(model.model.layers)): + mlp_activations[i] = [] # Reset activations for each layer + + # Run the model forward pass + with torch.no_grad(): + _ = model(input_ids.cuda()) + + # Store activations and corresponding labels + for i in range(len(model.model.layers)): + if mlp_activations[i]: # Check if activations were captured + dataset[i].append( + mlp_activations[i][0] + ) # Use the first batch output + labels[i].append(label) + + # Process correct and wrong statements + process_statements(statement_correct, 1) # Label 1 for correct + process_statements(statement_wrong, 0) # Label 0 for wrong + + # Train classifiers for each layer's activations + classifiers = [] + accuracy_per_layer_train = [] # To store train accuracy + accuracy_per_layer_test = [] # To store test accuracy + + results = {} # Dictionary to store accuracy results + + # pdb.set_trace() + + # Splitting the data for each layer and training a classifier + for i in tqdm.tqdm(range(len(model.model.layers))): + X_layer = dataset[i] + y_layer = labels[i] + # Flatten the activations for the classifier + # X_layer = [x.flatten() for x in X_layer] + X_layer = [x.sum(axis=(0, 1)) for x in X_layer] + + # Split data into train and test sets + X_train, X_test, y_train, y_test = train_test_split( + X_layer, y_layer, test_size=0.2, random_state=42 + ) + + # import pdb + # pdb.set_trace() + # Initialize and train a simple logistic regression classifier + clf = LogisticRegression(max_iter=1000) + clf.fit(X_train, y_train) + + # Report train accuracy + y_train_pred = clf.predict(X_train) + train_accuracy = accuracy_score(y_train, y_train_pred) + accuracy_per_layer_train.append(train_accuracy) + + # Report test accuracy + y_test_pred = clf.predict(X_test) + test_accuracy = accuracy_score(y_test, y_test_pred) + accuracy_per_layer_test.append(test_accuracy) + + classifiers.append(clf) # Save the classifier + # Store results for this layer + results[f"layer_{i}"] = { + "train_accuracy": train_accuracy, + "test_accuracy": test_accuracy, + } + + print(f"Layer {i} prober accuracy: {train_accuracy:.4f}") + + + # Save results to JSON + if args.adapter_path != "": + fname = ( + "-".join(args.adapter_path.split("/")[1:-1]) + .replace("_total_10ep", "") + .replace("_total_100ep", "") + ) + else: + fname = "base" + with open(os.path.join(args.output_file, f"nppl{args.nppl_eval}-{fname}.json"), "w") as f: + json.dump(results, f, indent=4) + + print(f"Results saved to {args.output_file}") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval/claude-sonet.sh b/scripts/eval/claude-sonet.sh new file mode 100644 index 0000000..24fcd06 --- /dev/null +++ b/scripts/eval/claude-sonet.sh @@ -0,0 +1,21 @@ +model="anthropic/claude-3-5-sonnet-20240620" + + +config="vllm" + +max_token=2048 +ntrain=0 +num_limit=100 + +echo "Processing num_limit: $num_limit" +for eval_nppl in 2 3 4 5 6 7 8; +do + echo "Processing ntrain: $ntrain" + + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split "test" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split "test" --problem_type "perturbed_statement" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split "test" --problem_type "perturbed_leaf" --eval_nppl ${eval_nppl} + # python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --cot --limit ${num_limit} --split "test" --eval_nppl ${eval_nppl} +done + + diff --git a/scripts/eval/eval_test.sh b/scripts/eval/eval_test.sh new file mode 100644 index 0000000..ac8c30c --- /dev/null +++ b/scripts/eval/eval_test.sh @@ -0,0 +1,31 @@ +config="vllm" +max_token=2048 +num_limit=100 + +echo "Processing num_limit: $num_limit" + + +# no \n after `### Answer:` in the prompt (--no_linebreak) for Meta-Llama-3-8B base model +model="meta-llama/Meta-Llama-3-8B" +python eval_kk.py --no_linebreak --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --cot --limit ${num_limit} --split "test" +python eval_kk.py --no_linebreak --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" +python eval_kk.py --no_linebreak --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 1 --config ${config} --use_vllm --limit ${num_limit} --split "test" +python eval_kk.py --no_linebreak--batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 1 --config ${config} --use_vllm --cot --limit ${num_limit} --split "test" + + +models=( +"AI-MO/NuminaMath-7B-CoT" +"deepseek-ai/deepseek-math-7b-instruct" +"microsoft/Phi-3-medium-4k-instruct" +"microsoft/Phi-3-mini-4k-instruct" +) +# Iterate over each model +for model in "${models[@]}"; +do + echo "Processing model: $model" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --cot --limit ${num_limit} --split "test" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 1 --config ${config} --use_vllm --limit ${num_limit} --split "test" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 1 --config ${config} --use_vllm --cot --limit ${num_limit} --split "test" + +done diff --git a/scripts/eval/eval_test_perturb.sh b/scripts/eval/eval_test_perturb.sh new file mode 100644 index 0000000..c6e3b8c --- /dev/null +++ b/scripts/eval/eval_test_perturb.sh @@ -0,0 +1,28 @@ +config="vllm" +max_token=2048 +num_limit=100 + +echo "Processing num_limit: $num_limit" + +# no \n after `### Answer:` in the prompt (--no_linebreak) for Meta-Llama-3-8B base model +model="meta-llama/Meta-Llama-3-8B" +python eval_kk.py --no_linebreak --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" --problem_type "perturbed_leaf" +python eval_kk.py --no_linebreak --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" --problem_type "perturbed_statement" + + + +models=( +"AI-MO/NuminaMath-7B-CoT" +"deepseek-ai/deepseek-math-7b-instruct" +"microsoft/Phi-3-medium-4k-instruct" +"microsoft/Phi-3-mini-4k-instruct" +) +# Iterate over each model +for model in "${models[@]}"; +do + echo "Processing model: $model" + + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" --problem_type "perturbed_leaf" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "test" --problem_type "perturbed_statement" + +done diff --git a/scripts/eval/eval_train.sh b/scripts/eval/eval_train.sh new file mode 100644 index 0000000..1c9b519 --- /dev/null +++ b/scripts/eval/eval_train.sh @@ -0,0 +1,23 @@ +config="vllm" +arch="meta-llama/Meta-Llama-3-8B" + +max_token=2048 + +models=( + # add "YOUR_FINETUNED_MODEL_PATH" + "ftllama/3ppl-direct-FT-50ep" +) +num_limit=100 # remove --num_limit if you want to evaluate on full dataset + +for eval_nppl in 2 3 4 5 6 7 8; +do + echo "Processing eval_nppl: $eval_nppl" + + for model in "${models[@]}"; + do + echo "Processing model: $model" + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "train" --problem_type "clean" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${model} --ntrain 0 --config ${config} --use_vllm --cot --limit ${num_limit} --split "train" --problem_type "clean" --eval_nppl ${eval_nppl} + done +done + diff --git a/scripts/eval/eval_train_pertub.sh b/scripts/eval/eval_train_pertub.sh new file mode 100644 index 0000000..b9c28ce --- /dev/null +++ b/scripts/eval/eval_train_pertub.sh @@ -0,0 +1,28 @@ +config="vllm" +max_token=2048 +arch="meta-llama/Meta-Llama-3-8B" +models=( + # add "YOUR_FINETUNED_MODEL_PATH" + "ftllama/3ppl-direct-FT-50ep" +) + +num_limit=100 # remove --num_limit if you want to evaluate on full dataset + +for eval_nppl in 2 3 4 5 6 7 8; +do + echo "Processing eval_nppl: $eval_nppl" + + for model in "${models[@]}"; + do + echo "Processing model: $model" + + for problem_type in "perturbed_statement" "perturbed_leaf" "random_pair" "reorder_statement" "uncommon_name" "flip_role"; + do + + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --arch ${arch} --ntrain 0 --config ${config} --use_vllm --limit ${num_limit} --split "train" --problem_type ${problem_type} --eval_nppl ${eval_nppl} + + + done + done +done + diff --git a/scripts/eval/gpt4omini_cot.sh b/scripts/eval/gpt4omini_cot.sh new file mode 100644 index 0000000..38a69cf --- /dev/null +++ b/scripts/eval/gpt4omini_cot.sh @@ -0,0 +1,29 @@ +model="openai/gpt-4o-mini-2024-07-18" +# model="YOUT-COT-FTED-MODEL-PATH" + + +config="vllm" +num_limit=100 + + +max_token=2048 +ntrain=0 + +for eval_nppl in 2 3 4 5 6 7 8; +do + echo "Processing eval_nppl: $eval_nppl" + for split in "train" "test"; + do + echo "Processing split: $split" + + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "clean" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "random_pair" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "reorder_statement" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "uncommon_name" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "flip_role" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "perturbed_statement" --eval_nppl ${eval_nppl} + python eval_kk.py --cot --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "perturbed_leaf" --eval_nppl ${eval_nppl} + done +done + + diff --git a/scripts/eval/gpt4omini_direct.sh b/scripts/eval/gpt4omini_direct.sh new file mode 100644 index 0000000..3294e0b --- /dev/null +++ b/scripts/eval/gpt4omini_direct.sh @@ -0,0 +1,28 @@ +model="openai/gpt-4o-mini-2024-07-18" + + +config="vllm" +num_limit=100 + + +max_token=2048 +ntrain=0 + +for eval_nppl in 2 3 4 5 6 7 8; +do + echo "Processing eval_nppl: $eval_nppl" + for split in "train" "test"; + do + echo "Processing split: $split" + + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "clean" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "random_pair" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "reorder_statement" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "uncommon_name" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "flip_role" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "perturbed_statement" --eval_nppl ${eval_nppl} + python eval_kk.py --batch_size 8 --model ${model} --max_token ${max_token} --ntrain ${ntrain} --config ${config} --limit ${num_limit} --split ${split} --problem_type "perturbed_leaf" --eval_nppl ${eval_nppl} + done +done + + diff --git a/scripts/ft/ft_lm3.sh b/scripts/ft/ft_lm3.sh new file mode 100644 index 0000000..94df38b --- /dev/null +++ b/scripts/ft/ft_lm3.sh @@ -0,0 +1,10 @@ +python finetune_kk.py \ + --train_data "train/people3_num1000.jsonl" \ + --test_data "test/people3_num100.jsonl" \ + --run_name kk_ft_ppl3 \ + --output_dir ./result/out_nocot/train3 \ + --num_train_epochs 50 \ + --save_strategy steps \ + --save_steps 0.2 \ + --max_seq_length 512 \ + --eval_steps 5 \ diff --git a/scripts/ft/ft_lm3_cot.sh b/scripts/ft/ft_lm3_cot.sh new file mode 100644 index 0000000..19ce57e --- /dev/null +++ b/scripts/ft/ft_lm3_cot.sh @@ -0,0 +1,11 @@ +python finetune_kk.py \ + --train_data "train/people3_num1000.jsonl" \ + --test_data "test/people3_num100.jsonl" \ + --run_name kk_ft_ppl3_cot \ + --output_dir ./result/out_cot/train3 \ + --cot_ft \ + --num_train_epochs 50 \ + --save_strategy steps \ + --save_steps 0.2 \ + --max_seq_length 512 \ + --eval_steps 5 \ diff --git a/scripts/ft/merge_adapter.sh b/scripts/ft/merge_adapter.sh new file mode 100644 index 0000000..2546825 --- /dev/null +++ b/scripts/ft/merge_adapter.sh @@ -0,0 +1,4 @@ +base_model_path="meta-llama/Meta-Llama-3-8B" # Base model path +adapter_path="" # Adapter path from fine-tuning +target_model_path="" # Merged model save path +python merge.py --base_model_path $base_model_path --adapter_path $adapter_path --target_model_path $target_model_path diff --git a/scripts/mem_classify/model_indicator.sh b/scripts/mem_classify/model_indicator.sh new file mode 100644 index 0000000..30d2f36 --- /dev/null +++ b/scripts/mem_classify/model_indicator.sh @@ -0,0 +1,19 @@ +nppl=5 +epoch=50 + +for perturb_type in "leaf" "statement" +do + for nppl in 3 5 + do + ft_model_path="ftllama/${nppl}ppl-direct-FT-${epoch}ep" + base_model_path="meta-llama/Meta-Llama-3-8B" + for model_path in ${base_model_path} ${ft_model_path} + do + data_file="result_qa/ftllama/${nppl}ppl-direct-FT-${epoch}ep_0shot/vllm_token2048_train_perturbed_${perturb_type}/people${nppl}_num1000_classify.jsonl" + echo $model_path + echo $data_file + python mem_cls_model.py --base_model_path ${model_path} --data_file ${data_file} + done + done +done + diff --git a/scripts/mem_classify/puzzle_indicator.sh b/scripts/mem_classify/puzzle_indicator.sh new file mode 100644 index 0000000..f727916 --- /dev/null +++ b/scripts/mem_classify/puzzle_indicator.sh @@ -0,0 +1,19 @@ + +input_files=( +# process data that contains the labels of Consistenly Solved v.s. non Consistenly Solved Puzzles +"result_qa/ftopenai/ppl3-1000-cot-repeat-5ep_0shot/vllm_token2048_cot_train_perturbed_leaf/people3_num1000_classify.jsonl" + +) +for input_file in "${input_files[@]}"; + do + echo $input_file + for text_field in "all_fields" "quiz" "response" "cot_repeat_steps" + do + for method in "combine" "tfidf" "bow" "wordlength" "charlength" + do + python mem_cls_puzzle.py --method ${method} --text_field ${text_field} --input_file ${input_file} --no_balance_label + done + done + done +done + diff --git a/scripts/probe/run.sh b/scripts/probe/run.sh new file mode 100644 index 0000000..98599b1 --- /dev/null +++ b/scripts/probe/run.sh @@ -0,0 +1,5 @@ +nppl_eval=3 +adapter_path="" +base_model_path="" + +python probe.py --base_model_path $base_model_path --adapter_path $adapter_path --nppl_eval $nppl_eval diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..615047f --- /dev/null +++ b/utils.py @@ -0,0 +1,82 @@ +import argparse +import json +import os +import numpy as np +import pandas as pd +import random +import torch +import time +import datasets + +def load_jsonl(file_path): + records = [] + with open(file_path, "r") as file: + for line in file: + records.append(json.loads(line)) + return records + +def write_jsonl(output_file, data): + + with open(output_file, "w") as file: + for item in data: + json_line = json.dumps(item) + file.write(json_line + "\n") + +def batch_decode_vllm(llm, prompts, batch_size=32): + """ + Perform batch decoding using vLLM. + + Args: + - llm: The vLLM model instance + - prompts: List of prompts to process + - batch_size: Number of prompts to process in each batch + + Returns: + - List of generated responses + """ + from vllm import SamplingParams # type: ignore + + all_responses = [] + for i in range(0, len(prompts), batch_size): + batch_prompts = prompts[i : i + batch_size] + sampling_params = SamplingParams(max_tokens=llm.max_tokens, temperature=0) + outputs = llm.model.generate( + batch_prompts, sampling_params + ) + responses = [output.outputs[0].text for output in outputs] + all_responses.extend(responses) + return all_responses + + +def init_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +def load_llm(args): + if "openai" in args.model: + from models.openai import ChatGPT + llm = ChatGPT(model_path=args.model, max_tokens=args.max_token) + elif "anthropic" in args.model: + from models.anthropic import Claude + llm = Claude(model_path=args.model, max_tokens=args.max_token) + else: + from models.hf import CasualLM + llm = CasualLM( + model_path=args.model, + arch=args.arch, + use_vllm=args.use_vllm, + max_tokens=args.max_token, + ) + return llm + +def load_eval_records(args, subject): + if args.problem_type != "clean": + records = datasets.load_dataset('K-and-K/perturbed-knights-and-knaves',data_files=f"{args.split}/{args.problem_type}/{subject}.jsonl")["train"] + else: + records = datasets.load_dataset('K-and-K/knights-and-knaves',data_files=f"{args.split}/{subject}.jsonl")["train"] + return records \ No newline at end of file