diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..f57fca0 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,41 @@ +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +permissions: + contents: read + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Install build dependencies + run: pip install numpy + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 82f9275..9941588 100644 --- a/.gitignore +++ b/.gitignore @@ -106,10 +106,8 @@ ipython_config.py #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +# https://pdm.fming.dev/#use-with-ide .pdm.toml -.pdm-python -.pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -159,4 +157,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +# data directory +data/ \ No newline at end of file diff --git a/Evaluate_Description-Embedding_Body.ipynb b/Evaluate_Description-Embedding_Body.ipynb new file mode 100644 index 0000000..1da16cb --- /dev/null +++ b/Evaluate_Description-Embedding_Body.ipynb @@ -0,0 +1,475 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "ed004c90", + "metadata": { + "ExecuteTime": { + "start_time": "2024-01-17T12:55:35.436360Z" + }, + "is_executing": true + }, + "outputs": [], + "source": [ + "# This script seeks a better alternative for the current labels used in the FuisionBody.label_embedding_model_body. \n", + "# For this purpose, it evaluattes an alternative embeddings of class descriptions, against the currently implemented default, that embeds label-descriptions.\n", + "\n", + "from fusionsent import FusionSentModel, Trainer, TrainingArguments\n", + "from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer\n", + "from datasets import load_dataset, Dataset\n", + "from transformers import AutoTokenizer\n", + "import numpy as np\n", + "import openai #Please note that openai is not listed in our requirements.txt file. Run $'pip install openai', to install the package.\n", + "import torch\n", + "import json\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9ef91033", + "metadata": {}, + "outputs": [], + "source": [ + "# Setting environment variables\n", + "cwd = os.path.abspath(os.getcwd())\n", + "os.environ['WORLD_SIZE'] = str(torch.cuda.device_count())\n", + "os.environ['MASTER_ADDR'] = 'localhost'\n", + "os.environ['MASTER_PORT'] = '29500'" + ] + }, + { + "cell_type": "markdown", + "id": "3351d25b", + "metadata": {}, + "source": [ + "# Load and Prepare All Datasets" + ] + }, + { + "cell_type": "markdown", + "id": "cb6f76dd", + "metadata": {}, + "source": [ + "*1. Download original data.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16a4ced9", + "metadata": {}, + "outputs": [], + "source": [ + "# The below are the exact datasets used for training in the original setfit paper.\n", + "# If not existent already, we will load them all, and store them locally in order to add label descriptions.\n", + "dataset_ids_binary_label: list[str] = [\"CR\", \"emotion\", \"enron_spam\"]\n", + "dataset_ids_nonbinary_label: list[str] = [\"sst5\", \"amazon_counterfactual\", \"emotion\", \"ag_news\"]\n", + "dataset_ids = dataset_ids_binary_label + dataset_ids_nonbinary_label\n", + "data_dir_original = \"./data/original\"\n", + "datasets_original = {} \n", + "\n", + "for dataset_id in dataset_ids:\n", + " print(f\"Loading dataset: '{dataset_id}'\")\n", + " datasets_original[dataset_id] = {}\n", + " for split in [\"train\", \"test\"]:\n", + " try:\n", + " dataset_split = load_dataset(f\"SetFit/{dataset_id}\", split=split)\n", + " datasets_original[dataset_id][split] = dataset_split\n", + " except ValueError as e:\n", + " print(f\"Could not load dataset '{dataset_id}'. An error occurred: {e}\")\n", + " datasets_original.pop(dataset_id)\n", + " break\n", + "print(\"-- Done --\")" + ] + }, + { + "cell_type": "markdown", + "id": "a0560b42", + "metadata": {}, + "source": [ + "*2. Generate label descriptions via OpenAI and save them to files.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2ff0ad8", + "metadata": {}, + "outputs": [], + "source": [ + "# ToDo: Fix generation for datasets 'enron_spam', and 'ag_news'.\n", + "data_dir_label_descriptions = \"./data/label_descriptions\"\n", + "label_description_file_template = \"{}_label_descriptions.json\"\n", + "os.makedirs(data_dir_label_descriptions, exist_ok=True)\n", + "\n", + "openai_api_key = \"your-openai-key\"\n", + "open_ai_model =\"gpt-4-0125-preview\"\n", + "regenerate = False\n", + "\n", + "def get_label_description(dataset_name: str, label: str, label_text: str, examples: list[str]) -> str:\n", + " try:\n", + " client = openai.OpenAI(api_key=openai_api_key)\n", + " completion = client.chat.completions.create(\n", + " model=open_ai_model,\n", + " messages= [\n", + " {\n", + " \"role\": \"system\", \n", + " \"content\": \"\"\"\n", + " You are a scientific research assistant, in the area of Natrual Language Processing.\n", + " Your purpose is to write comprhesnive, concise, and short descriptions for a given label of a dataset.\n", + " For each label, you will be provided some examples of data samples that are annoted with the resp. label.\n", + " Rules:\n", + " 1. Be consise in your descriptions.\n", + " 2. Each decitpion should be exactly one sentence long.\n", + " Not complying with the rules will result in termination. \n", + " \"\"\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"\"\"\n", + " Dataset name: '{dataset_name}'\\n\n", + " Label key: '{label}'\\n\n", + " Label name: '{label_text}'\\n\n", + " ---\\n\\n\n", + " Example Samples annotated with '{label_text}':\\n\\n\n", + " {examples}\\n\\n\n", + " ---\\n\\n\n", + " Please describe the essence of the label '{label}': '{label_text}' in one sentence:\n", + " \"\"\"\n", + " }\n", + " ]\n", + " )\n", + " if completion.choices and completion.choices[0].message and completion.choices[0].message.content:\n", + " response = completion.choices[0].message.content\n", + " print(f\"Obtained description for {dataset_id}/{label_text}: {response}\")\n", + " return response\n", + " else:\n", + " raise Exception(\"Invalid response from OpenAI: No content in the response.\")\n", + " except Exception as e:\n", + " raise Exception(f\"Unexpected error with the response from OpenAI: {str(e)}\")\n", + "\n", + "for dataset_id in dataset_ids:\n", + " description_file_path = os.path.join(data_dir_label_descriptions, label_description_file_template.format(dataset_id))\n", + " if (not regenerate) and os.path.exists(description_file_path):\n", + " print(f\"Skipped label generation for '{dataset_id}' dataset (File already exists).\")\n", + " continue\n", + " # Samples from SetFit/enron_spam are too large.\n", + " if dataset_id == \"enron_spam\":\n", + " continue\n", + "\n", + " # Process the dataset to get label-to-data mapping\n", + " label_to_data = {}\n", + " label_to_label_text = {}\n", + " for item in datasets_original[dataset_id][\"train\"]:\n", + " label = item['label']\n", + " text = item['text']\n", + " if label not in label_to_data:\n", + " label_to_data[label] = []\n", + " if label not in label_to_label_text:\n", + " label_to_label_text[label] = item[\"label_text\"]\n", + " label_to_data[label].append(text)\n", + "\n", + " # Sample the 5 examples or less (because of open ai token rate limits) per label and generate label descriptions\n", + " label_to_description = {}\n", + " hasEncounteredError = False\n", + " for label, examples in label_to_data.items():\n", + " sampled_examples: list[str] = np.random.choice(examples, size=5, replace=False).tolist()\n", + " while sum([len(t) for t in sampled_examples]) > 100:\n", + " sampled_examples = sampled_examples[:-1]\n", + " #print(sum([len(t) for t in sampled_examples]))\n", + " try:\n", + " description = get_label_description(dataset_id, label, label_to_label_text[label], examples)\n", + " except Exception as e:\n", + " hasEncounteredError=True\n", + " break\n", + " label_to_description[label] = description\n", + "\n", + " if hasEncounteredError:\n", + " print(f\"An error occurred during label description generation for datatset '{dataset_id}'. Skipping...\")\n", + " continue\n", + "\n", + " # Save the label-to-description mappings\n", + " with open(description_file_path, 'w') as f:\n", + " json.dump(label_to_description, f, indent=2, ensure_ascii=False)\n", + " \n", + " print(f\"Saved label descriptions for '{dataset_id}' dataset.\")" + ] + }, + { + "cell_type": "markdown", + "id": "e10bc9d3", + "metadata": {}, + "source": [ + "*3. Format the datasets in order to pass them into the DualSen model*" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "6519f3c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "[\"The label '1', denoted as 'positive', applies to data samples expressing favorable, satisfactory, or beneficial opinions, experiences, or outcomes.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "[\"The label '1', denoted as 'positive', applies to data samples expressing favorable, satisfactory, or beneficial opinions, experiences, or outcomes.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "['positive']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "['positive']\n", + "Sucessfully formatted dataset 'CR'.\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "[\"The essence of the label '0': 'sadness' is characterized by feelings of hopelessness, disappointment, melancholy, and vulnerability, often accompanied by a sense of isolation or being overwhelmed.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "[\"The essence of the label '0': 'sadness' is characterized by feelings of hopelessness, disappointment, melancholy, and vulnerability, often accompanied by a sense of isolation or being overwhelmed.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "['sadness']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "['sadness']\n", + "Sucessfully formatted dataset 'emotion'.\n", + "Skipping formatting dataset 'enron_spam': Description file not found.\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "[\"The label 'joy' encompasses examples demonstrating feelings of happiness, satisfaction, gladness, or positive emotional states experienced by individuals.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "[\"The label 'joy' encompasses examples demonstrating feelings of happiness, satisfaction, gladness, or positive emotional states experienced by individuals.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "['spam']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "1\n", + "['spam']\n", + "Sucessfully formatted dataset 'enron_spam'.\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 0 0 1]\n", + "[\"The label '4': 'very positive' is used for data samples that express strong or intense positive sentiments, enthusiasm, or approval.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 1 0 0 0]\n", + "[\"The label '1', 'negative', is used for reviews or comments that express dissatisfaction, disapproval, or disappointment regarding a subject.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 0 0 1]\n", + "['very positive']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 1 0 0 0]\n", + "['negative']\n", + "Sucessfully formatted dataset 'sst5'.\n", + "Skipping formatting dataset 'amazon_counterfactual': Key 'train' and/or 'test' not found.\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "[\"The essence of the label '0': 'sadness' is characterized by feelings of hopelessness, disappointment, melancholy, and vulnerability, often accompanied by a sense of isolation or being overwhelmed.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "[\"The essence of the label '0': 'sadness' is characterized by feelings of hopelessness, disappointment, melancholy, and vulnerability, often accompanied by a sense of isolation or being overwhelmed.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "['sadness']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[1 0 0 0 0 0]\n", + "['sadness']\n", + "Sucessfully formatted dataset 'emotion'.\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 1 0]\n", + "[\"The label '2': 'Business' encompasses news and information related to commerce, trade, financial markets, companies, and economic trends.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 1 0]\n", + "[\"The label '2': 'Business' encompasses news and information related to commerce, trade, financial markets, companies, and economic trends.\"]\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 1 0]\n", + "['Business']\n", + "Warning: Limiting dataset size to 250 elements for testing!\n", + "[0 0 1 0]\n", + "['Business']\n", + "Sucessfully formatted dataset 'ag_news'.\n" + ] + } + ], + "source": [ + "formatted_datasets = {}\n", + "def format_dataset(original_dataset, label_to_description=None) -> Dataset:\n", + " \"\"\"\n", + " Creates a Dataset object with label encoding and optional label descriptions.\n", + " \"\"\"\n", + " input_texts = [d['text'] for d in original_dataset]\n", + " raw_labels = [d['label'] for d in original_dataset]\n", + "\n", + " # Check if labels are binary (single value) or multi-class (list of labels)\n", + " if all(raw_label in [0,1] and not isinstance(raw_label, list) for raw_label in raw_labels):\n", + " # Binary case\n", + " label_encoder = LabelEncoder()\n", + " labels = label_encoder.fit_transform(raw_labels)\n", + " else:\n", + " # Multi-class case\n", + " label_encoder = MultiLabelBinarizer()\n", + " labels = label_encoder.fit_transform([raw_label] for raw_label in raw_labels)\n", + "\n", + " # Either select label text or label description for the 'label_description' text\n", + " if label_to_description is None:\n", + " label_descriptions = [[d['label_text']] for d in original_dataset]\n", + " else:\n", + " label_descriptions = [[label_to_description[str(d['label'])]] for d in original_dataset]\n", + "\n", + " # Limit to 250 elements for testing.\n", + " # TODO: Deal with error in setfit.\n", + " # Error occurrs in setfit.sampler, line 29: 'idxs = np.stack(np.triu_indices(n, k), axis=-1)'\n", + " # with n being the sample size, k=1 if sampled with replacedmed, 0 otherwise.\n", + " # Reason: Out-of memory. Latest numpy+setfit versions do not fix this.\n", + " input_texts = input_texts[:250]\n", + " labels = labels[:250]\n", + " label_descriptions = label_descriptions[:250]\n", + " print(\"Warning: Limiting dataset size to 250 elements for testing!\")\n", + " print(labels[0])\n", + " print(label_descriptions[0])\n", + " return Dataset.from_dict({\n", + " \"text\": input_texts,\n", + " \"label\": labels,\n", + " \"label_description\": label_descriptions\n", + " })\n", + "\n", + "for dataset_id in dataset_ids:\n", + " # Load label descriptions\n", + " description_file_path = os.path.join(data_dir_label_descriptions, label_description_file_template.format(dataset_id))\n", + " try:\n", + " with open(description_file_path, 'r') as f:\n", + " label_to_description = json.load(f)\n", + " except FileNotFoundError:\n", + " print(f\"Skipping formatting dataset '{dataset_id}': Description file not found.\")\n", + " \n", + " # Format train and validation datasets, one with the descriptions in \"label_description\", and one with the label texts instead.\n", + " try: \n", + " formatted_datasets[dataset_id] = {}\n", + " formatted_datasets[dataset_id][\"label_description\"] = {\n", + " \"train\": format_dataset(datasets_original[dataset_id][\"train\"], label_to_description),\n", + " \"test\": format_dataset(datasets_original[dataset_id][\"test\"], label_to_description)\n", + " }\n", + " formatted_datasets[dataset_id][\"label_text\"] = {\n", + " \"train\": format_dataset(datasets_original[dataset_id][\"train\"]),\n", + " \"test\": format_dataset(datasets_original[dataset_id][\"test\"])\n", + " }\n", + " print(f\"Sucessfully formatted dataset '{dataset_id}'.\")\n", + "\n", + " except KeyError as e:\n", + " print(f\"Skipping formatting dataset '{dataset_id}': Key 'train' and/or 'test' not found.\")\n", + " formatted_datasets.pop(dataset_id)" + ] + }, + { + "cell_type": "markdown", + "id": "852e0eef", + "metadata": {}, + "source": [ + "# Train & Evaluate FusionSent Model " + ] + }, + { + "cell_type": "markdown", + "id": "038dad2f", + "metadata": {}, + "source": [ + "*1. Set up the model, tokenizer, and training arguments.*" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "ffad2066", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "model_id = \"malteos/scincl\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "training_args = TrainingArguments(\n", + " batch_sizes=(10,15),\n", + " num_epochs=(1,3),\n", + " sampling_strategies=\"undersampling\",\n", + " use_setfit_body=False #In this experiment, we only want to evaluate different lavel_embedding submodels, so we dont need the 'setfit' body.\n", + " )\n", + "\n", + "def getFreshModel()->FusionSentModel:\n", + " return FusionSentModel.from_pretrained(pretrained_model_name_or_path=model_id, multi_target_strategy=\"one-vs-rest\")" + ] + }, + { + "cell_type": "markdown", + "id": "b355d506", + "metadata": {}, + "source": [ + "*2. Train and evaluate one dataset after another.*\n", + "\n", + "*Please choose an appropriate subset of all the datasets in `target_datasets`.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6b6e37a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "target_datatsets = dataset_ids[:1] #Select applicable datasets (only first for testing)\n", + "\n", + "for datatset_id in target_datatsets:\n", + " for dataset_key, dataset in formatted_datasets[dataset_id].items():\n", + " # Define Trainer and start training\n", + " trainer = Trainer(\n", + " model=getFreshModel(),\n", + " args=training_args,\n", + " train_dataset=dataset[\"train\"],\n", + " eval_dataset=dataset[\"test\"],\n", + " eval_metrics={\n", + " 'metric_names': ['f1', 'precision', 'recall', 'accuracy'],\n", + " 'metric_args': {'average': 'micro'}\n", + " }\n", + " )\n", + " print(f\"Training FusionSent on dataset '{dataset_id}', with {dataset_key}.\")\n", + " trainer.train()\n", + " # Evaluate the current model\n", + " eval_scores = trainer.evaluate(\n", + " x_eval=[item['text'] for item in dataset[\"test\"]],\n", + " y_eval=[item['label'] for item in dataset[\"test\"]]\n", + " )\n", + " print(f\"Evaluation results for '{dataset_id}' with {dataset_key}: {eval_scores}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/FusionSent_visualization.png b/FusionSent_visualization.png new file mode 100644 index 0000000..976e814 Binary files /dev/null and b/FusionSent_visualization.png differ diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 8d9bc09..98b3867 100644 --- a/README.md +++ b/README.md @@ -1 +1,130 @@ -# FusionSent \ No newline at end of file +# FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model + +Welcome to the FusionSent repository. FusionSent is an efficient few-shot learning model designed for multi-label classification of scientific documents with many classes. + +![Training Process of FusionSent](./FusionSent_visualization.png) + +**Figure 1**: The training process of FusionSent comprises three steps: + +1. Fine-tune two different sentence embedding models from the same Pre-trained Language Model (PLM), with parameters θ₁, θ₂ respectively. + - θ₁ is fine-tuned on pairs of training sentences using cosine similarity loss, and θ₂ is fine-tuned on pairs of training sentences and their corresponding label texts, using contrastive loss. + - Label texts can consist of simple label/class names or more extensive texts that semantically describe the meaning of a label/class. +2. Merge parameter sets θ₁, θ₂ into θ₃ using Spherical Linear Interpolation (SLERP). +3. Freeze θ₃ to embed the training sentences, which are then used as input features to train a classification head. + +By fine-tuning sentence embedding models using contrastive learning, FusionSent achieves high performance even with limited labeled data. The model initially leverages two distinct sub-models: one, using regular contrastive learning with item pairs (['setfit'](https://github.com/huggingface/setfit)), and another using label embeddings with class-description pairs ('label_embedding'). These two models are then fused, via (spherical) linear intterpolation, to create the robost FusionSent model that excels in diverse classification tasks. For detailed insights into the model and its performance, please refer to our [published paper](#). + +## Overview + +`FusionSent` is integrated with the [Hugging Face Hub](https://huggingface.co/) and provides two main classes: + +- **FusionSentModel**: This class encapsulates the dual fine-tuning process of the two sentence embedding models ('setfit, and 'label_embedding') and their fusion into a single model ('fusionsent'). It is the core model class for embedding sentences and performing classification tasks. +- **FusionTrainer**: Responsible for loading, cleaning, and preparing datasets for training and evaluation. + +## Installation + +To install the `fusionSent` package, use pip: + +```bash +pip install fusionsent +``` + +## Usage Example + +```python +from fusionsent.training_args import TrainingArguments +from fusionsent.modeling import FusionSentModel +from fusionsent.trainer import Trainer +from datasets import Dataset + +# Example dataset objects with sentences belonging to classes: ["Computer Science", "Physics", "Philosophy"] +train_dataset = Dataset.from_dict({ + "text": [ + "Algorithms and data structures form the foundation of computer science.", + "Quantum mechanics explores the behavior of particles at subatomic scales.", + "The study of ethics is central to philosophical inquiry." + ], + "label": [ + [1, 0, 0], # Computer Science + [0, 1, 0], # Physics + [0, 0, 1] # Philosophy + ], + "label_description": [ + ["Computer Science"], + ["Physics"], + ["Philosophy"] + ] +}) + +eval_dataset = Dataset.from_dict({ + "text": [ + "Artificial intelligence is transforming the landscape of technology.", + "General relativity revolutionized our understanding of gravity.", + "Epistemology questions the nature and limits of human knowledge." + ], + "label": [ + [1, 0, 0], # Computer Science + [0, 1, 0], # Physics + [0, 0, 1] # Philosophy + ], + "label_description": [ + ["Computer Science"], + ["Physics"], + ["Philosophy"] + ] +}) + +# Load the model. +model_id = "malteos/scincl" +model = FusionSentModel._from_pretrained(model_id=model_id) + +# Set training arguments. +training_args = TrainingArguments( + batch_sizes=(16, 1), + num_epochs=(1, 3), + sampling_strategies="undersampling" +) + +# Initialize trainer. +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset +) + +# Train the model. +trainer.train() + +# Evaluate the model. +eval_scores = trainer.evaluate( + x_eval=eval_dataset["text"], + y_eval=eval_dataset["label"] +) + +# Perform inference. +texts = [ + "Computational complexity helps us understand the efficiency of algorithms.", + "Thermodynamics studies the relationships between heat, work, and energy.", + "Existentialism explores the freedom and responsibility of individual existence." +] +predictions = model.predict(texts) +print(predictions) +``` + +For a more elaborate example, please refer to the [Jupyter notebook of a Description-Embedding Experiment](./Evaluate_Description-Embedding_Body.ipynb). + +## Citation + +If you use FusionSent in your research, please cite the following paper: + +```bibtex +@article{..., + title={...}, + author={...}, + journal={...}, + year={...} +} +``` + +For additional details and advanced configurations, please refer to the original paper linked at the beginning of this document. \ No newline at end of file diff --git a/fusionsent/__init__.py b/fusionsent/__init__.py new file mode 100644 index 0000000..5a7aafc --- /dev/null +++ b/fusionsent/__init__.py @@ -0,0 +1,4 @@ +from ._version import __version__ +from .modeling import FusionSentModel +from .trainer import Trainer +from .training_args import TrainingArguments \ No newline at end of file diff --git a/fusionsent/merging_methods.py b/fusionsent/merging_methods.py new file mode 100644 index 0000000..905e6e3 --- /dev/null +++ b/fusionsent/merging_methods.py @@ -0,0 +1,138 @@ +# This module provides different functionalities for merging two sets of model parameters. + +from typing import Union +import numpy as np +import torch + +def merge_models( + model_state_dict0, + model_state_dict1, + merging_method: str='slerp', + t: Union[float, np.ndarray]=0.5, + DOT_THRESHOLD: float = 0.9995, + eps: float = 1e-8 + ): + """ + Merges two model state dictionaries using a specified merging method. + + Args: + model_state_dict0: State dictionary of the first model. + model_state_dict1: State dictionary of the second model. + merging_method (str): Method to be used for merging (either 'slerp' [default], or 'lerp'). + t (Union[float, np.ndarray]): Interpolation factor, can be a float or ndarray. + DOT_THRESHOLD (float): Threshold to consider vectors as collinear (used only if merging_method = 'slerp'). + eps (float): Small value to prevent division by zero (used only if merging_method = 'slerp'). + + Returns: + fused_parameter_dict (dict): Dictionary containing the merged parameters. + """ + fused_parameter_dict = {} + if merging_method == 'slerp': + for key in model_state_dict1: + fused_parameter_dict[key] = _slerp(t=t, v0=model_state_dict0[key], v1=model_state_dict1[key], DOT_THRESHOLD=DOT_THRESHOLD, eps=eps) + elif merging_method == 'lerp': + for key in model_state_dict1: + fused_parameter_dict[key] = _lerp(t=t, v0=model_state_dict0[key], v1=model_state_dict1[key]) + else: + raise ValueError(f"'merging_method' has unsupported value '{merging_method}'. Choose either 'slerp' or 'lerp'.") + + return fused_parameter_dict + +def _lerp( + t: float, + v0: Union[np.ndarray, torch.Tensor], + v1: Union[np.ndarray, torch.Tensor] +) -> Union[np.ndarray, torch.Tensor]: + """ + Traditional linear interpolation of model parameters as simple weighted average. + + From: https://github.com/cg123/mergekit#linear + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 as interpolation or weighting factor. At t=0 will return v0, at t=1 will return v1. + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as colinear. Not recommended to alter this. + Returns: + v2 (np.ndarray or torch.Tensor, depending on the input vectors): Interpolation vector between v0 and v1 + """ + return (1 - t) * v0 + t * v1 + +def _slerp( + t: Union[float, np.ndarray], + v0: Union[np.ndarray, torch.Tensor], + v1: Union[np.ndarray, torch.Tensor], + DOT_THRESHOLD: float = 0.9995, + eps: float = 1e-8, +) -> Union[np.ndarray, torch.Tensor]: + """ + Spherical Linear Interpolation (SLERP) is a method used to smoothly interpolate between two vectors (i.e. model parameters). It maintains a constant rate of change and preserves the geometric properties of the spherical space in which the vectors reside. + + SLERP is implemented using the following steps: + + 1. Normalize the input vectors to unit length, ensuring they represent directions rather than magnitudes + 2. Calculate the angle between these vectors using their dot product. + 3. If the vectors are nearly collinear, it defaults to linear interpolation for efficiency. Otherwise, SLERP computing scale factors based on the interpolation factor t (t=0 = 100% of the first vector, t=1 = 100% of model 2) and the angle between the vectors. + 4. These factors are used to weigh the original vectors, which are then summed to obtain the interpolated vector. + + There are several reasons to prefer SLERP over a traditional linear interpolation. For example, in high-dimensional spaces, linear interpolation can lead to a decrease in the magnitude of the interpolated vector (i.e., it reduces the scale of weights). Moreover, the change in direction of the weights often represents more meaningful information (like feature learning and representation) than the magnitude of change. + + From: https://github.com/cg123/mergekit#slerp + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 as interpolation or weighting factor. At t=0 will return v0, at t=1 will return v1. + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as colinear. Not recommended to alter this. + Returns: + v2 (np.ndarray or torch.Tensor, depending on the input vectors): Interpolation vector between v0 and v1 + """ + is_torch = False + if not isinstance(v0, np.ndarray): + is_torch = True + v0 = v0.detach().cpu().float().numpy() + if not isinstance(v1, np.ndarray): + is_torch = True + v1 = v1.detach().cpu().float().numpy() + + # Copy the vectors to reuse them later + v0_copy = np.copy(v0) + v1_copy = np.copy(v1) + + # Normalize the vectors to get the directions and angles + v0 = _normalize(v0, eps) + v1 = _normalize(v1, eps) + + # Dot product with the normalized vectors (can't use np.dot in W) + dot = float(np.sum(v0 * v1)) + + # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp + if np.abs(dot) > DOT_THRESHOLD: + res = _lerp(t, v0_copy, v1_copy) + return _maybe_torch(res, is_torch) + + # Calculate initial angle between v0 and v1 + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + + # Angle at timestep t + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + + # Finish the slerp algorithm + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + res = s0 * v0_copy + s1 * v1_copy + + return _maybe_torch(res, is_torch) + +def _maybe_torch(v: np.ndarray, is_torch: bool) -> Union[np.ndarray, torch.Tensor]: + if not isinstance(v, np.ndarray): + v = np.array(v) + if is_torch: + return torch.from_numpy(v) + return v + +def _normalize(v: np.ndarray, eps: float) -> np.ndarray: + norm_v = np.linalg.norm(v) + if norm_v > eps: + v = v / norm_v + return v \ No newline at end of file diff --git a/fusionsent/modeling.py b/fusionsent/modeling.py new file mode 100644 index 0000000..4bda3c5 --- /dev/null +++ b/fusionsent/modeling.py @@ -0,0 +1,274 @@ +# This module contains the actual FusionSent model, with additional sub-models for different prediction strategies and classification heads. + +from typing import Callable, Dict, List, Optional, Union +import warnings +from packaging.version import Version, parse +from dataclasses import dataclass, field +import copy +import numpy as np +import torch +from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import PyTorchModelHubMixin +from sentence_transformers import SentenceTransformer, models +from sentence_transformers import __version__ as sentence_transformers_version +from sklearn.linear_model import LogisticRegression +from sklearn.multiclass import OneVsRestClassifier +from sklearn.multioutput import ClassifierChain, MultiOutputClassifier + +class FusionModelBody: + """ + This class encapsulates the dual encoder bodies of all variants ('setfit', 'label_embedding', 'fusion') for the FusionSent model. + + Attributes: + setfit_model_body (SentenceTransformer): A copy of the SentenceTransformer model for setfit. + label_embedding_model_body (SentenceTransformer): A copy of the SentenceTransformer model for label embedding. + fusion_model_body (SentenceTransformer): A copy of the SentenceTransformer model for fusion. + """ + + def __init__(self, model: SentenceTransformer): + self.setfit_model_body = copy.deepcopy(model) + self.label_embedding_model_body = copy.deepcopy(model) + self.fusion_model_body = copy.deepcopy(model) + + +class FusionModelHead: + """ + This class to encapsulate the classification heads for all variants of encoder bodies ('setfit', 'label_embedding', 'fusion') of the FusionSent model. + + Attributes: + setfit_model_head (Callable): A copy of the classification head for setfit. + label_embedding_model_head (Callable): A copy of the classification head for label embedding. + fusion_model_head (Callable): A copy of the classification head for fusion. + """ + + def __init__(self, model: Callable): + self.setfit_model_head = copy.deepcopy(model) + self.label_embedding_model_head = copy.deepcopy(model) + self.fusion_model_head = copy.deepcopy(model) + +@dataclass +class FusionSentModel(PyTorchModelHubMixin): + """ + This data class for the FusionSent model includes model bodies and heads for different prediction strategies. + + The FusionSentModel is designed to encapsulate three separate sub-models, each with a pretrained language model at its core, and a linear classification head on top: + - `setfit`: An encoder (body) intended to be trained contrastivley, with regular (item, item)-pairs (adapted from https://github.com/huggingface/setfit). + - `label_embedding`: An encoder (body) intended to be trained with pairs of (class-descriptions, item)-pairs. + - `fusion`: An encoder (body) that is the result of an (spherical) linear interpolation between the parameters of both the `setfit` and `label_embedding` sub-models. + + Each sub-model makes up a unique 'prediction strategy'. I.e., each sub-model (encoder + classification head) can be selected at runtime to be used. + Only one sub-model can be selected at any given time. + + Attributes: + model_body (FusionModelBody): An instance of FusionModelBody containing the model bodies ('fusion', 'label_embedding', 'setfit'). + model_head (FusionModelHead): An instance of FusionModelHead containing the model heads ('fusion', 'label_embedding', 'setfit'). + multi_target_strategy (Optional[str]): The strategy for handling multi-target classification ('one-vs-rest', 'multi-output', or 'classifier-chain'). + prediction_strategy (Optional[str]): The current prediction strategy ('fusion', 'label_embedding', 'setfit'). + sentence_transformers_kwargs (Dict): Additional keyword arguments for SentenceTransformer implementation. + transformers_config (Optional[Dict]): Configuration for the transformer implementation. + """ + + model_body: Optional[FusionModelBody] = None + model_head: Optional[FusionModelHead] = None + multi_target_strategy: Optional[str] = None + prediction_strategy: Optional[str] = None + sentence_transformers_kwargs: Dict = field(default_factory=dict, repr=False) + transformers_config: Optional[Dict] = None + + def get_prediction_strategy(self)->str: + """" + Returns the prediction strategy for the model body. If not `None`, it can be either `fusion`, `label_embedding` or `setfit`. + """ + return self.prediction_strategy + + def set_prediction_strategy(self, prediction_strategy: str)->None: + """ + Sets the prediction strategy of the model body. If not `None`, it can be either `fusion`, `label_embedding` or `setfit`. + Args: + prediction_strategy (`str`): A string representing the prediction strategy of the model. If not `None`, it can be either `fusion`, `label_embedding` or `setfit`. + """ + self.prediction_strategy = prediction_strategy + + def encode(self, texts: List[str], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))->np.ndarray: + """ + Convert input texts to embeddings using the SentenceTransformer dual encoder body. + + Args: + texts (`List[str]`): A list of texts to encode. + """ + if self.get_prediction_strategy() is None or self.get_prediction_strategy() == "fusion": + # get fusion embeddings + embeddings = self.get_fusion_embeddings(texts, device=device) + elif self.get_prediction_strategy() == "setfit": + # get SetFit embeddings + embeddings = self.get_setfit_embeddings(texts, device=device) + elif self.get_prediction_strategy() == "label_embedding": + # get label embeddings + embeddings = self.get_label_embeddings(texts, device=device) + + return embeddings + + def get_fusion_embeddings(self, texts: List[str], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))->np.ndarray: + """ + Convert input texts to embeddings using the fusion model body. + + Args: + texts (`List[str]`): A list of texts to encode. + """ + # get embeddings from fusion body + return self.model_body.fusion_model_body.encode(texts, device=device) + + def get_setfit_embeddings(self, texts: List[str], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))->np.ndarray: + """ + Convert input texts to embeddings using the SetFit model body. + + Args: + texts (`List[str]`): A list of texts to encode. + """ + # get embeddings from SetFit body + return self.model_body.setfit_model_body.encode(texts, device=device) + + def get_label_embeddings(self, texts: List[str], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))->np.ndarray: + """ + Convert input texts to embeddings using the label embeddings model body. + + Args: + texts (`List[str]`): A list of texts to encode. + """ + # get embeddings from label embeddings body + return self.model_body.label_embedding_model_body.encode(texts, device=device) + + def predict(self, texts: List[str], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))->np.ndarray: + """ + Predict classes of input texts. + + Args: + texts (`List[str]`): A list of texts for classification. + """ + # encode texts as input features for classification head + features = self.encode(texts, device=device) + + # classify texts + if self.get_prediction_strategy() is None or self.get_prediction_strategy() == "fusion": + # get fusion head embeddings + predictions = self.model_head.fusion_model_head.predict(features) + elif self.get_prediction_strategy() == "setfit": + # get SetFit head predictions + predictions = self.model_head.setfit_model_head.predict(features) + elif self.get_prediction_strategy() == "label_embedding": + # get label embedding head predictions + predictions = self.model_head.label_embedding_model_head.predict(features) + + return predictions + + @classmethod + @validate_hf_hub_args + def _from_pretrained( + cls, + model_id: str, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + resume_download: Optional[bool] = None, + local_files_only: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + multi_target_strategy: Optional[str] = None, + prediction_strategy: Optional[str] = None, + device: Optional[Union[torch.device, str]] = None, + trust_remote_code: bool = False, + **model_kwargs, + ) -> 'FusionSentModel': + """ + Internal method to load a pretrained FusionSent model from the Hugging Face Hub. + + This method is called by the Hugging Face Hub framework and should not be modified by the user. + It initializes the FusionSent model with pretrained components and configuration from the Hugging Face Hub. + + Args: + model_id (str): The ID of the model on the Hugging Face Hub. + cache_dir (Optional[str], optional): Directory to cache the model. + token (Optional[Union[bool, str]], optional): Token for accessing the Hub. + multi_target_strategy (Optional[str], optional): Strategy for multi-target classification ('one-vs-rest', 'multi-output', or 'classifier-chain'). + prediction_strategy (Optional[str], optional): The prediction strategy to use. + device (Optional[Union[torch.device, str]], optional): The device to use for the model. + trust_remote_code (bool, optional): Whether to trust custom code from the model repo. + **model_kwargs: Additional keyword arguments for the model. + + Returns: + FusionSentModel: The loaded FusionSent model. + """ + # Warn if any unused arguments are provided. -- Disabled this, because it will always be passed by parent class. + # unused_args = [ + # ('revision', revision), + # ('force_download', force_download), + # ('proxies', proxies), + # ('resume_download', resume_download), + # ('local_files_only', local_files_only) + # ] + # for arg_name, arg_value in unused_args: + # if arg_value is not None: + # warnings.warn(f"The '{arg_name}' argument is not used by 'FusionSentModel', and will have no effect.", UserWarning, stacklevel=2) + + #Setup additional arguments for sentence-transformer. + sentence_transformers_kwargs = { + "cache_folder": cache_dir, + "use_auth_token": token, + "device": device, + "trust_remote_code": trust_remote_code, + } + if parse(sentence_transformers_version) >= Version("2.3.0"): + sentence_transformers_kwargs = { + "cache_folder": cache_dir, + "token": token, + "device": device, + "trust_remote_code": trust_remote_code, + } + else: + if trust_remote_code: + raise ValueError( + "The `trust_remote_code` argument is only supported for `sentence-transformers` >= 2.3.0." + ) + sentence_transformers_kwargs = { + "cache_folder": cache_dir, + "use_auth_token": token, + "device": device, + } + + #Load model components. + word_embedding_model = models.Transformer(model_id) + pooling_model = models.Pooling(word_embedding_dimension=word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') + sentence_transformer = SentenceTransformer(modules=[word_embedding_model, pooling_model], **sentence_transformers_kwargs) + model_body = FusionModelBody(sentence_transformer) + + #Set device. + if parse(sentence_transformers_version) >= Version("2.3.0"): + device = sentence_transformer.device + else: + device = sentence_transformer._target_device + + #Configure classification-heads. + head_params = model_kwargs.pop("head_params", {}) + clf = LogisticRegression(**head_params) + if multi_target_strategy is not None: + if multi_target_strategy == "one-vs-rest": + multilabel_classifier = OneVsRestClassifier(clf) + elif multi_target_strategy == "multi-output": + multilabel_classifier = MultiOutputClassifier(clf) + elif multi_target_strategy == "classifier-chain": + multilabel_classifier = ClassifierChain(clf) + else: + raise ValueError(f"multi_target_strategy {multi_target_strategy} is not supported.") + + model_head = FusionModelHead(multilabel_classifier) + else: + model_head = FusionModelHead(clf) + + return cls( + model_body=model_body, + model_head=model_head, + multi_target_strategy=multi_target_strategy, + prediction_strategy=prediction_strategy, + sentence_transformers_kwargs=sentence_transformers_kwargs, + **model_kwargs, + ) \ No newline at end of file diff --git a/fusionsent/trainer.py b/fusionsent/trainer.py new file mode 100644 index 0000000..e4af74e --- /dev/null +++ b/fusionsent/trainer.py @@ -0,0 +1,676 @@ +#This module contains the Trainer class, responsible for managing the training and evaluation process for FusionSent. + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import warnings +import logging +import math +import random +import numpy as np +import torch +from torch import nn +from torch.utils.data import DataLoader +from transformers.trainer_utils import set_seed +from datasets import Dataset +from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score +from setfit.trainer import ColumnMappingMixin +from setfit.sampler import ContrastiveDataset +from sentence_transformers.datasets import SentenceLabelDataset +from sentence_transformers import InputExample, losses +from setfit.losses import SupConLoss +import gc +import json + +from .training_args import TrainingArguments +from .modeling import FusionSentModel +from .merging_methods import merge_models + +logging.basicConfig() +logger = logging.getLogger('FusionSent') +logger.setLevel(logging.INFO) + +class Trainer(ColumnMappingMixin): + """ + The Trainer class is responsible for managing the training and evaluation process for the FusionSent model. + + It facilitates the training of two distinct sub-models ('setfit' and 'label_embedding') and merges their parameters + into the unified FusionSent model. This class handles the preparation of datasets, the configuration of training parameters, + and the execution of training and evaluation routines. + """ + + DEFAULT_EVAL_METRICS = {'metric_names': ['f1', 'precision', 'recall', 'accuracy'], 'metric_args': {'average': 'micro'}} + + def __init__( + self, + model: FusionSentModel = None, + args: Optional[TrainingArguments] = TrainingArguments(), + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Dataset] = None, + eval_metrics: Optional[Dict[List,Dict]] = DEFAULT_EVAL_METRICS, + column_mapping: Optional[Dict[str, str]] = None, + ) -> None: + """ + Initializes the Trainer class with the provided FusionSent model, training arguments, datasets, evaluation metrics, and column mapping. + + Args: + model (FusionSentModel): The FusionSent model to be trained. If not provided, raises a RuntimeError. + args (Optional[TrainingArguments]): Configuration for training parameters. If not provided, the default setting of 'TrainingArguments' will be used. + train_dataset (Optional[Dataset]): The dataset used for training. If provided, applies column mapping if necessary. + eval_dataset (Optional[Dataset]): The dataset used for evaluation. If provided, applies column mapping if necessary. + eval_metrics (Optional[Dict[List, Dict]]): A dictionary specifying the evaluation metrics and their arguments. Defaults to evaluating f1, precision, recall, and accuracy with 'micro' averaging. + If not provided or ill-formatted, default metrics will be used. Example format: + { + 'metric_names': ['f1', 'precision', 'recall', 'accuracy'], + 'metric_args': {'average': 'micro'} + } + column_mapping (Optional[Dict[str, str]]): A mapping of dataset columns to the expected input columns. + + Raises: + ValueError: If the TrainingArguments are ill-formatted. + RuntimeError: If the `model` parameter is not provided, or not of type 'FusionSentModel'. + """ + + #Verify that a model has been given. + if model is None: + raise ValueError("`Trainer` requires a `model` argument.") + if not isinstance(model, FusionSentModel): + raise ValueError("`Trainer` requires a `model` argument of type 'FusionSentModel'.") + set_seed(12) # Seed must be set before instantiating the model when using model_init. + self.model = model + + #Initialize 'TrainingArguments' from given input (or as default), and validate them. + if args is not None and not isinstance(args, TrainingArguments): + raise ValueError("`args` must be a `TrainingArguments` instance imported from `FusionSent`.") + self.args = args + self.args._validate() + + #Assign and validate evaluation metrics, if given. + self.eval_metrics: Dict[List,Dict] = eval_metrics + self._validate_eval_metrics() + + #Apply column mapping to 'train_dataset' if necessary. + self.column_mapping = column_mapping + if train_dataset: + self._validate_column_mapping(train_dataset) + if self.column_mapping is not None: + logger.info("Applying column mapping to the training dataset") + train_dataset = self._apply_column_mapping(train_dataset, self.column_mapping) + self.train_dataset = train_dataset + + #Apply column mapping to 'eval_dataset' if necessary. + if eval_dataset: + self._validate_column_mapping(eval_dataset) + if self.column_mapping is not None: + logger.info("Applying column mapping to the evaluation dataset") + eval_dataset = self._apply_column_mapping(eval_dataset, self.column_mapping) + self.eval_dataset = eval_dataset + + + def _dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: + """ + Converts the provided dataset into a list of parameters required for training. + + Args: + dataset (Dataset): The dataset to be converted. + Expected to contain the keys 'text', 'label_description', and 'label'. + + Returns: + List[Iterable]: A list containing three elements: + - A list of texts from the dataset. + - A list of label descriptions from the dataset. + - A list of labels from the dataset. + """ + return [dataset["text"], dataset["label_description"], dataset["label"]] + + @staticmethod + def _has_any_multilabel(examples: List[InputExample]) -> bool: + """ + Determines if any of the input examples represent a multi-label scenario. + + Args: + examples (List[InputExample]): List of InputExample instances to check. + + Returns: + bool: True if any example has a non-binary label or if any label is a list or array, False otherwise. + """ + for example in examples: + label = example.label + + # Check if label is a list, tuple, or numpy array (multi-label scenario) + if isinstance(label, (list, tuple, np.ndarray)): + return True + + # Check if label is not binary (i.e., not 0 or 1) + if isinstance(label, (int, float)) and label not in {0, 1}: + return True + return False + + def _get_setfit_dataloader( + self, + x: List[str], + y: Union[List[int], List[List[int]]], + args: TrainingArguments, + max_pairs: int = -1 + ) -> Tuple[DataLoader, nn.Module, int]: + """ + Prepares a DataLoader and corresponding loss function for training the 'setfit' sub-model. + + Args: + x (List[str]): A list of input texts. + y (Union[List[int], List[List[int]]]): A list of binary- or multi-class labels corresponding to the input texts. + args (TrainingArguments): The training arguments configuration. + max_pairs (int, optional): Maximum number of pairs for contrastive sampling. Default is -1, which means no limit. + + Returns: + Tuple[DataLoader, nn.Module, int]: A tuple containing: + - DataLoader: The DataLoader for the 'setfit' sub-model. + - nn.Module: The loss function for the 'setfit' sub-model. + - int: The batch size used for the DataLoader. + """ + + # Adapt input data for sentence-transformers. + input_data = [InputExample(texts=[text], label=label) for text, label in zip(x, y)] + + if args.setfit_loss in [ + losses.BatchAllTripletLoss, + losses.BatchHardTripletLoss, + losses.BatchSemiHardTripletLoss, + losses.BatchHardSoftMarginTripletLoss, + SupConLoss, + ]: + data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.setfit_samples_per_label) + batch_size = min(args.setfit_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=True) + + if args.setfit_loss is losses.BatchHardSoftMarginTripletLoss: + loss = args.setfit_loss( + model=self.model.model_body.setfit_model_body, + distance_metric=args.setfit_distance_metric, + ) + elif args.setfit_loss is SupConLoss: + loss = args.setfit_loss(model=self.model.model_body.setfit_model_body) + else: + loss = args.setfit_loss( + model=self.model.model_body.setfit_model_body, + distance_metric=args.setfit_distance_metric, + margin=args.setfit_margin, + ) + else: + data_sampler = ContrastiveDataset( + examples=input_data, + multilabel=Trainer._has_any_multilabel(input_data), + num_iterations=args.num_iterations, + sampling_strategy=args.setfit_sampling_strategy, + max_pairs=max_pairs, + ) + batch_size = min(args.setfit_batch_size, len(data_sampler)) + dataloader = DataLoader(data_sampler, batch_size=batch_size, drop_last=False) + loss = args.setfit_loss(self.model.model_body.setfit_model_body) + + return dataloader, loss, batch_size + + def _get_label_embedding_dataloader( + self, + texts: List[str], + label_descriptions: List[str], + args: TrainingArguments + ) -> Tuple[DataLoader, nn.Module, int]: + """ + Prepares a DataLoader and corresponding loss function for training the 'label_embedding' sub-model. + + Note that remaining TODO's include: + - Adding additional Sentence-Transformers losses + - Reimplementing sampling strategies to support oversampling of negatives and undersampling of positives. + + Args: + texts (List[str]): A list of input texts. + label_descriptions (List[str]): A list of label descriptions corresponding to the input texts. + args (TrainingArguments): The training arguments configuration. + + Returns: + Tuple[DataLoader, nn.Module, int]: A tuple containing: + - DataLoader: The DataLoader for the 'label_embedding' sub-model. + - nn.Module: The loss function for the 'label_embedding' sub-model. + - int: The batch size used for the DataLoader. + """ + + #TODO: add remaining ST losses + #TODO: reimplement sampling strategies to support oversampling of negatives and undersampling of positives + if args.label_embedding_loss is losses.MultipleNegativesRankingLoss: + # create default dataloader with positives only + input_data = [] + for i, text in enumerate(texts): + for label_description in label_descriptions[i]: + input_data.append(InputExample(texts=[text, label_description])) + + elif args.label_embedding_loss is losses.TripletLoss: + if args.label_embedding_sampling_strategy == "oversampling": + # create dataloader for triplet loss with oversampling of positives + input_data = [] + unique_labels = set([x for xs in label_descriptions for x in xs]) + for i, text in enumerate(texts): + negative_labels = unique_labels - set(label_descriptions[i]) + # oversample positive label descriptions + positive_label_description_samples = random.choices(label_descriptions[i], k=len(negative_labels)) + for x in range(len(negative_labels)): + input_data.append(InputExample(texts=[text, positive_label_description_samples[x], list(negative_labels)[x]])) + elif args.label_embedding_sampling_strategy == "undersampling": + # create dataloader for triplet loss with undersampling of negatives + input_data = [] + unique_labels = set([x for xs in label_descriptions for x in xs]) + for i, text in enumerate(texts): + negative_labels = unique_labels - set(label_descriptions[i]) + # undersample negative label description + negative_label_description_samples = random.sample(list(negative_labels), len(label_descriptions[i])) + for x in range(len(label_descriptions[i])): + input_data.append( + InputExample(texts=[text, label_descriptions[i][x], negative_label_description_samples[x]])) + + + elif args.label_embedding_loss in [losses.ContrastiveLoss,losses.CosineSimilarityLoss,losses.OnlineContrastiveLoss]: + if args.label_embedding_sampling_strategy == "oversampling": + # create dataloader for contrastive learning with oversampling of positives + input_data = [] + unique_labels = set([x for xs in label_descriptions for x in xs]) + for i, text in enumerate(texts): + negative_labels = unique_labels - set(label_descriptions[i]) + # add positive label descriptions for anchor text + for positive_label_description in label_descriptions[i]: + input_data.append(InputExample(texts=[text, positive_label_description], label=1.0)) + + # add negative label descriptions for anchor text + for negative_label_description in list(negative_labels): + input_data.append(InputExample(texts=[text, negative_label_description], label=0.0)) + + # oversample positive label descriptions + positive_label_description_samples = random.choices(label_descriptions[i], k=len(negative_labels)-1) + for positive_label_description in positive_label_description_samples: + input_data.append(InputExample(texts=[text, positive_label_description], label=1.0)) + + elif args.label_embedding_sampling_strategy == "undersampling": + # create dataloader for contrastive learning with undersampling of negatives + input_data = [] + unique_labels = set([x for xs in label_descriptions for x in xs]) + for i, text in enumerate(texts): + negative_labels = unique_labels - set(label_descriptions[i]) + # add positive label descriptions for anchor text + for positive_label_description in label_descriptions[i]: + input_data.append(InputExample(texts=[text, positive_label_description], label=1.0)) + + # add negative label descriptions for anchor text + negative_label_description_samples = random.sample(list(negative_labels), + len(label_descriptions[i])) + for negative_label_description in negative_label_description_samples: + input_data.append(InputExample(texts=[text, negative_label_description], label=0.0)) + + data_sampler = SentenceLabelDataset(input_data, samples_per_label=args.label_embedding_samples_per_label) + batch_size = min(args.label_embedding_batch_size, len(data_sampler)) + dataloader = DataLoader(input_data, shuffle=True, batch_size=batch_size) + loss = args.label_embedding_loss(self.model.model_body.label_embedding_model_body) + + return dataloader, loss, batch_size + + def _validate_eval_metrics( + self, + other: Optional[Dict[List,Dict]] = None + ): + """ + Validates the local evaluation metrics to ensure they contain at least one of the valid evaluation arguments: + `f1`, `precision`, `recall`, `accuracy`. + + Args: + other (Optional[Dict[List,Dict]]): An alternative set of evaluation metrics to validate. + + Raises: + ValueError: If the evaluation metrics do not contain at least one of the valid evaluation arguments. + """ + valid_metrics = set(['f1', 'precision', 'recall', 'accuracy']) + if other is not None and 'metric_names' in other.keys(): + provided_metrics = set(other['metric_names']) + else: + provided_metrics = set(self.eval_metrics.get('metric_names', [])) + + if not provided_metrics.intersection(valid_metrics): + raise ValueError( + "'eval_metrics' did not contain at least one of the following valid values under key 'metric_names': `f1`, `precision`, `recall`, `accuracy`." + ) + + def _has_evaluation_setting(self) -> bool: + """ + Returns a boolean indicating wether this trainer instance could perform an evaluation (has been given an evaluation dataset and metrics). + """ + return self.eval_dataset and self.eval_metrics + + def train( + self, + args: Optional[TrainingArguments] = None, + trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, + **kwargs, + ) -> None: + """ + This function represents the main training entry point. + + Note that evaluation will be perfomed automatically, iff a dectionary of evaluation metrics and an evaluation datatset has been provided at initialization of this instance. + Additionally, evaluation can always be carried out manually via the 'evaluate' method. + + Args: + args (Optional[TrainingArguments]): Training arguments to temporarily override the default training arguments for this call. + trial (Optional[Union["optuna.Trial", Dict[str, Any]]]): The trial run or hyperparameter dictionary for hyperparameter search. + + Raises: + ValueError: If `train_dataset` is not provided, or 'TrainingArguments' is not None and ill-formatted. + """ + if len(kwargs): + warnings.warn( + f"`{self.__class__.__name__}.train` does not accept keyword arguments anymore. " + f"Please provide training arguments via a `TrainingArguments` instance to the `{self.__class__.__name__}` " + f"initialisation or the `{self.__class__.__name__}.train` method.", + DeprecationWarning, + stacklevel=2, + ) + + #Assign and validate training arguments. + args = args or self.args or TrainingArguments() + self.args._validate() + + #Check for existing training dataset. + if self.train_dataset is None: + raise ValueError( + f"Training requires a `train_dataset` given to the `{self.__class__.__name__}` initialization." + ) + + #Initialize trainer parameters and model for hp-search, if applicable. + if trial: + self._hp_search_setup(trial) + + #Construct train parameters + train_parameters = self._dataset_to_parameters(self.train_dataset) + + #Train model body + self._train_sentence_transformers_body(*train_parameters, args=args) + + #If evaluation dataset and metrics are given... + if self._has_evaluation_setting(): + #...train the model head, ... + self._train_classifier_head( + x_train_texts=train_parameters[0], + y_train=train_parameters[2], + x_eval=self.eval_dataset['text'], + y_eval=self.eval_dataset['label'], + eval_metrics=self.eval_metrics, + args=args + ) + + #...and evaluate the whole model. + logger.info(" ***** Running evaluation on `eval_dataset` *****") + self.eval_scores = self.evaluate( + x_eval=self.eval_dataset['text'], + y_eval=self.eval_dataset['label'], + eval_metrics=self.eval_metrics + ) + return self.eval_scores + else: + #Else, train only the model head, without evaluation. + self._train_classifier_head( + x_train_texts=train_parameters[0], + y_train=train_parameters[2], + args=args + ) + return None + + def _train_sentence_transformers_body( + self, + x_train_texts: List[str], + x_train_label_descriptions: List[List[str]], + y_train: Optional[Union[List[int], List[List[int]]]] = None, + args: Optional[TrainingArguments] = None + ) -> None: + """ + Trains both dual encoder `SentenceTransformer` bodies of the sub-models ('setfit' and 'label_embeding') for the embedding training phase. + After training, it merges the parameters of both sub-models into the final encoder body of FusionSent. + + Args: + x_train_texts (List[str]): A list of training texts. + x_train_label_descriptions (List[List[str]]): A list of lists including label descriptions for each positive label per training text. + y_train (Union[List[int], List[List[int]]], optional): A list of labels corresponding to the training texts. + args (TrainingArguments, optional): Temporarily change the training arguments for this training call. If not provided, default training arguments will be used. + + Raises: + ValueError: If 'args' is not None and ill-formatted. + """ + args = args or self.args or TrainingArguments() + args._validate() + + logger.info(" ***** Preparing training dataset *****") + + #Construct dataset for SetFit body training + setfit_train_dataloader, setfit_loss_func, setfit_batch_size = self._get_setfit_dataloader( + x=x_train_texts,y=y_train, args=args + ) + + #Construct dataset for label embedding training + label_embedding_train_dataloader, label_embedding_loss_func, label_embedding_batch_size = self._get_label_embedding_dataloader( + texts=x_train_texts, label_descriptions=x_train_label_descriptions, args=args + ) + + #Compute total number of training steps. + setfit_total_train_steps = len(setfit_train_dataloader) * args.setfit_num_epochs + label_embeddings_total_train_steps = len(label_embedding_train_dataloader) * args.label_embedding_num_epochs + + #Log training statistics. + logger.info(" ***** Running sentence transformers body training *****") + logger.info(f" Total number of examples = {len(setfit_train_dataloader.dataset)} + {len(label_embedding_train_dataloader.dataset)}") + logger.info(f" Number of batches = {len(setfit_train_dataloader)} + {len(label_embedding_train_dataloader)}") + logger.info(f" Number of epochs = {args.setfit_num_epochs} + {args.label_embedding_num_epochs}") + logger.info(f" Train batch sizes = {setfit_batch_size} & {label_embedding_batch_size}") + logger.info(f" Total optimization steps = {setfit_total_train_steps} + {label_embeddings_total_train_steps}") + + #Train the setfit body (only if it is intended to be used). + if args.use_setfit_body: + setfit_warmup_steps = math.ceil( + setfit_total_train_steps * args.setfit_warmup_proportion + ) + self.model.model_body.setfit_model_body.fit( + train_objectives=[(setfit_train_dataloader, setfit_loss_func)], + epochs=args.setfit_num_epochs, warmup_steps=setfit_warmup_steps, + show_progress_bar=args.show_progress_bar + ) + setfit_loss_func.to('cpu') + self.model.model_body.setfit_model_body.to('cpu') + gc.collect() + with torch.no_grad(): + torch.cuda.empty_cache() + + #Train the label_embeddings body. + label_embeddings_warmup_steps = math.ceil( + label_embeddings_total_train_steps * args.label_embedding_warmup_proportion + ) + self.model.model_body.label_embedding_model_body.fit( + train_objectives=[(label_embedding_train_dataloader, label_embedding_loss_func)], + epochs=args.label_embedding_num_epochs, + warmup_steps=label_embeddings_warmup_steps, + show_progress_bar=args.show_progress_bar + ) + label_embedding_loss_func.to('cpu') + self.model.model_body.label_embedding_model_body.to('cpu') + gc.collect() + with torch.no_grad(): + torch.cuda.empty_cache() + + #Get parameters of both trained models + setfit_parameter_dict = dict( + self.model.model_body.setfit_model_body._first_module().auto_model.named_parameters() + ) + label_embedding_parameter_dict = dict( + self.model.model_body.label_embedding_model_body._first_module().auto_model.named_parameters() + ) + + #Fuse/merge model parameters with selected algorithm. + t = 0.5 if args.use_setfit_body else 0 + fused_parameter_dict = merge_models( + model_state_dict0=label_embedding_parameter_dict, + model_state_dict1=setfit_parameter_dict, + t=t, + merging_method=args.merging_method + ) + + #Initialize the body of the final FusionSent model with the fused model parameters. + fusion_state_dict = self.model.model_body.fusion_model_body._first_module().auto_model.state_dict() + for key in fusion_state_dict: + fusion_state_dict[key] = fused_parameter_dict[key] + self.model.model_body.fusion_model_body._first_module().auto_model.load_state_dict(fusion_state_dict) + + @staticmethod + def _ensure_single_label_format(labels: Union[List[int], List[List[int]]]): + """ + Helper function to convert a list of labels into single-label format, if neccesary. + """ + if isinstance(labels[0], list): + return np.argmax(labels, axis=1) + return labels + + def _train_classifier_head( + self, + x_train_texts: List[str], + y_train: Union[List[int], List[List[int]]], + x_eval: Optional[List[str]] = None, + y_eval: Optional[List[int]] = None, + eval_metrics: Optional[Dict[List, Dict]] = None, + args: Optional[TrainingArguments] = None, + ) -> None: + """ + Trains a classification head for each candidate model body (`setfit`, `label_embedding`, and their 'fusion'). + If evaluation metrics and dataset are provided, the performance of all final models will be evaluted and the best performing model is set as the default for further use. + + Note: Cross-Validation is yet to be implemented. + + Args: + x_train_texts (List[str]): A list of training texts. + y_train (Union[List[int], List[List[int]]]): A list of labels corresponding to the training texts. + x_eval (Optional[List[str]]): A list of evaluation texts. + y_eval (Optional[List[int]]): A list of labels corresponding to the evaluation texts. + eval_metrics (Optional[Dict[List, Dict]]): A dictionary specifying the evaluation metrics and their respective arguments. + If not provided, evaluation will be omitted. If ill-formatted, default metrics will be used. Example format: + { + 'metric_names': ['f1', 'precision', 'recall', 'accuracy'], + 'metric_args': {'average': 'micro'} + } + args (Optional[TrainingArguments]): Training arguments to temporarily override the default training arguments for this call. + """ + logger.info(" ***** Running classification head training *****") + y_train = Trainer._ensure_single_label_format(y_train) # Necessary for model head trainig. + + #Get embeddings from setfit body and train setfit model head. + self.model.set_prediction_strategy("setfit") + setfit_train_features = self.model.model_body.setfit_model_body.encode(x_train_texts, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + self.model.model_head.setfit_model_head.fit(setfit_train_features, y_train) + + #Get embeddings from label_embedding body and train label_embedding model head. + self.model.set_prediction_strategy("label_embedding") + label_embedding_train_features = self.model.model_body.label_embedding_model_body.encode(x_train_texts, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + self.model.model_head.label_embedding_model_head.fit(label_embedding_train_features, y_train) + + #Get embeddings from fusion body and train fusion model head. + self.model.set_prediction_strategy("fusion") + fusion_train_features = self.model.model_body.fusion_model_body.encode(x_train_texts, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) + self.model.model_head.fusion_model_head.fit(fusion_train_features, y_train) + + #Evaluate classifications with different body features and set the best performing one as default. + eval_dict = {} + if x_eval and y_eval and eval_metrics: + + #Use evaluation dataset to choose best performing features. + self.model.set_prediction_strategy("setfit") + setfit_eval_scores = self.evaluate(x_eval=x_eval, y_eval=y_eval, eval_metrics=eval_metrics) + print("SetFit eval scores:", setfit_eval_scores) + eval_dict["SetFit eval scores"] = setfit_eval_scores + self.model.set_prediction_strategy("label_embedding") + label_embedding_eval_scores = self.evaluate(x_eval=x_eval, y_eval=y_eval, eval_metrics=eval_metrics) + print("Label embedding eval scores:", label_embedding_eval_scores) + eval_dict["Label embedding eval scores"] = label_embedding_eval_scores + self.model.set_prediction_strategy("fusion") + fusion_eval_scores = self.evaluate(x_eval=x_eval, y_eval=y_eval, eval_metrics=eval_metrics) + print("Fusion eval scores:", fusion_eval_scores) + eval_dict["Fusion eval scores"] = fusion_eval_scores + + #Save evaluation dictionary, if path was provided. + if args.json_path is not None: + with open(args.json_path + '.json', 'w') as fp: + json.dump(eval_dict, fp) + + #choose best performing model from average of evaluation scores as default mode. + mean_eval_scores = {} + mean_eval_scores['fusion'] = np.mean(list(fusion_eval_scores.values())) + mean_eval_scores['label_embedding'] = np.mean(list(label_embedding_eval_scores.values())) + mean_eval_scores['setfit'] = np.mean(list(setfit_eval_scores.values())) + self.model.set_prediction_strategy(max(mean_eval_scores, key=mean_eval_scores.get)) + + else: + #TODO: Perform cross-validation to get best performing features + pass + + def evaluate( + self, x_eval: List[str], + y_eval: Union[List[int], List[List[int]]], + eval_metrics: Optional[Dict[List,Dict]] = None + ): + """ + Evaluates the performance of the full model on a given evaluation dataset. + Note that this depends on the model's current prediction_strategy (i.e. which encoder body it will use) to perform inference. + + Args: + x_eval (List[str]): A list of evaluation texts. + y_eval (Union[List[int], List[List[int]]]): A list of labels corresponding to the evaluation texts. + eval_metrics (Optional[Dict[List, Dict]]): A dictionary specifying the evaluation metrics and their respective arguments, to temporarily override the default (if any). + If not provided or ill-formatted, default metrics will be used. Example format: + { + 'metric_names': ['f1', 'precision', 'recall', 'accuracy'], + 'metric_args': {'average': 'micro'} + }. + + Returns: + Dict[str, float]: A dictionary containing the computed scores for each specified metric. + """ + + #If no eval_metrics were given, use the configured ones. + if not eval_metrics: + eval_metrics = self.eval_metrics + + #Validate eval_metrics and use the default if ill-formatted. + try: + self._validate_eval_metrics(eval_metrics) + except ValueError as e: + if "eval_metrics" in str(e): + eval_metrics = self.DEFAULT_EVAL_METRICS + warnings.warn( + "'eval_metrics' provided were ill-formatted. Falling back to default metrics.", + UserWarning, + stacklevel=2, + ) + + #Perform inference on the evaluation dataset. + y_pred = self.model.predict(x_eval) + y_true = Trainer._ensure_single_label_format(y_eval) + + #Correctly format eval_metrics if only a single one was given. + if isinstance(eval_metrics['metric_names'], str): + eval_metrics['metric_names'] = [eval_metrics['metric_names']] + + #Perform the evaluation. + eval_scores = {} + for metric in eval_metrics['metric_names']: + if metric == "f1": + eval_scores["f1"] = f1_score(y_true, y_pred, average=eval_metrics['metric_args']['average']) + elif metric == "precision": + eval_scores["precision"] = precision_score(y_true, y_pred, average=eval_metrics['metric_args']['average']) + elif metric == "recall": + eval_scores["recall"] = recall_score(y_true, y_pred, average=eval_metrics['metric_args']['average']) + elif metric == "accuracy": + eval_scores["accuracy"] = accuracy_score(y_true, y_pred) + + #Check wether evalaution metrics are present. Note: This must always be the case, if a successful validation has occured (so in theory, this exception should never be raised). + if not eval_scores: + raise ValueError( + "eval_metrics did not contain at least on of the following valid evaluation arguments: `f1`, `precision`, `recall`, `accuracy`." + ) + + return eval_scores \ No newline at end of file diff --git a/fusionsent/training_args.py b/fusionsent/training_args.py new file mode 100644 index 0000000..ae109d2 --- /dev/null +++ b/fusionsent/training_args.py @@ -0,0 +1,259 @@ +# This module encapsulates the set of training arguments that can be passed to the FusionSent model to specifiy training. + +from typing import Callable, Optional, Tuple, Union +from dataclasses import dataclass +from sentence_transformers import losses +import warnings + +@dataclass +class TrainingArguments: + """ + A dataclass containing all the arguments that can be passed to the FusionSent model to specifiy training. + Pass these either at model initialisation (to be the same for all training runs), or specifically, when calling the training method. + + FusionSent trains two distinct sub-models, 'setfit' and 'label_embedding', whichs parameters are then fused. + For customization purposes, the majority of training arguments can hence be given as a Tuple, in which the first and seccond components are destined to the "set-fit"- and "label-embedding"-submodel respectivley. + If only a single value is provided, it will be used for both sub-models. + + After instantiation, each model's specific traning arguments are referenceable through custom properties of this class. + Example: + batch_sizes[0] is addressed to 'setfit', accessible as property 'TrainingArguments.setfit_batch_size'. + batch_sizes[1] is addressed to 'label_embedding', accessible as property 'Trainingarguments.label_embedding_batch_size'. + + Attributes: + batch_sizes (Optional[Union[int, Tuple[int, int]]]): Batch sizes for training. Single integer for both sub-models, or a tuple, to address each one individually. Default is (16, 1). + num_epochs (Optional[Union[int, Tuple[int, int]]]): Number of epochs for training. Single integer for both sub-models, or a tuple, to address each one individually. Default is (1, 3). + sampling_strategies (Optional[Union[str, Tuple[str, str]]]): Sampling strategies for training data. Single string for both sub-models, or a tuple, to address each one individually. Choose either "oversampling" (Default), "unique", or "undersampling", respectivley. See 'setfit.ContrastiveDataset' for more details. + num_iterations (Optional[int]): Number of iterations for training. Always the same for both sub-models. + distance_metrics (Optional[Union[Callable, Tuple[Callable, Callable]]]): Distance metrics for loss functions. Single 'Callable' for both sub-models, or a tuple, to address each one individually. Default is cosine distance for triplet loss. + losses (Optional[Union[Callable, Tuple[Callable, Callable]]]): Loss functions for training. Single 'Callable' for both sub-models, or a tuple, to address each one individually Default is (CosineSimilarityLoss, ContrastiveLoss). + merging_method (Optional[str]): Method for merging the parameters of both sub-modules after training. Choose either 'slerp' (default), or 'lerp'. + margins (Optional[Union[float, Tuple[float, float]]]): Margin values for loss functions, to determine the threshold for considering examples as similar or dissimilar. Single float for both models, or a tuple, to address each one individually. Default is 0.25. + warmup_proportions (Optional[Union[float, Tuple[float, float]]]): Proportion of the total training steps used for warming up the learning rates. Single float for both models, or a tuple, to address each one individually. Default is 0.1. + samples_per_label (Optional[Union[int, Tuple[int, int]]]): Number of samples per label for training. A single integer for both models, or a tuple, to address each one individually. Default is 2. + show_progress_bar (Optional[bool]): Whether to show progress bar during training. Default is True. + use_setfit_body (Optional[bool]): Whether to train the 'setfit' submodel, and use its parameters in the merged FusionSent, or not. Use this when you only want to evaluate the 'label_embedding' sub-model. Default is True. + json_path (Optional[str]): Path to save evaluation results as JSON. + """ + + batch_sizes: Optional[Union[int, Tuple[int, int]]] = (16, 1) + num_epochs: Optional[Union[int, Tuple[int, int]]] = (1, 3) + sampling_strategies: Optional[Union[str, Tuple[str, str]]] = "oversampling" + num_iterations: Optional[int] = None + distance_metrics: Optional[Union[Callable, Tuple[Callable, Callable]]] = losses.BatchHardTripletLossDistanceFunction.cosine_distance + losses: Optional[Union[Callable, Tuple[Callable, Callable]]] = (losses.CosineSimilarityLoss, losses.ContrastiveLoss) + merging_method: Optional[str] = 'slerp' + margins: Optional[Union[float, Tuple[float, float]]] = 0.25 + warmup_proportions: Optional[Union[float, Tuple[float, float]]] = 0.1 + samples_per_label: Optional[Union[int, Tuple[int, int]]] = 2 + show_progress_bar: Optional[bool] = True + use_setfit_body: Optional[bool] = True + json_path: Optional[str] = None + + @property + def setfit_batch_size(self) -> int: + """ + Batch sizes for training the 'setfit' sub-model. + """ + if isinstance(self.batch_sizes, int): + return self.batch_sizes + else: + return self.batch_sizes[0] + + @property + def label_embedding_batch_size(self) -> int: + """ + Batch sizes for training the 'label_embedding' sub-model. + """ + if isinstance(self.batch_sizes, int): + return self.batch_sizes + else: + return self.batch_sizes[1] + + @property + def setfit_num_epochs(self) -> int: + """ + Number of epochs for training the 'setfit' sub-model. + """ + if isinstance(self.num_epochs, int): + return self.num_epochs + else: + return self.num_epochs[0] + + @property + def label_embedding_num_epochs(self) -> int: + """ + Number of epochs for training the 'label_embedding' sub-model. + """ + if isinstance(self.num_epochs, int): + return self.num_epochs + else: + return self.num_epochs[1] + + @property + def setfit_sampling_strategy(self) -> str: + """ + Sampling strategy for training data of the 'setfit' sub-model. + Either "oversampling" (Default), "unique", or "undersampling". + See 'setfit.ContrastiveDataset' for more details. + """ + if isinstance(self.sampling_strategies, str): + return self.sampling_strategies + else: + return self.sampling_strategies[0] + + @property + def label_embedding_sampling_strategy(self) -> str: + """ + Sampling strategy for training data of the 'label_embedding' sub-model. + Either "oversampling" (Default), "unique", or "undersampling". + See 'setfit.ContrastiveDataset' for more details. + """ + if isinstance(self.sampling_strategies, str): + return self.sampling_strategies + else: + return self.sampling_strategies[1] + + @property + def setfit_distance_metric(self) -> Callable: + """ + Distance metric for the loss function of the 'setfit' sub-model. + """ + if isinstance(self.distance_metrics, Callable): + return self.distance_metrics + else: + return self.distance_metrics[0] + + @property + def label_embedding_distance_metric(self) -> Callable: + """ + Distance metric for the loss function of the 'label_embedding' sub-model. + """ + if isinstance(self.distance_metrics, Callable): + return self.distance_metrics + else: + return self.distance_metrics[1] + + @property + def setfit_loss(self) -> Callable: + """ + Loss function for training the 'setfit' sub-model. + """ + if isinstance(self.losses, Callable): + return self.losses + else: + return self.losses[0] + + @property + def label_embedding_loss(self) -> Callable: + """ + Loss function for training the 'label_embedding' sub-model. + """ + if isinstance(self.losses, Callable): + return self.losses + else: + return self.losses[1] + + @property + def setfit_margin(self) -> float: + """ + Margin values for the loss function of the 'setfit' sub-model. + This determines the threshold for considering examples as similar or dissimilar. + """ + if isinstance(self.margins, float): + return self.margins + else: + return self.margins[0] + + @property + def label_embedding_margin(self) -> float: + """ + Margin values for the loss function of the 'label_embedding' sub-model. + This determines the threshold for considering examples as similar or dissimilar. + """ + if isinstance(self.margins, float): + return self.margins + else: + return self.margins[1] + + @property + def setfit_warmup_proportion(self) -> float: + """ + Proportion of the total training steps used for warming up the learning rate for training the 'setfit' sub-model. + """ + if isinstance(self.warmup_proportions, float): + return self.warmup_proportions + else: + return self.warmup_proportions[0] + + @property + def label_embedding_warmup_proportion(self) -> float: + """ + Proportion of the total training steps used for warming up the learning rate for training the 'label_embedding' sub-model. + """ + if isinstance(self.warmup_proportions, float): + return self.warmup_proportions + else: + return self.warmup_proportions[1] + + @property + def setfit_samples_per_label(self) -> int: + """ + Number of samples per label for training the 'setfit' submodel. + """ + if isinstance(self.samples_per_label, int): + return self.samples_per_label + else: + return self.samples_per_label[0] + + @property + def label_embedding_samples_per_label(self) -> int: + """ + Number of samples per label for training the 'label_embedding' submodel. + """ + if isinstance(self.samples_per_label, int): + return self.samples_per_label + else: + return self.samples_per_label[1] + + def _validate(self): + """ + Validates the provided training arguments to ensure they are in the correct format and contain necessary values. + Raises warnings for missing optional arguments and exceptions for missing non-optional ones. + """ + # Optional warning for missing json_path + if self.json_path is None: + warnings.warn( + f"`{self.__class__.__name__}.train` did not receive a `json_path`." + f"Evaluation results will not be saved to file." + f"Please provide a `json_path` to the `TrainingArguments` instance to suppress this warning.", + UserWarning, + stacklevel=2, + ) + + # Validate required fields by accessing properties and catching errors + required_properties = [ + ('setfit_batch_size', int), ('label_embedding_batch_size', int), + ('setfit_num_epochs', int), ('label_embedding_num_epochs', int), + ('setfit_sampling_strategy', str), ('label_embedding_sampling_strategy', str), + ('setfit_distance_metric', Callable), ('label_embedding_distance_metric', Callable), + ('setfit_loss', Callable), ('label_embedding_loss', Callable), + ('setfit_margin', float), ('label_embedding_margin', float), + ('setfit_warmup_proportion', float), ('label_embedding_warmup_proportion', float), + ('setfit_samples_per_label', int), ('label_embedding_samples_per_label', int) + ] + for prop, expected_type in required_properties: + try: + value = getattr(self, prop) + if not isinstance(value, expected_type): + raise TypeError(f"Expected type {expected_type} for {prop}, but got {type(value)}.") + except Exception as e: + raise ValueError(f"Invalid value for {prop}: {str(e)}") + + # Check for valid values in sampling_strategies + valid_sampling_strategies = {"oversampling", "unique", "undersampling"} + sampling_strategy_props = ['setfit_sampling_strategy', 'label_embedding_sampling_strategy'] + for strategy_prop in sampling_strategy_props: + value = getattr(self, strategy_prop) + if value not in valid_sampling_strategies: + raise ValueError(f"Invalid value '{value}' for '{prop}'. Must be one of {valid_sampling_strategies}.") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fbcce5e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "numpy"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..56a6798 --- /dev/null +++ b/setup.py @@ -0,0 +1,78 @@ +from setuptools import setup, find_packages + +# Read the contents of README file for 'long_description'. +with open("README.md", "r") as fh: + _long_description = fh.read() + +setup( + name="fusionsent", + version="0.0.8", + author="Tim Schopf, Alexander Blatzheim", + author_email="tim.schopf@tum.de, alexander.blatzheim@tum.de", + description="FusionSent: A Fusion-Based Multi-Task Sentence Embedding Model", + long_description=_long_description, + long_description_content_type="text/markdown", + url="https://github.com/sebischair/FusionSent", + packages=find_packages(), + install_requires= [ + "accelerate==0.32.1", + "aiohappyeyeballs==2.4.3", + "aiohttp==3.9.5", + "aiosignal==1.3.1", + "async-timeout==4.0.3", + "attrs==23.2.0", + "certifi==2024.8.30", + "charset-normalizer==3.3.2", + "datasets==2.20.0", + "dill==0.3.5.1", + "evaluate==0.4.2", + "filelock==3.15.4", + "frozenlist==1.4.1", + "fsspec==2023.10.0", + "huggingface-hub==0.21.2", + "idna==3.10", + "Jinja2==3.1.4", + "joblib==1.4.2", + "MarkupSafe==2.1.5", + "mpmath==1.3.0", + "multidict==6.0.5", + "multiprocess==0.70.13", + "networkx==3.3", + "numpy==1.23.5", + "packaging==24.1", + "pandas==2.2.2", + "pillow==10.4.0", + "psutil==6.0.0", + "pyarrow==15.0.0", + "python-dateutil==2.9.0.post0", + "pytz==2024.2", + "PyYAML==6.0.1", + "regex==2024.5.15", + "requests==2.32.3", + "safetensors==0.4.3", + "scikit-learn==1.5.1", + "scipy==1.10.1", + "sentence-transformers==3.0.1", + "setfit==1.0.3", + "six==1.16.0", + "sympy==1.13.3", + "threadpoolctl==3.5.0", + "tokenizers==0.19.1", + "torch==2.3.1", + "tqdm==4.66.4", + "transformers==4.40.0", + "typing_extensions==4.12.2", + "tzdata==2024.1", + "urllib3==2.2.2", + "xxhash==3.4.1", + "yarl==1.9.4" + ], + classifiers=[ + "Programming Language :: Python :: 3.10", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + python_requires='>=3.10', + license="Apache-2.0", + include_package_data=True +) \ No newline at end of file